├── .gitattributes ├── .gitignore ├── README.md ├── adaboost ├── .ipynb_checkpoints │ └── Adaboost-checkpoint.ipynb ├── AdaBoost.py ├── Adaboost.ipynb └── README.md ├── decision_tree ├── README.md └── tree_id3.py ├── em ├── README.md ├── data │ ├── amix1-est.dat │ ├── amix1-tst.dat │ ├── amix1-val.dat │ ├── amix2-est.dat │ ├── amix2-tst.dat │ ├── amix2-val.dat │ ├── golub-est.dat │ ├── golub-tst.dat │ └── golub-val.dat ├── gmm.py ├── gmm_penality.py ├── main.py └── main_panelity.py ├── kmeans ├── .ipynb_checkpoints │ └── kmeans-checkpoint.ipynb ├── README.md ├── kmeans.ipynb ├── kmeans_base.py └── kmeans_plus.py ├── knn ├── .ipynb_checkpoints │ └── KNN-checkpoint.ipynb ├── KNN.ipynb ├── README.md ├── knn_base.py └── knn_kdtree.py ├── logistic_regression ├── .ipynb_checkpoints │ └── logistic_regression-checkpoint.ipynb ├── LogisticRegressionClassifier.py ├── README.md ├── logistic_regression.ipynb └── max_entropy.py ├── naive_bayes ├── .ipynb_checkpoints │ └── naiveBayes-checkpoint.ipynb ├── README.md ├── naiveBayes.ipynb ├── naiveBayesBase.py └── naiveBayesGaussian.py ├── perceptron ├── .ipynb_checkpoints │ └── perceptron-checkpoint.ipynb ├── README.md ├── perceptron.ipynb ├── perceptron_base.py └── perceptron_dual.py ├── support_vector_machine ├── .ipynb_checkpoints │ └── svm-checkpoint.ipynb ├── README.md ├── svm.ipynb └── svm.py └── utils ├── data_generater.py ├── misc_utils.py ├── plot.py └── word_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.xml 3 | .idea/ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # machine-learning 2 | 通过阅读网上的资料代码,进行自我加工,努力实现常用的机器学习算法。 3 | 4 | # 目前已经实现可运行算法 5 | ### [KNN和KdTree的实现](https://github.com/SmallVagetable/machine_learning_python/tree/master/knn) 6 | ### [感知机的基本形式和对偶形式的实现](https://github.com/SmallVagetable/machine_learning_python/tree/master/perceptron) 7 | ### [Kmeans和Kmeans++的实现](https://github.com/SmallVagetable/machine_learning_python/tree/master/kmeans) 8 | ### [EM GMM高斯混合和GMM+LASSO的实现](https://github.com/SmallVagetable/machine_learning_python/tree/master/em) 9 | ### [实现朴素贝叶斯的基本算法和高斯混合朴素贝叶斯算法](https://github.com/SmallVagetable/machine_learning_python/tree/master/naive_bayes) 10 | ### [实现决策树的基本算法](https://github.com/SmallVagetable/machine_learning_python/tree/master/decision_tree) 11 | ### [实现adaboost基本算法](https://github.com/SmallVagetable/machine_learning_python/tree/master/adaboost) 12 | ### [实现svm基本算法](https://github.com/SmallVagetable/machine_learning_python/tree/master/support_vector_machine) 13 | ### [实现逻辑回归基本算法](https://github.com/SmallVagetable/machine_learning_python/tree/master/logistic_regression) -------------------------------------------------------------------------------- /adaboost/.ipynb_checkpoints/Adaboost-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from adaboost.AdaBoost import AdaBoost\n", 22 | "from sklearn.ensemble import AdaBoostClassifier\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "from sklearn.datasets import load_iris\n", 28 | "from sklearn.model_selection import train_test_split\n", 29 | "np.random.seed(10)\n", 30 | "\n", 31 | "%matplotlib inline" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def create_data():\n", 41 | " iris = load_iris()\n", 42 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 43 | " df['label'] = iris.target\n", 44 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 45 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 46 | " for i in range(len(data)):\n", 47 | " if data[i,-1] == 0:\n", 48 | " data[i,-1] = -1\n", 49 | " return data[:,:2], data[:,-1]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "X, y = create_data()\n", 59 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 5, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "" 71 | ] 72 | }, 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | }, 77 | { 78 | "data": { 79 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 80 | "text/plain": [ 81 | "" 82 | ] 83 | }, 84 | "metadata": {}, 85 | "output_type": "display_data" 86 | } 87 | ], 88 | "source": [ 89 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 90 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 91 | "plt.legend()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "0.65" 103 | ] 104 | }, 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n", 112 | "clf.fit(X_train, y_train)\n", 113 | "clf.score(X_test, y_test)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n", 125 | " learning_rate=0.5, n_estimators=100, random_state=None)" 126 | ] 127 | }, 128 | "execution_count": 7, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n", 135 | "clf.fit(X_train, y_train)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "0.9" 147 | ] 148 | }, 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "clf.score(X_test, y_test)" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": "Python 3", 162 | "language": "python", 163 | "name": "python3" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": { 167 | "name": "ipython", 168 | "version": 3 169 | }, 170 | "file_extension": ".py", 171 | "mimetype": "text/x-python", 172 | "name": "python", 173 | "nbconvert_exporter": "python", 174 | "pygments_lexer": "ipython3", 175 | "version": "3.6.4" 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /adaboost/AdaBoost.py: -------------------------------------------------------------------------------- 1 | from sklearn.ensemble import AdaBoostClassifier 2 | 3 | from utils.data_generater import * 4 | 5 | class AdaBoost(object): 6 | def __init__(self, n_estimators=50, learning_rate=1.0): 7 | self.clf_num = n_estimators 8 | self.learning_rate = learning_rate 9 | 10 | def init_args(self, datasets, labels): 11 | 12 | self.X = datasets 13 | self.Y = labels 14 | self.M, self.N = datasets.shape 15 | 16 | # 弱分类器数目和集合 17 | self.clf_sets = [] 18 | 19 | # 初始化weights 20 | self.weights = [1.0 / self.M] * self.M 21 | 22 | # G(x)系数 alpha 23 | self.alpha = [] 24 | 25 | def _G(self, features, labels, weights): 26 | m = len(features) 27 | error = 100000.0 # 无穷大 28 | best_v = 0.0 29 | # 单维features 30 | features_min = min(features) 31 | features_max = max(features) 32 | n_step = (features_max - features_min + self.learning_rate) // self.learning_rate 33 | # print('n_step:{}'.format(n_step)) 34 | direct, compare_array = None, None 35 | for i in range(1, int(n_step)): 36 | v = features_min + self.learning_rate * i 37 | 38 | if v not in features: 39 | # 误分类计算 40 | compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)]) 41 | weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]]) 42 | 43 | compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)]) 44 | weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]]) 45 | 46 | if weight_error_positive < weight_error_nagetive: 47 | weight_error = weight_error_positive 48 | _compare_array = compare_array_positive 49 | direct = 'positive' 50 | else: 51 | weight_error = weight_error_nagetive 52 | _compare_array = compare_array_nagetive 53 | direct = 'nagetive' 54 | 55 | # print('v:{} error:{}'.format(v, weight_error)) 56 | if weight_error < error: 57 | error = weight_error 58 | compare_array = _compare_array 59 | best_v = v 60 | return best_v, direct, error, compare_array 61 | 62 | # 计算alpha 63 | def _alpha(self, error): 64 | return 0.5 * np.log((1 - error) / error) 65 | 66 | # 规范化因子 67 | def _Z(self, weights, a, clf): 68 | return sum([weights[i] * np.exp(-1 * a * self.Y[i] * clf[i]) for i in range(self.M)]) 69 | 70 | # 权值更新 71 | def _w(self, a, clf, Z): 72 | for i in range(self.M): 73 | self.weights[i] = self.weights[i] * np.exp(-1 * a * self.Y[i] * clf[i]) / Z 74 | 75 | # G(x)的线性组合 76 | def _f(self, alpha, clf_sets): 77 | pass 78 | 79 | def G(self, x, v, direct): 80 | if direct == 'positive': 81 | return 1 if x > v else -1 82 | else: 83 | return -1 if x > v else 1 84 | 85 | def fit(self, X, y): 86 | self.init_args(X, y) 87 | 88 | for epoch in range(self.clf_num): 89 | best_clf_error, best_v, clf_result = 100000, None, None 90 | # 根据特征维度, 选择误差最小的 91 | for j in range(self.N): 92 | features = self.X[:, j] 93 | # 分类阈值,分类误差,分类结果 94 | v, direct, error, compare_array = self._G(features, self.Y, self.weights) 95 | 96 | if error < best_clf_error: 97 | best_clf_error = error 98 | best_v = v 99 | final_direct = direct 100 | clf_result = compare_array 101 | axis = j 102 | 103 | # print('epoch:{}/{} feature:{} error:{} v:{}'.format(epoch, self.clf_num, j, error, best_v)) 104 | if best_clf_error == 0: 105 | break 106 | 107 | # 计算G(x)系数a 108 | a = self._alpha(best_clf_error) 109 | self.alpha.append(a) 110 | # 记录分类器 111 | self.clf_sets.append((axis, best_v, final_direct)) 112 | # 规范化因子 113 | Z = self._Z(self.weights, a, clf_result) 114 | # 权值更新 115 | self._w(a, clf_result, Z) 116 | 117 | # print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a)) 118 | # print('weight:{}'.format(self.weights)) 119 | # print('\n') 120 | 121 | def predict(self, feature): 122 | result = 0.0 123 | for i in range(len(self.clf_sets)): 124 | axis, clf_v, direct = self.clf_sets[i] 125 | f_input = feature[axis] 126 | result += self.alpha[i] * self.G(f_input, clf_v, direct) 127 | # sign 128 | return 1 if result > 0 else -1 129 | 130 | def score(self, X_test, y_test): 131 | right_count = 0 132 | for i in range(len(X_test)): 133 | feature = X_test[i] 134 | if self.predict(feature) == y_test[i]: 135 | right_count += 1 136 | 137 | return right_count / len(X_test) 138 | 139 | if __name__ == "__main__": 140 | X_train, X_test, y_train, y_test = create_svm_data() 141 | my_ada = AdaBoost(n_estimators=10, learning_rate=0.2) 142 | my_ada.fit(X_train, y_train) 143 | print("my AdaBoost score", my_ada.score(X_test, y_test)) 144 | 145 | sk_ada = AdaBoostClassifier(n_estimators=100, learning_rate=0.5) 146 | sk_ada.fit(X_train, y_train) 147 | print("sklearn AdaBoost score", sk_ada.score(X_test, y_test)) -------------------------------------------------------------------------------- /adaboost/Adaboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from adaboost.AdaBoost import AdaBoost\n", 22 | "from sklearn.ensemble import AdaBoostClassifier\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "from sklearn.datasets import load_iris\n", 28 | "from sklearn.model_selection import train_test_split\n", 29 | "np.random.seed(10)\n", 30 | "\n", 31 | "%matplotlib inline" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def create_data():\n", 41 | " iris = load_iris()\n", 42 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 43 | " df['label'] = iris.target\n", 44 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 45 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 46 | " for i in range(len(data)):\n", 47 | " if data[i,-1] == 0:\n", 48 | " data[i,-1] = -1\n", 49 | " return data[:,:2], data[:,-1]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "X, y = create_data()\n", 59 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 5, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "" 71 | ] 72 | }, 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | }, 77 | { 78 | "data": { 79 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 80 | "text/plain": [ 81 | "" 82 | ] 83 | }, 84 | "metadata": {}, 85 | "output_type": "display_data" 86 | } 87 | ], 88 | "source": [ 89 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 90 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 91 | "plt.legend()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "0.65" 103 | ] 104 | }, 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n", 112 | "clf.fit(X_train, y_train)\n", 113 | "clf.score(X_test, y_test)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n", 125 | " learning_rate=0.5, n_estimators=100, random_state=None)" 126 | ] 127 | }, 128 | "execution_count": 7, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n", 135 | "clf.fit(X_train, y_train)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "0.9" 147 | ] 148 | }, 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "clf.score(X_test, y_test)" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": "Python 3", 162 | "language": "python", 163 | "name": "python3" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": { 167 | "name": "ipython", 168 | "version": 3 169 | }, 170 | "file_extension": ".py", 171 | "mimetype": "text/x-python", 172 | "name": "python", 173 | "nbconvert_exporter": "python", 174 | "pygments_lexer": "ipython3", 175 | "version": "3.6.4" 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /adaboost/README.md: -------------------------------------------------------------------------------- 1 | # 实现adaboost的算法 2 | 3 | # 结果比较 4 | 结果Adaboost.ipynb中展示 5 | 6 | # 相关博客 7 | #### [1. 集成学习(Ensemble Learning)原理](https://www.cnblogs.com/huangyc/p/9949598.html) 8 | #### [2. 集成学习(Ensemble Learning)Bagging](https://www.cnblogs.com/huangyc/p/9957216.html) 9 | #### [3. 集成学习(Ensemble Learning)随机森林(Random Forest)](https://www.cnblogs.com/huangyc/p/9960820.html) 10 | #### [4. 集成学习(Ensemble Learning)Adaboost](https://www.cnblogs.com/huangyc/p/9969958.html) 11 | #### [5. 集成学习(Ensemble Learning)GBDT](https://www.cnblogs.com/huangyc/p/9973148.html) 12 | #### [6. 集成学习(Ensemble Learning)算法比较](https://www.cnblogs.com/huangyc/p/9973253.html) 13 | #### [7. 集成学习(Ensemble Learning)Stacking](https://www.cnblogs.com/huangyc/p/9975183.html) -------------------------------------------------------------------------------- /decision_tree/README.md: -------------------------------------------------------------------------------- 1 | # 实现决策树的基本算法 2 | 3 | # 结果比较 4 | tree_id3.py中展示 5 | 6 | # 相关博客 7 | #### [1. 决策树(Decision Tree)-决策树原理](https://www.cnblogs.com/huangyc/p/9734972.html) 8 | -------------------------------------------------------------------------------- /decision_tree/tree_id3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from collections import Counter 4 | 5 | 6 | class Node(object): 7 | def __init__(self, x=None, label=None, y=None, data=None): 8 | self.label = label # label:子节点分类依据的特征 9 | self.x = x # x:特征 10 | self.child = [] # child:子节点 11 | self.y = y # y:类标记(叶节点才有) 12 | self.data = data # data:包含数据(叶节点才有) 13 | 14 | def append(self, node): # 添加子节点 15 | self.child.append(node) 16 | 17 | def predict(self, features): # 预测数据所述类 18 | if self.y is not None: 19 | return self.y 20 | for c in self.child: 21 | if c.x == features[self.label]: 22 | return c.predict(features) 23 | 24 | 25 | def printnode(node, depth=0): # 打印树所有节点 26 | if node.label is None: 27 | print(depth, (node.label, node.x, node.y, len(node.data))) 28 | else: 29 | print(depth, (node.label, node.x)) 30 | for c in node.child: 31 | printnode(c, depth+1) 32 | 33 | 34 | class DTreeID3(object): 35 | def __init__(self, epsilon=0, alpha=0): 36 | # 信息增益阈值 37 | self.epsilon = epsilon 38 | self.alpha = alpha 39 | self.tree = Node() 40 | 41 | # 求概率 42 | def prob(self, datasets): 43 | datalen = len(datasets) 44 | labelx = set(datasets) 45 | p = {l: 0 for l in labelx} 46 | for d in datasets: 47 | p[d] += 1 48 | for i in p.items(): 49 | p[i[0]] /= datalen 50 | return p 51 | 52 | # 求数据集的熵 53 | def calc_ent(self, datasets): 54 | p = self.prob(datasets) 55 | value = list(p.values()) 56 | return -np.sum(np.multiply(value, np.log2(value))) 57 | 58 | # 求条件熵 59 | def cond_ent(self, datasets, col): 60 | labelx = set(datasets.iloc[col]) 61 | p = {x: [] for x in labelx} 62 | for i, d in enumerate(datasets.iloc[-1]): 63 | p[datasets.iloc[col][i]].append(d) 64 | return sum([self.prob(datasets.iloc[col])[k] * self.calc_ent(p[k]) for k in p.keys()]) 65 | 66 | 67 | # 求信息增益 68 | def info_gain_train(self, datasets, datalabels): 69 | datasets = datasets.T 70 | ent = self.calc_ent(datasets.iloc[-1]) 71 | gainmax = {} 72 | for i in range(len(datasets) - 1): 73 | cond = self.cond_ent(datasets, i) 74 | gainmax[ent - cond] = i 75 | m = max(gainmax.keys()) 76 | return gainmax[m], m 77 | 78 | 79 | def train(self, datasets, node): 80 | labely = datasets.columns[-1] 81 | # 判断样本是否为同一类输出Di,如果是则返回单节点树T。标记类别为Di 82 | if len(datasets[labely].value_counts()) == 1: 83 | node.data = datasets[labely] 84 | node.y = datasets[labely][0] 85 | return 86 | # 判断特征是否为空,如果是则返回单节点树T,标记类别为样本中输出类别D实例数最多的类别 87 | if len(datasets.columns[:-1]) == 0: 88 | node.data = datasets[labely] 89 | node.y = datasets[labely].value_counts().index[0] 90 | return 91 | # 计算A中的各个特征(一共n个)对输出D的信息增益,选择信息增益最大的特征Ag。 92 | gainmaxi, gainmax = self.info_gain_train(datasets, datasets.columns) 93 | # 如果Ag的信息增益小于阈值ε,则返回单节点树T,标记类别为样本中输出类别D实例数最多的类别。 94 | if gainmax <= self.epsilon: 95 | node.data = datasets[labely] 96 | node.y = datasets[labely].value_counts().index[0] 97 | return 98 | # 按特征Ag的不同取值Agi将对应的样本输出D分成不同的类别Di。每个类别产生一个子节点。对应特征值为Agi。返回增加了节点的数T。 99 | vc = datasets[datasets.columns[gainmaxi]].value_counts() 100 | for Di in vc.index: 101 | node.label = gainmaxi 102 | child = Node(Di) 103 | node.append(child) 104 | new_datasets = pd.DataFrame([list(i) for i in datasets.values if i[gainmaxi]==Di], columns=datasets.columns) 105 | self.train(new_datasets, child) 106 | 107 | #训练数据 108 | def fit(self, datasets): 109 | self.train(datasets, self.tree) 110 | 111 | # 找到所有节点 112 | def findleaf(self, node, leaf): 113 | for t in node.child: 114 | if t.y is not None: 115 | leaf.append(t.data) 116 | else: 117 | for c in node.child: 118 | self.findleaf(c, leaf) 119 | 120 | def findfather(self, node, errormin): 121 | if node.label is not None: 122 | cy = [c.y for c in node.child] 123 | if None not in cy: # 全是叶节点 124 | childdata = [] 125 | for c in node.child: 126 | for d in list(c.data): 127 | childdata.append(d) 128 | childcounter = Counter(childdata) 129 | 130 | old_child = node.child # 剪枝前先拷贝一下 131 | old_label = node.label 132 | old_y = node.y 133 | old_data = node.data 134 | 135 | node.label = None # 剪枝 136 | node.y = childcounter.most_common(1)[0][0] 137 | node.data = childdata 138 | 139 | error = self.c_error() 140 | if error <= errormin: # 剪枝前后损失比较 141 | errormin = error 142 | return 1 143 | else: 144 | node.child = old_child # 剪枝效果不好,则复原 145 | node.label = old_label 146 | node.y = old_y 147 | node.data = old_data 148 | else: 149 | re = 0 150 | i = 0 151 | while i < len(node.child): 152 | if_re = self.findfather(node.child[i], errormin) # 若剪过枝,则其父节点要重新检测 153 | if if_re == 1: 154 | re = 1 155 | elif if_re == 2: 156 | i -= 1 157 | i += 1 158 | if re: 159 | return 2 160 | return 0 161 | 162 | def c_error(self): # 求C(T) 163 | leaf = [] 164 | self.findleaf(self.tree, leaf) 165 | leafnum = [len(l) for l in leaf] 166 | ent = [self.calc_ent(l) for l in leaf] 167 | print("Ent:", ent) 168 | error = self.alpha*len(leafnum) 169 | for l, e in zip(leafnum, ent): 170 | error += l*e 171 | print("C(T):", error) 172 | return error 173 | 174 | def cut(self, alpha=0): # 剪枝 175 | if alpha: 176 | self.alpha = alpha 177 | errormin = self.c_error() 178 | self.findfather(self.tree, errormin) 179 | 180 | if __name__ == "__main__": 181 | 182 | 183 | datasets = np.array([ 184 | ['青年', '否', '否', '一般', '否'], 185 | ['青年', '否', '否', '好', '否'], 186 | ['青年', '是', '否', '好', '是'], 187 | ['青年', '是', '是', '一般', '是'], 188 | ['青年', '否', '否', '一般', '否'], 189 | ['中年', '否', '否', '一般', '否'], 190 | ['中年', '否', '否', '好', '否'], 191 | ['中年', '是', '是', '好', '是'], 192 | ['中年', '否', '是', '非常好', '是'], 193 | ['中年', '否', '是', '非常好', '是'], 194 | ['老年', '否', '是', '非常好', '是'], 195 | ['老年', '否', '是', '好', '是'], 196 | ['老年', '是', '否', '好', '是'], 197 | ['老年', '是', '否', '非常好', '是'], 198 | ['老年', '否', '否', '一般', '否'], 199 | ['青年', '否', '否', '一般', '是']]) # 在李航原始数据上多加了最后这行数据,以便体现剪枝效果 200 | 201 | datalabels = np.array(['年龄', '有工作', '有自己的房子', '信贷情况', '类别']) 202 | train_data = pd.DataFrame(datasets, columns=datalabels) 203 | test_data = ['老年', '否', '否', '一般'] 204 | 205 | dt = DTreeID3(epsilon=0) # 可修改epsilon查看预剪枝效果 206 | dt.fit(train_data) 207 | 208 | print('DTree:') 209 | printnode(dt.tree) 210 | y = dt.tree.predict(test_data) 211 | print('result:', y) 212 | 213 | dt.cut(alpha=0.5) # 可修改正则化参数alpha查看后剪枝效果 214 | 215 | print('DTree:') 216 | printnode(dt.tree) 217 | y = dt.tree.predict(test_data) 218 | print('result:', y) 219 | -------------------------------------------------------------------------------- /em/README.md: -------------------------------------------------------------------------------- 1 | # 实现em gmm算法 2 | 高斯混合模型(GMM 聚类)的 EM 算法实现。 3 | 4 | # 入口 5 | 基本的gmm的入口是main.py 6 | 带penality的gmm的入口是main_penality.py 7 | 8 | # 相关博客 9 | #### [1. EM算法-数学基础](https://www.cnblogs.com/huangyc/p/10123320.html) 10 | #### [2. EM算法-原理详解](https://www.cnblogs.com/huangyc/p/10123780.html) 11 | #### [3. EM算法-高斯混合模型GMM](https://www.cnblogs.com/huangyc/p/10125117.html) 12 | #### [4. EM算法-高斯混合模型GMM详细代码实现](https://www.cnblogs.com/huangyc/p/10274881.html) 13 | #### [5. EM算法-高斯混合模型GMM+Lasso](https://www.cnblogs.com/huangyc/p/10275131.html) 14 | -------------------------------------------------------------------------------- /em/data/amix1-est.dat: -------------------------------------------------------------------------------- 1 | 4.8695 4.6649 2 | 1.3191 4.1648 3 | 4.6286 3.8526 4 | 5.2303 3.8473 5 | 0.8402 3.4262 6 | 4.1141 4.5625 7 | 5.1009 3.554 8 | 3.2464 4.297 9 | 1.0755 4.0667 10 | 2.4032 4.3099 11 | 5.4343 3.6545 12 | 4.8087 4.209 13 | 0.8914 3.7833 14 | 4.9785 3.3115 15 | 5.3969 3.7852 16 | 1.22 4.3816 17 | 5.2243 3.7736 18 | 5.2227 3.6556 19 | 4.3745 4.5783 20 | 2.9663 4.4406 21 | 2.887 5.2207 22 | 4.5483 4.7165 23 | 1.1236 3.9137 24 | 5.2279 3.9325 25 | 0.9978 3.5297 26 | 5.0112 3.6284 27 | 2.5725 4.183 28 | 3.046 5.0863 29 | 3.0772 3.6755 30 | 4.6264 3.3732 31 | 3.0431 3.9411 32 | 5.0297 3.7052 33 | 5.5016 4.3861 34 | 4.3906 4.297 35 | 3.2991 3.8621 36 | 1.074 4.1335 37 | 4.6507 4.0042 38 | 5.2801 4.7934 39 | 2.9153 4.728 40 | 2.6326 3.7633 41 | 4.6893 3.692 42 | 3.0474 3.6727 43 | 0.7439 4.8127 44 | 5.1537 4.8411 45 | 4.6111 4.0127 46 | 2.8048 3.8963 47 | 4.5522 4.1678 48 | 3.1483 3.9113 49 | 5.0791 3.4657 50 | 4.9282 3.9499 51 | 3.4493 4.1313 52 | 4.7273 3.8379 53 | 2.8573 4.399 54 | 4.3926 4.1533 55 | 4.5073 3.5216 56 | 4.7392 3.9716 57 | 5.0274 3.8381 58 | 4.8146 3.442 59 | 4.5521 3.7759 60 | 2.6141 3.1797 61 | 2.6066 3.9517 62 | 0.8141 3.2563 63 | 4.6537 4.2913 64 | 4.4462 4.9346 65 | 2.8534 5.3293 66 | 1.1773 3.6904 67 | 0.8914 4.4429 68 | 4.9422 4.1038 69 | 0.8142 3.8529 70 | 4.9691 3.833 71 | 4.3906 3.8969 72 | 5.411 4.604 73 | 4.8276 3.7039 74 | 0.7065 4.2605 75 | 5.0626 3.6313 76 | 4.8279 3.5369 77 | 5.0647 5.0124 78 | 5.3321 3.396 79 | 4.5743 4.0088 80 | 5.1648 3.8095 81 | 2.7311 4.6347 82 | 2.9007 2.8574 83 | 0.9269 4.4083 84 | 4.9946 4.2551 85 | 5.0148 3.4669 86 | 4.4519 4.4939 87 | 0.7085 4.0475 88 | 1.1284 3.9776 89 | 4.7833 4.1393 90 | 4.8636 3.4217 91 | 0.9711 3.5234 92 | 3.0592 4.1316 93 | 4.9753 3.9264 94 | 1.1141 3.9701 95 | 5.6788 3.8696 96 | 5.4387 3.9973 97 | 2.6959 4.8074 98 | 2.7364 4.5815 99 | 5.619 3.5116 100 | 4.9336 4.5102 101 | 3.1473 4.2108 102 | 0.9505 4.3478 103 | 0.9835 4.303 104 | 5.1499 3.7874 105 | 1.1101 3.4328 106 | 1.1715 4.8598 107 | 2.433 3.7051 108 | 4.8676 3.5301 109 | 5.681 3.7783 110 | 6.0597 4.078 111 | 1.0364 4.6592 112 | 4.4735 4.4599 113 | 2.6661 4.3085 114 | 2.7896 3.7097 115 | 4.9175 3.3429 116 | 3.1301 4.5026 117 | 5.6539 3.0958 118 | 4.4295 4.8892 119 | 2.791 4.1953 120 | 2.9848 3.8471 121 | 1.2933 4.337 122 | 0.9233 4.8271 123 | 0.9209 4.0487 124 | 4.5902 4.1615 125 | 0.8405 3.4519 126 | 2.8033 3.982 127 | 0.9341 4.4742 128 | 4.4823 4.2472 129 | 0.8961 3.5827 130 | 4.2521 4.241 131 | 3.3022 3.8634 132 | 4.6544 4.3396 133 | 5.335 4.1573 134 | 3.4093 3.8329 135 | 3.0167 4.034 136 | 4.8591 3.7352 137 | 3.2298 4.0195 138 | 2.9132 2.8676 139 | 5.0074 4.1246 140 | 3.241 4.9437 141 | 1.2574 4.2291 142 | 4.9229 4.7889 143 | 2.9775 4.3913 144 | 4.2324 3.9024 145 | 5.1051 4.2718 146 | 0.914 3.9153 147 | 3.0363 3.1813 148 | 4.771 3.3184 149 | 5.0425 3.6957 150 | 4.9725 3.4162 151 | 5.5178 3.4334 152 | 4.9344 4.8818 153 | 2.6995 3.5901 154 | 5.0657 4.5774 155 | 5.3666 5.3308 156 | 4.238 3.5944 157 | 1.1121 3.9655 158 | 5.1226 3.9449 159 | 4.3711 3.7975 160 | 2.7855 4.4326 161 | 0.6362 4.3296 162 | 2.8827 3.7919 163 | 4.8853 4.8752 164 | 3.4325 3.6512 165 | 5.0546 4.4571 166 | 4.6912 3.746 167 | 3.4023 4.3836 168 | 2.571 4.8546 169 | 0.7789 3.9874 170 | 4.2962 4.2849 171 | 1.1306 5.0307 172 | 4.7435 3.6594 173 | 4.7483 3.0919 174 | 4.99 4.2965 175 | 5.7755 3.6202 176 | 4.7409 4.3232 177 | 5.6385 4.6387 178 | 3.2558 4.1026 179 | 4.3969 2.8885 180 | 5.0947 2.8286 181 | 1.1457 3.5577 182 | 4.3778 3.2731 183 | 3.2411 3.2491 184 | 4.5983 4.2679 185 | 5.7576 3.1199 186 | 1.1191 4.0024 187 | 3.4342 3.6964 188 | 3.4441 4.5417 189 | 5.5643 3.8877 190 | 5.1764 4.6755 191 | 4.6307 4.5734 192 | 5.4776 3.986 193 | 4.9418 4.1585 194 | 4.8552 5.1753 195 | 1.02 4.9039 196 | 4.2108 4.2573 197 | 4.6019 3.8716 198 | 5.0444 5.9051 199 | 4.9381 3.5844 200 | 3.1043 3.5634 201 | 3.1444 4.8476 202 | 5.0045 4.4958 203 | 1.0058 5.2801 204 | 1.2319 3.4429 205 | 5.3655 4.5036 206 | 3.0142 3.7981 207 | 2.977 4.925 208 | 4.8536 4.5951 209 | 4.6071 4.0021 210 | 5.0832 3.9761 211 | 5.035 4.6594 212 | 5.2726 4.0716 213 | 5.5764 3.5093 214 | 0.9096 4.821 215 | 3.8499 3.8766 216 | 3.1807 3.8465 217 | 4.8613 3.0802 218 | 1.0037 4.6592 219 | 5.2257 3.7373 220 | 2.3395 4.196 221 | 2.8012 3.4197 222 | 2.8815 4.7749 223 | 5.1987 4.2425 224 | 2.8496 3.7446 225 | 5.3334 4.4829 226 | 0.8825 3.6966 227 | 0.7545 3.8812 228 | 4.7497 4.4502 229 | 4.777 3.5114 230 | 2.8597 4.4242 231 | 0.718 4.5507 232 | 4.2151 3.4175 233 | 4.7699 3.778 234 | 5.6712 4.2441 235 | 1.1429 3.9456 236 | 5.1505 3.6195 237 | 4.8395 3.8562 238 | 4.5852 2.9428 239 | 3.4558 4.1837 240 | 1.1809 3.6185 241 | 5.2591 4.2244 242 | 1.1084 3.9317 243 | 5.2526 3.7098 244 | 1.4003 4.1109 245 | 2.6125 4.4548 246 | 4.4405 5.3379 247 | 4.8064 2.9576 248 | 0.9504 3.8763 249 | 5.002 3.7349 250 | 1.3527 4.3448 251 | 4.5371 4.6508 252 | 4.8721 4.6813 253 | 5.2752 4.2971 254 | 5.0079 3.2136 255 | 4.8833 4.5507 256 | 2.7078 4.2796 257 | 5.0975 4.3247 258 | 4.7364 3.9818 259 | 5.5104 3.9762 260 | 2.9029 3.8724 261 | 2.7022 4.9819 262 | 2.571 4.1271 263 | 4.9351 4.2194 264 | 2.6746 4.3524 265 | 2.6231 3.7855 266 | 2.7088 3.1142 267 | 5.4714 3.9949 268 | 3.0214 5.0949 269 | 4.6252 3.4239 270 | 4.4173 4.2865 271 | 4.6226 3.8628 272 | 3.0633 3.5005 273 | 1.2647 3.0975 274 | 5.0572 3.4671 275 | 4.5308 3.8119 276 | 3.4668 3.4796 277 | 1.0674 3.6764 278 | 4.6976 4.2306 279 | 2.7953 4.8172 280 | 5.262 3.8477 281 | 5.1517 3.8446 282 | 0.9972 3.4383 283 | 4.9024 4.078 284 | 3.4083 4.8769 285 | 1.0939 3.6687 286 | 4.5103 3.8361 287 | 5.3921 3.8541 288 | 5.0524 4.5479 289 | 2.9607 4.2426 290 | 4.9686 3.5122 291 | 2.783 4.2279 292 | 5.2006 3.8226 293 | 4.9469 3.7825 294 | 4.6065 4.2121 295 | 5.3839 3.4805 296 | 5.4779 4.2478 297 | 4.2894 3.8718 298 | 5.2763 3.5165 299 | 4.6071 4.1635 300 | 3.5405 4.8858 301 | -------------------------------------------------------------------------------- /em/data/amix1-tst.dat: -------------------------------------------------------------------------------- 1 | 5.2808 3.6156 2 | 4.7303 3.7494 3 | 0.9959 3.559 4 | 2.5807 4.0899 5 | 1.1083 3.5048 6 | 1.2733 3.6579 7 | 3.3905 2.981 8 | 1.2434 4.3976 9 | 5.0299 4.0407 10 | 3.2442 3.7179 11 | 5.4128 4.1687 12 | 4.5364 3.881 13 | 4.8332 3.7551 14 | 4.9865 3.6888 15 | 5.3561 4.1125 16 | 4.9252 2.3934 17 | 4.4775 3.056 18 | 2.6375 3.7803 19 | 4.905 4.8708 20 | 1.3042 4.494 21 | 1.1917 4.5679 22 | 4.6076 3.9304 23 | 0.5703 3.9075 24 | 2.8247 4.3856 25 | 1.0218 3.7012 26 | 3.6221 3.6106 27 | 1.1001 4.2486 28 | 1.2437 3.1001 29 | 5.1167 3.791 30 | 2.5086 3.6793 31 | 4.3519 4.4113 32 | 2.3682 4.3488 33 | 1.2128 4.1716 34 | 5.29 4.6994 35 | 5.2526 4.2648 36 | 2.8982 4.6326 37 | 1.0913 3.733 38 | 1.1552 4.1479 39 | 4.8601 3.8985 40 | 0.9891 4.2645 41 | 3.252 4.5492 42 | 6.0426 3.9637 43 | 4.7879 3.9115 44 | 4.7847 3.4586 45 | 0.9149 3.1805 46 | 2.9126 3.1739 47 | 2.8925 3.1594 48 | 2.7046 3.332 49 | 4.7097 3.0637 50 | 0.9489 3.3557 51 | 2.5646 3.903 52 | 5.5259 3.565 53 | 5.0284 4.8402 54 | 5.1101 3.3439 55 | 4.9721 3.8096 56 | 2.7843 3.7926 57 | 5.0704 3.6051 58 | 0.8286 4.8081 59 | 0.7874 3.9323 60 | 2.7067 5.0244 61 | 0.8293 4.4076 62 | 3.0184 3.9446 63 | 4.9348 3.4348 64 | 1.418 4.0849 65 | 4.9021 3.6787 66 | 1.2955 4.204 67 | 1.0887 4.1641 68 | 3.3557 5.8198 69 | 4.133 3.7754 70 | 2.8783 4.3176 71 | 2.7642 4.1089 72 | 2.6383 4.7838 73 | 3.3499 3.9634 74 | 4.7791 4.2952 75 | 4.6629 3.0806 76 | 5.0784 3.9922 77 | 5.5871 3.6545 78 | 4.4645 4.5434 79 | 4.6457 3.9255 80 | 0.7308 4.1735 81 | 4.8115 3.1875 82 | 2.4637 3.573 83 | 4.6434 3.6336 84 | 0.8809 3.7946 85 | 3.305 4.07 86 | 4.999 3.9809 87 | 4.8807 3.2523 88 | 3.2542 4.6954 89 | 1.0067 3.791 90 | 0.9821 3.4702 91 | 5.1158 4.1002 92 | 5.2458 3.3985 93 | 4.7352 4.1047 94 | 3.0688 3.5925 95 | 5.3214 3.9451 96 | 0.8274 4.8519 97 | 5.1144 3.5427 98 | 4.9397 3.803 99 | 3.3689 3.3539 100 | 5.288 4.528 101 | 3.3282 3.2188 102 | 5.3732 3.921 103 | 4.8967 3.9011 104 | 0.8503 4.288 105 | 4.544 3.6257 106 | 5.6051 3.8816 107 | 4.8368 4.3767 108 | 3.2499 3.9575 109 | 3.5713 3.8257 110 | 3.03 4.3685 111 | 1.0295 3.3795 112 | 0.8465 3.7118 113 | 4.7721 4.8774 114 | 4.3019 3.6663 115 | 0.7939 4.0134 116 | 3.2009 3.8243 117 | 5.1745 2.6161 118 | 3.3878 4.6413 119 | 0.9815 3.9234 120 | 5.8172 4.6643 121 | 5.5945 2.9402 122 | 2.8334 4.1375 123 | 4.884 3.6355 124 | 3.2892 4.3709 125 | 3.3255 3.8691 126 | 5.5505 4.7877 127 | 3.185 4.1638 128 | 5.1828 3.3479 129 | 5.3289 3.8798 130 | 3.0701 3.8655 131 | 4.5693 3.9423 132 | 5.0988 4.062 133 | 4.7722 3.5183 134 | 1.1334 4.7485 135 | 4.6581 4.7538 136 | 0.8514 3.5916 137 | 1.1706 3.8742 138 | 4.8935 3.4505 139 | 6.1248 4.1598 140 | 4.6464 4.4867 141 | 5.1465 4.0466 142 | 2.7186 3.8577 143 | 5.2278 4.1286 144 | 1.0164 4.3792 145 | 2.5044 4.3931 146 | 0.8386 4.2247 147 | 2.574 4.4343 148 | 5.4381 4.6702 149 | 2.9834 4.0859 150 | 4.5025 3.544 151 | 0.9897 3.8696 152 | 0.8118 4.2695 153 | 4.5605 4.4332 154 | 2.969 4.3075 155 | 4.503 3.4867 156 | 5.0818 3.993 157 | 4.7683 3.9542 158 | 4.9384 4.9271 159 | 4.8529 3.7184 160 | 4.5457 3.6019 161 | 4.9043 3.7024 162 | 4.8152 4.9632 163 | 3.3412 4.1404 164 | 5.4514 4.656 165 | 4.1075 4.3762 166 | 3.0124 4.5191 167 | 2.9558 4.4889 168 | 2.574 2.9308 169 | 5.1904 4.633 170 | 3.1892 4.2373 171 | 4.2647 3.3939 172 | 5.1258 4.4008 173 | 2.3128 3.9976 174 | 4.6041 3.7698 175 | 2.501 3.7543 176 | 5.2659 3.6122 177 | 5.6491 4.4456 178 | 5.4201 3.8439 179 | 3.0587 4.6743 180 | 3.1404 3.5467 181 | 4.9517 4.3761 182 | 2.9241 4.1846 183 | 4.5509 4.1954 184 | 4.5982 4.3383 185 | 0.5424 3.5991 186 | 3.1025 3.9824 187 | 5.2808 3.8439 188 | 5.4825 3.2557 189 | 5.4327 4.6308 190 | 2.2122 4.0087 191 | 3.8483 4.7502 192 | 5.1371 3.9113 193 | 0.9276 3.8034 194 | 3.0976 4.204 195 | 3.2193 3.6319 196 | 0.7306 4.1816 197 | 0.6186 4.5547 198 | 4.9655 4.7927 199 | 3.1066 3.2273 200 | 4.7889 3.9301 201 | 4.862 4.0241 202 | 4.1996 3.9913 203 | 4.3515 3.6324 204 | 4.5531 3.3232 205 | 2.2 4.0195 206 | 1.0455 3.4489 207 | 0.9426 4.1335 208 | 3.2244 4.4995 209 | 5.4273 3.7381 210 | 3.6689 4.2284 211 | 4.9604 2.9791 212 | 0.7798 4.0317 213 | 3.4034 4.1642 214 | 0.6741 4.2845 215 | 3.2707 4.768 216 | 5.1033 3.3124 217 | 5.2156 3.7731 218 | 5.3163 4.3573 219 | 4.8102 3.9377 220 | 4.8029 3.6335 221 | 3.2729 3.0605 222 | 5.4828 4.4888 223 | 5.2878 4.265 224 | 3.2877 4.0889 225 | 2.4633 3.2546 226 | 5.4056 4.0506 227 | 5.0033 3.7806 228 | 1.4244 3.9102 229 | 5.3272 3.0913 230 | 3.3667 3.5744 231 | 0.7067 4.2196 232 | 3.6579 3.4588 233 | 5.07 4.8839 234 | 4.3925 3.2447 235 | 1.3431 3.8823 236 | 1.3416 3.6104 237 | 5.6241 4.1722 238 | 4.6975 3.811 239 | 5.7592 3.0604 240 | 5.091 3.7997 241 | 5.472 4.8472 242 | 0.9918 4.6022 243 | 0.6186 3.4324 244 | 3.289 4.3466 245 | 5.0572 3.5829 246 | 4.9967 4.5694 247 | 0.6641 4.4889 248 | 1.1578 4.2788 249 | 4.6694 4.4939 250 | 2.8946 3.8528 251 | 3.2466 4.0504 252 | 2.8721 4.3062 253 | 0.8955 3.5623 254 | 4.0227 4.4983 255 | 2.8698 4.3888 256 | 5.0855 4.4885 257 | 5.6935 4.2692 258 | 4.3638 4.1547 259 | 4.5868 4.2828 260 | 1.1752 4.1467 261 | 5.3091 3.6172 262 | 2.6961 3.6601 263 | 5.0798 3.3318 264 | 2.9944 3.2159 265 | 2.9504 3.4741 266 | 3.0428 3.813 267 | 5.1831 3.143 268 | 5.5448 3.8197 269 | 1.4118 3.9519 270 | 5.4304 3.6213 271 | 3.0712 3.7146 272 | 3.4521 3.8602 273 | 0.8021 4.3964 274 | 2.8428 4.1771 275 | 5.5698 3.5349 276 | 1.1836 4.0157 277 | 1.4631 3.8891 278 | 0.9893 3.7876 279 | 2.7166 3.7105 280 | 3.2988 4.8664 281 | 3.308 4.1059 282 | 4.9744 3.4959 283 | 4.7819 3.3882 284 | 5.5673 3.7031 285 | 1.1481 5.004 286 | 1.074 3.7124 287 | 1.3516 4.8563 288 | 5.3478 4.4656 289 | 1.1599 3.7186 290 | 2.7321 4.226 291 | 4.9143 4.5593 292 | 3.2522 2.8902 293 | 5.1298 3.6726 294 | 4.7728 4.3218 295 | 5.4805 4.3575 296 | 4.1564 3.9634 297 | 1.1348 4.6887 298 | 0.9133 3.9089 299 | 4.8875 3.3841 300 | 0.9607 4.1171 301 | 4.9266 3.961 302 | 4.8449 4.8613 303 | 5.7812 4.4323 304 | 0.6085 4.0713 305 | 3.0365 3.5759 306 | 5.3222 4.0267 307 | 4.6915 4.2862 308 | 4.6456 4.0756 309 | 3.0846 3.028 310 | 3.9956 4.3692 311 | 5.2676 4.016 312 | 5.0558 4.0445 313 | 5.1812 3.3993 314 | 2.5752 3.9438 315 | 2.7856 4.3257 316 | 4.4155 4.1321 317 | 5.436 4.2057 318 | 4.6139 3.9186 319 | 0.9158 3.724 320 | 5.6396 3.9443 321 | 3.0521 4.596 322 | 2.772 3.388 323 | 1.0194 5.1259 324 | 4.6366 2.7989 325 | 2.8118 3.7727 326 | 5.3549 4.5302 327 | 4.961 3.3516 328 | 3.1266 3.4082 329 | 4.9581 3.6522 330 | 3.1357 4.0288 331 | 5.3384 3.688 332 | 0.9555 3.1122 333 | 4.7808 4.1652 334 | 5.5924 4.6764 335 | 4.7812 3.8603 336 | 3.2681 3.6404 337 | 1.3109 3.1311 338 | 4.7988 4.7129 339 | 4.6414 3.805 340 | 4.9767 3.6869 341 | 4.8885 3.4749 342 | 5.2237 3.8943 343 | 4.1008 4.0239 344 | 0.7188 3.2519 345 | 2.3244 3.8884 346 | 5.2533 4.8241 347 | 2.8289 4.4806 348 | 4.8392 4.4618 349 | 4.8613 3.7571 350 | 3.3013 3.844 351 | 5.0048 3.9954 352 | 5.2077 3.9721 353 | 2.7953 4.1434 354 | 2.3631 3.896 355 | 4.4283 4.0674 356 | 2.9054 4.4128 357 | 1.4217 4.3982 358 | 5.0573 4.7126 359 | 5.0663 3.9178 360 | 1.075 5.1917 361 | 4.4828 2.961 362 | 2.6876 4.3844 363 | 5.4164 4.3629 364 | 4.9751 4.8098 365 | 1.3507 3.2537 366 | 4.8149 4.5216 367 | 4.917 3.9046 368 | 0.6262 3.3965 369 | 1.2092 3.975 370 | 1.1735 4.3513 371 | 1.4158 3.6555 372 | 0.9373 3.5389 373 | 4.5526 4.1553 374 | 3.1832 4.1208 375 | 1.13 3.9127 376 | 4.8315 3.5208 377 | 2.6738 4.0241 378 | 4.9289 4.3874 379 | 1.1168 5.5118 380 | 0.9907 3.6707 381 | 1.084 3.7924 382 | 1.0757 4.7078 383 | 5.134 4.2303 384 | 4.5865 4.2487 385 | 4.308 3.5195 386 | 4.9017 3.6726 387 | 5.1369 4.5036 388 | 4.886 3.8498 389 | 5.4253 3.6585 390 | 5.0125 3.3954 391 | 4.5694 4.2296 392 | 5.5688 4.0916 393 | 5.7322 4.1592 394 | 2.7326 3.3213 395 | 4.6748 4.6319 396 | 5.223 3.1816 397 | 5.192 3.2064 398 | 2.7954 4.026 399 | 4.534 4.7195 400 | 3.1653 4.6244 401 | 3.0238 3.6992 402 | 2.5544 4.7346 403 | 5.2465 4.2805 404 | 4.8344 4.9946 405 | 4.371 4.5115 406 | 2.7297 3.213 407 | 4.7318 4.4003 408 | 1.0556 4.3125 409 | 5.4367 3.159 410 | 4.6245 3.9461 411 | 4.7271 4.2177 412 | 3.3256 4.3649 413 | 5.0487 3.5361 414 | 4.8008 3.302 415 | 5.5722 3.8113 416 | 5.9997 4.0726 417 | 5.3115 3.9845 418 | 5.5637 4.0269 419 | 1.2047 3.3915 420 | 3.5936 3.3169 421 | 3.0671 3.3575 422 | 3.1323 3.7855 423 | 4.749 4.5747 424 | 4.6891 4.2039 425 | 2.462 3.2064 426 | 4.5353 3.6791 427 | 1.1287 3.1483 428 | 0.8134 4.1552 429 | 1.5185 4.6436 430 | 5.0303 4.5297 431 | 0.9306 4.6699 432 | 2.1931 3.7165 433 | 1.2326 4.451 434 | 2.3894 3.6978 435 | 4.5237 3.3263 436 | 4.9063 4.6313 437 | 5.2813 4.142 438 | 4.9884 5.0868 439 | 0.8562 4.8939 440 | 1.2528 4.154 441 | 5.5967 3.9834 442 | 4.9577 4.1012 443 | 3.0178 3.425 444 | 2.9935 4.6111 445 | 5.002 3.9095 446 | 2.7248 4.2186 447 | 5.1259 3.6454 448 | 2.8768 4.0504 449 | 1.113 3.348 450 | 2.8551 3.8419 451 | 0.8397 3.7341 452 | 2.9237 3.3518 453 | 5.6635 4.9363 454 | 5.2723 4.7847 455 | 3.2407 4.2915 456 | 5.3228 4.0323 457 | 1.0346 4.1657 458 | 0.8835 4.4433 459 | 5.8051 3.5039 460 | 4.192 4.1113 461 | 4.8271 3.8778 462 | 5.1683 4.3214 463 | 3.1913 3.8804 464 | 1.0735 3.9433 465 | 4.9958 4.0542 466 | 1.455 3.7793 467 | 3.0691 3.8833 468 | 3.3866 3.8728 469 | 2.7381 4.0332 470 | 5.0062 4.1163 471 | 5.3517 4.069 472 | 2.654 4.2637 473 | 1.1177 4.0792 474 | 4.8822 4.707 475 | 0.7954 4.6055 476 | 3.0822 4.2498 477 | 2.9222 3.6261 478 | 4.5418 4.4762 479 | 4.9263 3.906 480 | 4.5209 3.965 481 | 4.6943 4.1057 482 | 3.2213 3.1542 483 | 3.1328 3.9383 484 | 5.2021 3.8708 485 | 4.9985 3.6477 486 | 4.6424 3.6174 487 | 1.3318 3.6509 488 | 0.9025 3.8641 489 | 1.0622 3.6838 490 | 5.6526 4.4347 491 | 0.8095 4.4002 492 | 2.9693 4.0836 493 | 2.9553 3.807 494 | 4.7165 4.1767 495 | 5.1836 4.025 496 | 0.9312 3.6513 497 | 2.5587 4.1767 498 | 3.062 5.1722 499 | 5.2236 3.4979 500 | 1.0734 3.9971 501 | 3.0154 4.1458 502 | 3.2313 4.0445 503 | 5.3342 4.8603 504 | 5.193 3.5242 505 | 3.0232 3.7691 506 | 2.6014 4.0635 507 | 3.1654 3.3849 508 | 4.4766 3.8308 509 | 0.7611 4.6085 510 | 2.7058 3.6646 511 | 4.8712 3.4467 512 | 3.4449 4.28 513 | 2.7522 4.5558 514 | 5.2767 4.014 515 | 5.1144 3.8186 516 | 4.9109 4.0137 517 | 1.098 4.2609 518 | 1.1868 3.4573 519 | 2.8212 4.5355 520 | 5.4771 3.6561 521 | 5.5141 3.9473 522 | 4.8352 4.4129 523 | 5.0341 4.0536 524 | 1.0012 4.1233 525 | 2.0889 4.6573 526 | 3.3369 4.7904 527 | 5.3041 4.3862 528 | 5.5284 4.4651 529 | 1.0103 4.7705 530 | 1.1418 3.3653 531 | 5.3429 3.6973 532 | 3.3947 4.3295 533 | 4.9883 3.4881 534 | 3.0571 2.2302 535 | 1.1698 4.4049 536 | 2.9333 4.7174 537 | 2.7673 4.5941 538 | 5.3603 3.2769 539 | 4.7697 4.1511 540 | 5.8385 4.0837 541 | 2.9742 4.7775 542 | 5.0292 4.6738 543 | 1.0444 3.9042 544 | 5.5997 4.4073 545 | 0.9254 4.5186 546 | 4.222 4.3442 547 | 3.4097 3.6716 548 | 2.7011 4.6083 549 | 5.305 2.9614 550 | 4.4834 3.3338 551 | 4.852 4.5854 552 | 5.411 3.7013 553 | 0.9101 3.5767 554 | 4.5702 3.2744 555 | 4.2721 4.03 556 | 1.0007 4.4937 557 | 0.9035 4.379 558 | 4.5404 3.9124 559 | 3.1279 3.3456 560 | 5.0488 4.0327 561 | 4.8457 4.0284 562 | 2.823 4.2776 563 | 5.6386 3.1736 564 | 5.0673 4.4028 565 | 4.9684 4.5456 566 | 3.9821 4.2681 567 | 3.1856 3.7668 568 | 5.2592 4.2429 569 | 5.2464 3.9414 570 | 4.6027 4.03 571 | 3.2784 4.1566 572 | 3.1623 3.5824 573 | 3.1012 4.2443 574 | 4.5574 3.9461 575 | 2.4097 4.4367 576 | 2.6821 3.5132 577 | 4.7402 3.8712 578 | 5.0285 3.799 579 | 5.7294 4.0198 580 | 2.942 3.6839 581 | 4.9264 4.852 582 | 1.1176 3.7339 583 | 0.9758 4.1539 584 | 4.809 2.8785 585 | 3.3407 3.8689 586 | 4.8225 4.13 587 | 3.6868 3.5705 588 | 2.9325 3.2407 589 | 5.7106 3.1783 590 | 5.4507 3.9943 591 | 1.1982 4.2141 592 | 1.0066 4.7764 593 | 4.3716 3.7573 594 | 4.3017 3.6531 595 | 5.6057 4.181 596 | 1.1374 3.4447 597 | 3.1665 4.4231 598 | 4.9011 4.827 599 | 3.0949 3.6362 600 | 5.0508 4.133 601 | 0.9678 4.893 602 | 4.7868 3.5887 603 | 4.5122 3.6611 604 | 2.8057 3.61 605 | 4.9214 3.6995 606 | 4.7627 3.9375 607 | 4.9471 4.0981 608 | 2.7817 4.151 609 | 3.0278 4.3305 610 | 3.2474 4.5874 611 | 1.0077 4.072 612 | 5.4934 4.1429 613 | 3.0496 3.481 614 | 3.0053 3.1831 615 | 5.2766 4.0612 616 | 4.0768 4.1374 617 | 4.9549 3.8413 618 | 4.5259 3.9122 619 | 5.1537 3.332 620 | 2.8684 3.7162 621 | 0.8939 4.407 622 | 5.6872 3.5537 623 | 3.481 4.0458 624 | 1.1023 3.7201 625 | 5.2692 4.5195 626 | 0.8538 3.9809 627 | 0.9631 4.1045 628 | 3.2808 4.1916 629 | 4.9764 3.9777 630 | 3.4379 2.9044 631 | 5.1748 4.1806 632 | 2.9394 3.4152 633 | 5.1734 3.5529 634 | 4.8181 3.5397 635 | 4.8869 4.2301 636 | 5.629 4.3392 637 | 2.7421 4.2502 638 | 2.6126 4.4804 639 | 3.2117 3.3251 640 | 4.9948 4.1011 641 | 5.5071 3.0247 642 | 3.6566 3.427 643 | 3.4487 3.9924 644 | 0.9507 3.9974 645 | 5.518 4.0827 646 | 2.9529 3.6996 647 | 5.065 4.3934 648 | 4.7713 3.9823 649 | 4.833 2.7657 650 | 2.893 3.4873 651 | 3.456 4.1414 652 | 5.2727 4.1017 653 | 4.5923 3.9657 654 | 5.529 4.1871 655 | 0.9841 4.8589 656 | 2.5828 4.6429 657 | 5.4123 4.565 658 | 2.9229 3.4659 659 | 5.0159 4.4292 660 | 2.8953 3.927 661 | 5.2695 4.4847 662 | 5.5306 3.2134 663 | 3.2055 4.4503 664 | 3.6669 3.9961 665 | 3.225 4.1517 666 | 3.3599 2.9705 667 | 5.0428 3.9667 668 | 5.0567 4.0462 669 | 2.838 5.2492 670 | 5.4028 2.8191 671 | 4.6582 4.7425 672 | 4.3677 4.3536 673 | 4.4199 4.5562 674 | 4.5298 4.8723 675 | 2.6148 3.8665 676 | 4.8683 3.9348 677 | 5.5684 4.2624 678 | 1.0001 4.772 679 | 3.1619 3.9041 680 | 5.0754 4.2922 681 | 3.4182 4.3802 682 | 3.0963 3.2212 683 | 3.1227 3.4437 684 | 5.7958 4.2402 685 | 2.6899 3.8518 686 | 4.5411 3.91 687 | 4.5811 4.6373 688 | 2.8206 5.2087 689 | 3.0948 4.1377 690 | 5.0201 4.3273 691 | 3.1402 3.5668 692 | 1.2087 4.0289 693 | 4.8946 3.093 694 | 2.5448 3.841 695 | 5.2433 3.7538 696 | 5.6792 4.4593 697 | 3.1437 4.3405 698 | 2.5219 4.9746 699 | 0.9406 4.0669 700 | 1.1157 4.2461 701 | 4.9928 3.9746 702 | 5.1612 3.6863 703 | 0.644 3.937 704 | 0.8716 3.5312 705 | 5.037 5.255 706 | 3.1493 3.2683 707 | 5.313 3.9501 708 | 3.3415 4.2509 709 | 4.138 4.545 710 | 3.0329 4.0945 711 | 3.092 4.0959 712 | 2.5994 4.2744 713 | 4.6555 3.5785 714 | 3.428 4.4626 715 | 0.8835 4.3983 716 | 4.8258 5.3236 717 | 5.1681 4.6182 718 | 4.9952 4.6299 719 | 4.6944 3.7158 720 | 0.95 4.5394 721 | 2.9154 3.999 722 | 3.1106 3.4155 723 | 5.373 2.8737 724 | 5.3227 4.2441 725 | 1.0016 4.3048 726 | 3.3647 3.4567 727 | 4.9228 3.5546 728 | 2.8283 4.5256 729 | 4.3777 2.8883 730 | 2.6814 4.9463 731 | 2.8933 3.6018 732 | 5.4881 3.9332 733 | 2.9763 3.6902 734 | 0.9196 4.9425 735 | 3.2993 3.4309 736 | 2.8881 3.3268 737 | 5.4902 3.0064 738 | 2.9352 4.0376 739 | 0.9867 3.8514 740 | 3.3821 3.8463 741 | 2.9641 3.8853 742 | 4.2392 3.7534 743 | 0.9465 3.6307 744 | 5.1771 4.035 745 | 5.4939 4.6772 746 | 2.8888 3.7332 747 | 1.149 3.7643 748 | 5.2186 4.3845 749 | 5.2939 4.5754 750 | 4.1341 3.4611 751 | 3.149 3.279 752 | 4.6786 3.1987 753 | 0.9572 3.7706 754 | 1.3772 4.7892 755 | 2.6234 4.0088 756 | 5.8945 4.166 757 | 4.8751 3.7698 758 | 2.9031 3.4822 759 | 2.9384 4.0832 760 | 2.4503 4.0665 761 | 4.9544 4.0768 762 | 0.7213 3.499 763 | 0.9493 3.1213 764 | 3.0126 3.9336 765 | 5.1881 4.2606 766 | 1.332 4.7775 767 | 2.8081 3.9704 768 | 0.7154 3.9706 769 | 3.0485 3.6468 770 | 3.7136 3.556 771 | 5.1198 3.3239 772 | 4.8216 3.5328 773 | 3.2582 3.8181 774 | 2.9678 4.2889 775 | 2.4949 4.8639 776 | 4.5696 4.097 777 | 0.7349 4.1652 778 | 3.8143 4.5582 779 | 1.1333 4.3634 780 | 4.9825 4.2045 781 | 2.8601 4.0841 782 | 3.6147 3.5777 783 | 4.5534 3.9244 784 | 5.2465 4.2964 785 | 0.7387 3.4409 786 | 3.2752 3.9267 787 | 5.1831 4.1002 788 | 5.2889 4.297 789 | 4.7919 3.862 790 | 4.5459 3.762 791 | 2.9286 3.489 792 | 4.9476 3.9766 793 | 1.5276 4.1878 794 | 4.6421 4.3867 795 | 2.7351 4.1831 796 | 1.2263 3.7808 797 | 0.6025 4.366 798 | 0.9209 4.3533 799 | 4.6915 2.8895 800 | 5.1245 3.9978 801 | 5.4439 3.9702 802 | 3.014 4.1309 803 | 1.2126 3.7126 804 | 2.9033 4.3311 805 | 0.7421 3.7247 806 | 5.0493 3.704 807 | 1.1983 4.3622 808 | 2.5002 4.4212 809 | 3.1493 3.6508 810 | 3.5522 3.0189 811 | 5.5442 4.4887 812 | 0.8392 3.4083 813 | 4.2547 5.1818 814 | 3.2741 3.6164 815 | 1.0527 3.7992 816 | 4.1409 3.0426 817 | 2.8108 4.3897 818 | 4.9545 4.6357 819 | 5.1394 5.1478 820 | 2.6433 3.9272 821 | 5.1968 3.2871 822 | 3.0363 4.4906 823 | 4.4595 3.8679 824 | 3.4049 4.4939 825 | 3.1519 3.0761 826 | 2.981 4.4711 827 | 4.8262 3.227 828 | 3.3105 3.0764 829 | 3.06 4.8103 830 | 1.2882 3.9629 831 | 5.4383 3.652 832 | 1.2856 5.0409 833 | 5.0577 4.1248 834 | 5.2873 4.3598 835 | 5.1058 4.1738 836 | 2.8901 4.1699 837 | 5.0149 3.7761 838 | 0.6828 4.6515 839 | 2.9109 3.6121 840 | 5.3584 4.5454 841 | 4.5785 3.7879 842 | 5.1736 3.9062 843 | 5.5506 3.873 844 | 0.9979 3.7499 845 | 5.2291 4.707 846 | 3.5022 5.2736 847 | 3.3648 3.8305 848 | 0.7682 4.7466 849 | 5.0011 3.9875 850 | 3.2108 4.0082 851 | 2.8071 3.9037 852 | 1.1235 4.6679 853 | 5.3035 2.8908 854 | 5.0438 4.5501 855 | 5.194 3.7497 856 | 4.6613 2.9135 857 | 5.4902 3.3699 858 | 4.4887 3.7123 859 | 5.1633 3.7682 860 | 2.9127 4.3386 861 | 5.1225 4.0616 862 | 0.8416 3.5341 863 | 5.097 3.8484 864 | 4.1906 4.9474 865 | 0.9729 3.755 866 | 3.6214 4.1164 867 | 4.6543 3.6491 868 | 5.4143 4.6464 869 | 1.022 4.0965 870 | 5.1031 4.9975 871 | 4.6988 4.1799 872 | 4.7653 3.3628 873 | 4.3214 4.1723 874 | 3.4677 3.4472 875 | 3.5345 3.5322 876 | 1.19 4.1049 877 | 0.8246 3.483 878 | 2.7613 3.8551 879 | 4.6895 3.9291 880 | 4.0661 3.4515 881 | 3.1734 4.6754 882 | 4.5436 3.6242 883 | 0.9645 2.9589 884 | 4.957 3.9742 885 | 3.1666 4.3171 886 | 4.8902 4.0138 887 | 0.7495 4.5262 888 | 5.0719 4.2075 889 | 3.0198 3.1873 890 | 4.5981 4.2834 891 | 2.9111 3.9991 892 | 4.4086 4.1939 893 | 5.0469 4.28 894 | 5.4547 3.1546 895 | 4.8995 3.5117 896 | 6.0475 3.3004 897 | 2.884 4.6561 898 | 5.5348 4.2302 899 | 1.197 3.9447 900 | 2.9208 4.5118 901 | 4.9792 3.0601 902 | 2.873 3.4456 903 | 2.8296 4.0058 904 | 1.0744 4.223 905 | 2.93 3.913 906 | 3.3283 4.0387 907 | 4.6272 3.762 908 | 4.7915 4.4167 909 | 3.1248 3.6153 910 | 2.621 4.5781 911 | 4.7726 4.3268 912 | 3.3571 4.2704 913 | 5.0543 4.3873 914 | 0.9787 2.7082 915 | 1.193 4.4833 916 | 5.0627 3.6256 917 | 1.0927 3.657 918 | 1.2713 3.6438 919 | 4.981 3.1023 920 | 3.3347 4.0862 921 | 2.677 4.1773 922 | 5.1108 4.216 923 | 4.4402 4.8017 924 | 3.2424 4.1488 925 | 3.1711 4.306 926 | 5.0346 4.0085 927 | 3.0556 3.5946 928 | 3.0948 4.1141 929 | 5.1639 3.8605 930 | 2.9051 4.1832 931 | 5.2045 4.5641 932 | 0.8469 4.2311 933 | 1.0168 3.8673 934 | 4.6974 3.8947 935 | 3.3245 3.1904 936 | 4.9871 3.2427 937 | 5.1189 3.6338 938 | 0.681 3.4455 939 | 3.2338 3.5537 940 | 2.3823 3.5042 941 | 2.6818 4.4415 942 | 4.6647 4.3294 943 | 2.8513 4.6737 944 | 1.4172 3.9643 945 | 2.7988 3.9579 946 | 3.2862 4.0409 947 | 1.342 4.602 948 | 4.1959 4.5682 949 | 1.2648 3.6172 950 | 1.0422 4.223 951 | 2.9659 4.1446 952 | 0.8855 4.1563 953 | 5.3763 4.1304 954 | 4.7979 3.8967 955 | 3.1445 3.7926 956 | 2.8304 4.8333 957 | 4.7357 3.5897 958 | 1.0942 4.0179 959 | 4.6167 4.5772 960 | 2.4404 4.0553 961 | 5.3407 3.724 962 | 4.7093 3.8761 963 | 2.5641 4.1726 964 | 5.5151 3.8306 965 | 3.4127 3.6304 966 | 4.6455 5.0863 967 | 2.728 4.4278 968 | 2.8577 3.9651 969 | 5.4471 3.5472 970 | 5.0491 4.6144 971 | 5.0687 3.4301 972 | 5.2026 4.2618 973 | 2.8341 3.4102 974 | 5.184 4.1957 975 | 5.1773 3.6695 976 | 1.0422 3.9316 977 | 3.1061 3.4632 978 | 2.8503 3.5865 979 | 5.7043 4.543 980 | 1.0196 3.7514 981 | 3.128 3.846 982 | 4.5021 3.9198 983 | 0.6728 4.9975 984 | 3.1234 3.4927 985 | 1.1381 3.5623 986 | 4.338 3.5833 987 | 1.0512 3.6688 988 | 0.9514 3.7226 989 | 4.7476 3.3507 990 | 5.3501 3.6278 991 | 0.9033 4.5204 992 | 1.2982 4.3518 993 | 4.6507 3.9855 994 | 5.2804 3.8763 995 | 5.7031 4.1011 996 | 0.7648 5.2418 997 | 4.2416 4.0762 998 | 4.9642 4.1004 999 | 6.0255 3.3242 1000 | 5.3024 3.9224 1001 | -------------------------------------------------------------------------------- /em/data/amix1-val.dat: -------------------------------------------------------------------------------- 1 | 5.4588 4.0434 2 | 5.2669 3.3546 3 | 5.1538 4.5024 4 | 5.3268 3.668 5 | 5.429 4.9751 6 | 0.9751 4.0879 7 | 0.8857 3.2649 8 | 5.0562 3.4437 9 | 3.9136 4.5926 10 | 5.5667 3.9686 11 | 4.4762 3.6467 12 | 0.9227 4.5618 13 | 5.2953 3.1459 14 | 1.0783 4.4457 15 | 5.0753 4.4619 16 | 1.1351 3.6057 17 | 5.5509 4.4574 18 | 3.247 3.8896 19 | 5.3474 3.461 20 | 3.9616 4.5846 21 | 4.3323 4.6896 22 | 4.9953 3.7002 23 | 5.2051 4.264 24 | 1.1455 3.1367 25 | 3.0411 3.3502 26 | 0.5151 3.8823 27 | 0.9139 3.3276 28 | 2.5009 4.2456 29 | 5.2816 3.8505 30 | 5.0324 4.0918 31 | 2.5245 3.7778 32 | 2.5764 3.9915 33 | 4.9752 3.9037 34 | 0.8873 4.1027 35 | 1.2373 3.3638 36 | 2.6872 3.5082 37 | 1.0648 4.4754 38 | 0.9477 3.7728 39 | 3.0601 3.5764 40 | 4.2175 4.1768 41 | 3.1501 3.4225 42 | 4.0516 3.8377 43 | 4.6686 3.6681 44 | 0.9204 3.8444 45 | 2.86 5.4809 46 | 4.8765 4.1807 47 | 0.6867 4.2577 48 | 3.0696 3.9405 49 | 1.4491 3.9061 50 | 4.8242 3.7474 51 | 3.0737 3.6106 52 | 5.1281 3.8167 53 | 4.9932 4.0138 54 | 3.0549 4.5518 55 | 0.9394 3.6332 56 | 4.6533 4.1332 57 | 2.7773 3.8045 58 | 4.7364 4.3303 59 | 5.1487 4.1306 60 | 4.1647 3.7369 61 | 4.8516 4.044 62 | 4.5525 4.3255 63 | 4.7004 3.1804 64 | 5.2861 3.5308 65 | 3.2726 4.1039 66 | 3.0458 4.8869 67 | 5.5777 4.4507 68 | 5.0145 4.9718 69 | 3.3416 4.9069 70 | 4.5464 4.2107 71 | 4.521 4.5344 72 | 1.1868 4.172 73 | 2.4532 4.2446 74 | 4.0391 4.398 75 | 1.0141 4.3971 76 | 4.853 3.5327 77 | 5.3511 3.8039 78 | 5.0941 3.533 79 | 2.9149 3.4171 80 | 2.8519 4.237 81 | 0.711 4.6081 82 | 5.0858 3.9328 83 | 2.9362 3.4897 84 | 4.8134 2.8801 85 | 6.0261 4.1628 86 | 5.8844 3.2556 87 | 4.9089 4.0682 88 | 1.1478 3.828 89 | 3.2382 3.293 90 | 4.7177 2.5015 91 | 4.585 3.9819 92 | 5.8108 4.0323 93 | 2.7276 4.4238 94 | 3.2247 4.2321 95 | 3.4443 3.6305 96 | 5.3126 4.3382 97 | 5.2237 3.7511 98 | 5.4196 4.0605 99 | 4.612 4.5332 100 | 2.5552 3.6513 101 | 1.3682 3.9216 102 | 5.0843 4.0347 103 | 5.0957 4.1252 104 | 4.5013 4.4991 105 | 5.3706 2.84 106 | 2.5647 4.1755 107 | 4.4386 4.4103 108 | 4.5611 5.0181 109 | 4.5267 3.2452 110 | 5.3053 4.1474 111 | 5.2279 3.6808 112 | 2.6075 4.2638 113 | 5.0631 5.342 114 | 2.931 4.4386 115 | 3.3528 4.6242 116 | 5.6624 3.9968 117 | 3.0357 3.7495 118 | 0.8147 4.9599 119 | 1.0776 3.0912 120 | 5.1942 4.0759 121 | 3.1221 3.2972 122 | 1.4802 4.401 123 | 4.8954 4.4617 124 | 5.6845 3.8028 125 | 4.6837 3.171 126 | 2.7677 3.2948 127 | 5.3021 4.3943 128 | 4.6742 4.0818 129 | 0.8057 3.8054 130 | 5.102 4.8595 131 | 5.0186 3.5459 132 | 5.2223 3.9699 133 | 3.5308 3.7035 134 | 3.0673 4.0118 135 | 4.6243 4.0785 136 | 5.2667 4.0823 137 | 1.1471 3.4876 138 | 0.9982 3.9373 139 | 5.4543 4.6206 140 | 4.6934 4.191 141 | 3.4778 4.031 142 | 4.9007 4.0918 143 | 3.253 3.9989 144 | 2.9854 4.288 145 | 3.071 3.1229 146 | 5.2182 3.2985 147 | 2.714 3.4527 148 | 2.9685 3.1761 149 | 4.3204 3.808 150 | 4.7782 4.6227 151 | -------------------------------------------------------------------------------- /em/gmm.py: -------------------------------------------------------------------------------- 1 | from scipy.special import logsumexp 2 | from utils.misc_utils import * 3 | 4 | 5 | 6 | class GMM(object): 7 | def __init__(self, k, tol = 1e-3, reg_covar = 1e-7): 8 | self.K = k 9 | self.tol = tol 10 | self.reg_covar=reg_covar 11 | self.times = 100 12 | self.loglike = 0 13 | 14 | 15 | def fit(self, trainMat): 16 | self.X = trainMat 17 | self.N, self.D = trainMat.shape 18 | self.GMM_EM() 19 | 20 | # gmm入口 21 | def GMM_EM(self): 22 | self.scale_data() 23 | self.init_params() 24 | for i in range(self.times): 25 | log_prob_norm, self.gamma = self.e_step(self.X) 26 | self.mu, self.cov, self.alpha = self.m_step() 27 | newloglike = self.loglikelihood(log_prob_norm) 28 | # print(newloglike) 29 | if abs(newloglike - self.loglike) < self.tol: 30 | break 31 | self.loglike = newloglike 32 | 33 | 34 | #预测类别 35 | def predict(self, testMat): 36 | log_prob_norm, gamma = self.e_step(testMat) 37 | category = gamma.argmax(axis=1).flatten().tolist()[0] 38 | return np.array(category) 39 | 40 | 41 | #e步,估计gamma 42 | def e_step(self, data): 43 | gamma_log_prob = np.mat(np.zeros((self.N, self.K))) 44 | 45 | for k in range(self.K): 46 | gamma_log_prob[:, k] = log_weight_prob(data, self.alpha[k], self.mu[k], self.cov[k]) 47 | 48 | log_prob_norm = logsumexp(gamma_log_prob, axis=1) 49 | log_gamma = gamma_log_prob - log_prob_norm[:, np.newaxis] 50 | return log_prob_norm, np.exp(log_gamma) 51 | 52 | 53 | #m步,最大化loglikelihood 54 | def m_step(self): 55 | newmu = np.zeros([self.K, self.D]) 56 | newcov = [] 57 | newalpha = np.zeros(self.K) 58 | for k in range(self.K): 59 | Nk = np.sum(self.gamma[:, k]) 60 | newmu[k, :] = np.dot(self.gamma[:, k].T, self.X) / Nk 61 | cov_k = self.compute_cov(k, Nk) 62 | newcov.append(cov_k) 63 | newalpha[k] = Nk / self.N 64 | 65 | newcov = np.array(newcov) 66 | return newmu, newcov, newalpha 67 | 68 | 69 | #计算cov,防止非正定矩阵reg_covar 70 | def compute_cov(self, k, Nk): 71 | diff = np.mat(self.X - self.mu[k]) 72 | cov = np.array(diff.T * np.multiply(diff, self.gamma[:,k]) / Nk) 73 | cov.flat[::self.D + 1] += self.reg_covar 74 | return cov 75 | 76 | 77 | #数据预处理 78 | def scale_data(self): 79 | for d in range(self.D): 80 | max_ = self.X[:, d].max() 81 | min_ = self.X[:, d].min() 82 | self.X[:, d] = (self.X[:, d] - min_) / (max_ - min_) 83 | self.xj_mean = np.mean(self.X, axis=0) 84 | self.xj_s = np.sqrt(np.var(self.X, axis=0)) 85 | 86 | 87 | #初始化参数 88 | def init_params(self): 89 | self.mu = np.random.rand(self.K, self.D) 90 | self.cov = np.array([np.eye(self.D)] * self.K) * 0.1 91 | self.alpha = np.array([1.0 / self.K] * self.K) 92 | 93 | 94 | #log近似算法,可以防止underflow,overflow 95 | def loglikelihood(self, log_prob_norm): 96 | return np.sum(log_prob_norm) 97 | 98 | 99 | # def loglikelihood(self): 100 | # P = np.zeros([self.N, self.K]) 101 | # for k in range(self.K): 102 | # P[:,k] = prob(self.X, self.mu[k], self.cov[k]) 103 | # 104 | # return np.sum(np.log(P.dot(self.alpha))) 105 | 106 | 107 | -------------------------------------------------------------------------------- /em/gmm_penality.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import norm 3 | from scipy.special import logsumexp 4 | 5 | 6 | class GMMPenality(object): 7 | 8 | def __init__(self, K, tol = 1e-3, penalty = 1): 9 | self.K = K 10 | self.tol = tol 11 | self.times = 80 12 | self.penalty = penalty 13 | self.beginPenaltyTime = 10 14 | 15 | 16 | def fit(self, train): 17 | # self.X = self.scale_data(train) 18 | self.X = train 19 | self.GMM_EM() 20 | 21 | 22 | def init_paras(self): 23 | self.N, self.D = self.X.shape 24 | self.means = np.mean(self.X, axis=0) 25 | self.std = np.sqrt(np.var(self.X, axis=0)) 26 | 27 | self.mu = np.random.rand(self.K, self.D) 28 | self.sigma = np.random.rand(self.K, self.D) 29 | self.alpha = np.array([1.0 / self.K] * self.K) 30 | 31 | self.gamma = np.random.rand(self.N, self.K) 32 | self.loglike = 0 33 | 34 | def scale_data(self, X): 35 | for i in range(X.shape[1]): 36 | max_ = X[:, i].max() 37 | min_ = X[:, i].min() 38 | X[:, i] = (X[:, i] - min_) / (max_ - min_) 39 | return X 40 | 41 | def GMM_EM(self): 42 | self.init_paras() 43 | for i in range(self.times): 44 | #m step 45 | self.m_step(i) 46 | # e step 47 | logGammaNorm, self.gamma= self.e_step(self.X) 48 | #loglikelihood 49 | loglike = self.logLikelihood(logGammaNorm) 50 | #penalty 51 | pen = 0 52 | if i >= self.beginPenaltyTime: 53 | for j in range(self.D): 54 | pen += self.penalty * np.sum(abs(self.mu[:,j] - self.means[j])) / self.std[j] 55 | 56 | # print("step = %s, alpha = %s, loglike = %s"%(i, [round(p[0], 5) for p in self.alpha.tolist()], round(loglike - pen, 5))) 57 | # if abs(self.loglike - loglike) < self.tol: 58 | # break 59 | # else: 60 | 61 | self.loglike = loglike - pen 62 | 63 | def e_step(self, data): 64 | N, D = data.shape 65 | gamma = np.random.rand(N, self.K) 66 | for k in range(self.K): 67 | gamma[:, k] = np.log(self.alpha[k]) + self.log_prob(data, self.mu[k,], self.sigma[k, :]) 68 | 69 | logGammaNorm = logsumexp(gamma, axis=1) 70 | gamma = np.exp(gamma - logGammaNorm[:, np.newaxis]) 71 | return logGammaNorm, gamma 72 | 73 | 74 | def m_step(self, step): 75 | gammaNorm = np.array(np.sum(self.gamma, axis=0)).reshape(self.K, 1) 76 | self.alpha = gammaNorm / np.sum(gammaNorm) 77 | for k in range(self.K): 78 | Nk = gammaNorm[k] 79 | if Nk == 0: 80 | continue 81 | for j in range(self.D): 82 | if step >= self.beginPenaltyTime: 83 | # 算出penality的偏移量shift,通过当前维度的mu和样本均值比较,确定shift的符号,相当于把lasso的绝对值拆开了 84 | shift = np.square(self.sigma[k, j]) * self.penalty / (self.std[j] * Nk) 85 | if self.mu[k, j] >= self.means[j]: 86 | shift = shift 87 | else: 88 | shift = -shift 89 | else: 90 | shift = 0 91 | self.mu[k, j] = np.dot(self.gamma[:, k].T, self.X[:, j]) / Nk - shift 92 | self.sigma[k, j] = np.sqrt(np.sum(np.multiply(self.gamma[:, k], np.square(self.X[:, j] - self.mu[k, j]))) / Nk) 93 | 94 | 95 | def predict(self, test): 96 | logGammaNorm, gamma = self.e_step(test) 97 | category = gamma.argmax(axis=1).flatten().tolist() 98 | return np.array(category),self.logLikelihood(logGammaNorm) 99 | 100 | 101 | #计算极大似然,通过np.sum所有的logGammaNorm 102 | def logLikelihood(self, logNorm): 103 | return np.sum(logNorm) 104 | 105 | 106 | #计算高斯密度概率函数,样本的高斯概率密度函数,其实就是每个一维mu,sigma的高斯的和 107 | def log_prob(self, X, mu, sigma): 108 | N, D = X.shape 109 | logRes = np.zeros(N) 110 | for i in range(N): 111 | a = norm.logpdf(X[i,:], loc=mu, scale=sigma) 112 | logRes[i] = np.sum(a) 113 | return logRes -------------------------------------------------------------------------------- /em/main.py: -------------------------------------------------------------------------------- 1 | from em.gmm import * 2 | from sklearn import mixture 3 | 4 | 5 | 6 | def checkResult(): 7 | X = np.loadtxt("./data/amix1-est.dat") 8 | searchK = 4 9 | epoch = 5 10 | maxLogLikelihood = 0 11 | maxResult = None 12 | maxK = 0 13 | alpha = None 14 | for i in range(2, searchK): 15 | k = i 16 | for j in range(epoch): 17 | model1 = GMM(k) 18 | model1.fit(X) 19 | if model1.loglike > maxLogLikelihood: 20 | maxLogLikelihood = model1.loglike 21 | maxResult = model1.predict(X) 22 | maxK = k 23 | alpha = model1.alpha 24 | 25 | alpha, maxResult = changeLabel(alpha, maxResult) 26 | print("my gmm k = %s, alpha = %s, maxloglike = %s"%(maxK,[round(a, 5) for a in alpha],maxLogLikelihood)) 27 | 28 | 29 | model2 = mixture.BayesianGaussianMixture(n_components=maxK,covariance_type='full') 30 | result2 = model2.fit_predict(X) 31 | alpha2, result2 = changeLabel(model2.weights_.tolist(), result2) 32 | 33 | result = np.sum(maxResult==result2) 34 | percent = np.mean(maxResult==result2) 35 | print("sklearn gmm k = %s, alpha = %s, maxloglike = %s"%(maxK,[round(a, 5) for a in alpha2],model2.lower_bound_)) 36 | 37 | print("succ = %s/%s"%(result, len(result2))) 38 | print("succ = %s"%(percent)) 39 | 40 | print(maxResult[:100]) 41 | print(result2[:100]) 42 | 43 | 44 | def changeLabel(alpha, predict): 45 | alphaSorted = sorted(alpha, reverse=True) 46 | labelOld = [] 47 | for i in predict: 48 | if i not in labelOld: 49 | labelOld.append(i) 50 | if len(labelOld) == len(alpha): 51 | break 52 | labelNew = sorted(labelOld) 53 | for i, old in enumerate(labelOld): 54 | predict[predict == old] = labelNew[i] + 100 55 | return alphaSorted, predict - 100 56 | 57 | 58 | if __name__ == "__main__": 59 | checkResult() 60 | 61 | -------------------------------------------------------------------------------- /em/main_panelity.py: -------------------------------------------------------------------------------- 1 | 2 | from em.gmm_penality import * 3 | from sklearn import mixture 4 | 5 | def getDataList(): 6 | fileNameList = [] 7 | dataList = ["amix1-est", "amix2-est", "golub-est"] 8 | for file in dataList: 9 | filePath = "./data/%s.dat" % (file) 10 | fileNameList.append(filePath) 11 | return fileNameList 12 | 13 | def checkResult(): 14 | fileNameList = getDataList() 15 | for fileName in fileNameList: 16 | X = np.loadtxt(fileName) 17 | searchK = 3 18 | penalty = [0, 1, 2] 19 | epoch = 2 20 | maxLogLikelihood = float('-inf') 21 | maxResult = None 22 | maxK = 0 23 | alpha = None 24 | bestP = 0 25 | for i in range(2, searchK): 26 | k = i 27 | for p in penalty: 28 | for j in range(epoch): 29 | model1 = GMMPenality(k, penalty = p) 30 | model1.fit(X) 31 | if model1.loglike > maxLogLikelihood: 32 | maxLogLikelihood = model1.loglike 33 | maxResult, _ = model1.predict(X) 34 | maxK = k 35 | alpha = model1.alpha 36 | bestP = p 37 | alphaSorted = sorted(model1.alpha.tolist(), reverse=True) 38 | print("fileName = %s, k = %s, penalty = %s alpha = %s, loglike = %s" % (fileName.split("/")[-1], k, p, [round(p[0], 5) for p in alphaSorted], round(model1.loglike, 5))) 39 | 40 | alpha, maxResult = changeLabel(alpha.reshape(1, -1).tolist()[0], maxResult) 41 | print("myself GMM alpha = %s, loglikelihood = %s, bestP = %s"% 42 | ([round(a, 5) for a in alpha], round(maxLogLikelihood, 5), bestP)) 43 | 44 | # maxK = 3 45 | model2 = mixture.BayesianGaussianMixture(n_components=maxK,covariance_type='full') 46 | result2 = model2.fit_predict(X) 47 | sklearnAlpha, result2 = changeLabel(model2.weights_.tolist(), result2) 48 | 49 | result = np.sum(maxResult==result2) 50 | percent = np.mean(maxResult==result2) 51 | 52 | print("sklearn GMM alpha = %s, loglikelihood = %s"% 53 | ([round(a, 5) for a in sklearnAlpha], round(model2.lower_bound_, 5))) 54 | print("succ = %s/%s"%(result, len(result2))) 55 | print("succ = %s"%(percent)) 56 | print(maxResult[:20]) 57 | print(result2[:20]) 58 | suffix = ["tst", "val"] 59 | for suf in suffix: 60 | newfileName = fileName.replace("-est", "-%s"%suf) 61 | newX = np.loadtxt(newfileName) 62 | maxResult, loglike = model1.predict(newX) 63 | print("fileName = %s, loglike = %s"%(newfileName.split("/")[-1],loglike)) 64 | 65 | 66 | def changeLabel(alpha, predict): 67 | alphaSorted = sorted(alpha, reverse=True) 68 | labelOld = [] 69 | for i in predict: 70 | if i not in labelOld: 71 | labelOld.append(i) 72 | if len(labelOld) == len(alpha): 73 | break 74 | labelNew = sorted(labelOld) 75 | for i, old in enumerate(labelOld): 76 | predict[predict == old] = labelNew[i] + 100 77 | return alphaSorted, predict - 100 78 | 79 | 80 | if __name__ == "__main__": 81 | checkResult() -------------------------------------------------------------------------------- /kmeans/README.md: -------------------------------------------------------------------------------- 1 | # 实现kmeans和kmeans++算法 2 | 3 | # 结果比较 4 | kmeans的结果用jupyter notebook打开kmeans.ipynb 5 | 6 | # 相关博客 7 | #### [1. K-Means原理解析](https://www.cnblogs.com/huangyc/p/10224045.html) 8 | #### [2. K-Means的优化](https://www.cnblogs.com/huangyc/p/10226492.html) 9 | #### [3. sklearn的K-Means的使用](https://www.cnblogs.com/huangyc/p/10229064.html) 10 | #### [4. K-Means和K-Means++实现](https://www.cnblogs.com/huangyc/p/10274001.html) 11 | -------------------------------------------------------------------------------- /kmeans/kmeans_base.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import time 3 | 4 | from sklearn.cluster import KMeans 5 | from sklearn import datasets 6 | 7 | import numpy as np 8 | 9 | from utils.misc_utils import distance, check_random_state, sortLabel 10 | 11 | 12 | 13 | class KMeansBase(object): 14 | 15 | def __init__(self, n_clusters = 8, init = "random", max_iter = 300, random_state = None, n_init = 10, tol = 1e-4): 16 | self.k = n_clusters # 聚类个数 17 | self.init = init # 输出化方式 18 | self.max_iter = max_iter # 最大迭代次数 19 | self.random_state = check_random_state(random_state) #随机数 20 | self.n_init = n_init # 进行多次聚类,选择最好的一次 21 | self.tol = tol # 停止聚类的阈值 22 | 23 | # fit对train建立模型 24 | def fit(self, dataset): 25 | self.tol = self._tolerance(dataset, self.tol) 26 | 27 | bestError = None 28 | bestCenters = None 29 | bestLabels = None 30 | for i in range(self.n_init): 31 | labels, centers, error = self._kmeans(dataset) 32 | if bestError == None or error < bestError: 33 | bestError = error 34 | bestCenters = centers 35 | bestLabels = labels 36 | self.centers = bestCenters 37 | return bestLabels, bestCenters, bestError 38 | 39 | # predict根据训练好的模型预测新的数据 40 | def predict(self, X): 41 | return self.update_labels_error(X, self.centers)[0] 42 | 43 | # 合并fit和predict 44 | def fit_predict(self, dataset): 45 | self.fit(dataset) 46 | return self.predict(dataset) 47 | 48 | # kmeans的主要方法,完成一次聚类的过程 49 | def _kmeans(self, dataset): 50 | self.dataset = np.array(dataset) 51 | bestError = None 52 | bestCenters = None 53 | bestLabels = None 54 | centerShiftTotal = 0 55 | centers = self._init_centroids(dataset) 56 | 57 | for i in range(self.max_iter): 58 | oldCenters = centers.copy() 59 | labels, error = self.update_labels_error(dataset, centers) 60 | centers = self.update_centers(dataset, labels) 61 | 62 | if bestError == None or error < bestError: 63 | bestLabels = labels.copy() 64 | bestCenters = centers.copy() 65 | bestError = error 66 | 67 | ## oldCenters和centers的偏移量 68 | centerShiftTotal = np.linalg.norm(oldCenters - centers) ** 2 69 | if centerShiftTotal <= self.tol: 70 | break 71 | 72 | #由于上面的循环,最后一步更新了centers,所以如果和旧的centers不一样的话,再更新一次labels,error 73 | if centerShiftTotal > 0: 74 | bestLabels, bestError = self.update_labels_error(dataset, bestCenters) 75 | 76 | return bestLabels, bestCenters, bestError 77 | 78 | 79 | # k个数据点,随机生成 80 | def _init_centroids(self, dataset): 81 | n_samples = dataset.shape[0] 82 | centers = [] 83 | if self.init == "random": 84 | seeds = self.random_state.permutation(n_samples)[:self.k] 85 | centers = dataset[seeds] 86 | elif self.init == "k-means++": 87 | pass 88 | return np.array(centers) 89 | 90 | 91 | # 把tol和dataset相关联 92 | def _tolerance(self, dataset, tol): 93 | variances = np.var(dataset, axis=0) 94 | return np.mean(variances) * tol 95 | 96 | 97 | # 更新每个点的标签,和计算误差 98 | def update_labels_error(self, dataset, centers): 99 | labels = self.assign_points(dataset, centers) 100 | new_means = defaultdict(list) 101 | error = 0 102 | for assignment, point in zip(labels, dataset): 103 | new_means[assignment].append(point) 104 | 105 | for points in new_means.values(): 106 | newCenter = np.mean(points, axis=0) 107 | error += np.sqrt(np.sum(np.square(points - newCenter))) 108 | 109 | return labels, error 110 | 111 | # 更新中心点 112 | def update_centers(self, dataset, labels): 113 | new_means = defaultdict(list) 114 | centers = [] 115 | for assignment, point in zip(labels, dataset): 116 | new_means[assignment].append(point) 117 | 118 | for points in new_means.values(): 119 | newCenter = np.mean(points, axis=0) 120 | centers.append(newCenter) 121 | 122 | return np.array(centers) 123 | 124 | 125 | # 分配每个点到最近的center 126 | def assign_points(self, dataset, centers): 127 | labels = [] 128 | for point in dataset: 129 | shortest = float("inf") # 正无穷 130 | shortest_index = 0 131 | for i in range(len(centers)): 132 | val = distance(point[np.newaxis], centers[i]) 133 | if val < shortest: 134 | shortest = val 135 | shortest_index = i 136 | labels.append(shortest_index) 137 | return labels 138 | 139 | 140 | if __name__ == "__main__": 141 | # 用iris数据集测试准确度和速度 142 | iris = datasets.load_iris() 143 | km = KMeansBase(3) 144 | startTime = time.time() 145 | labels = km.fit_predict(iris.data) 146 | print("km time", time.time() - startTime) 147 | print(np.array(sortLabel(labels))) 148 | 149 | kmeans = KMeans(init='k-means++', n_clusters=3, n_init=10) 150 | startTime = time.time() 151 | label = kmeans.fit_predict(iris.data) 152 | print("sklearn time", time.time() - startTime) 153 | print(sortLabel(label)) -------------------------------------------------------------------------------- /kmeans/kmeans_plus.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from sklearn.cluster import KMeans 4 | from sklearn import datasets 5 | 6 | import numpy as np 7 | 8 | from utils.misc_utils import distance, sortLabel 9 | from kmeans.kmeans_base import KMeansBase 10 | 11 | 12 | class KMeansPlusPlus(KMeansBase): 13 | 14 | def __init__(self, n_clusters = 8, init="random", max_iter = 300, random_state = None, n_init = 10, tol = 1e-4): 15 | super(KMeansPlusPlus, self).__init__( 16 | n_clusters=n_clusters, init=init, max_iter=max_iter, 17 | random_state=random_state, tol=tol, n_init=n_init) 18 | 19 | def _init_centroids(self, dataset): 20 | n_samples = dataset.shape[0] 21 | centers = [] 22 | if self.init == "random": 23 | seeds = self.random_state.permutation(n_samples)[:self.k] 24 | centers = dataset[seeds] 25 | elif self.init == "k-means++": 26 | centers = self._k_means_plus_plus(dataset) 27 | return np.array(centers) 28 | 29 | 30 | # kmeans++的初始化方式,加速聚类速度 31 | def _k_means_plus_plus(self, dataset): 32 | n_samples, n_features = dataset.shape 33 | centers = np.empty((self.k, n_features)) 34 | # n_local_trials是每次选择候选点个数 35 | n_local_trials = None 36 | if n_local_trials is None: 37 | n_local_trials = 2 + int(np.log(self.k)) 38 | 39 | # 第一个随机点 40 | center_id = self.random_state.randint(n_samples) 41 | centers[0] = dataset[center_id] 42 | 43 | # closest_dist_sq是每个样本,到所有中心点最近距离 44 | # 假设现在有3个中心点,closest_dist_sq = [min(样本1到3个中心距离),min(样本2到3个中心距离),...min(样本n到3个中心距离)] 45 | closest_dist_sq = distance(centers[0, np.newaxis], dataset) 46 | 47 | # current_pot所有最短距离的和 48 | current_pot = closest_dist_sq.sum() 49 | 50 | for c in range(1, self.k): 51 | # 选出n_local_trials随机址,并映射到current_pot的长度 52 | rand_vals = self.random_state.random_sample(n_local_trials) * current_pot 53 | # np.cumsum([1,2,3,4]) = [1, 3, 6, 10],就是累加当前索引前面的值 54 | # np.searchsorted搜索随机出的rand_vals落在np.cumsum(closest_dist_sq)中的位置。 55 | # candidate_ids候选节点的索引 56 | candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq), rand_vals) 57 | 58 | # best_candidate最好的候选节点 59 | # best_pot最好的候选节点计算出的距离和 60 | # best_dist_sq最好的候选节点计算出的距离列表 61 | best_candidate = None 62 | best_pot = None 63 | best_dist_sq = None 64 | for trial in range(n_local_trials): 65 | # 计算每个样本到候选节点的欧式距离 66 | distance_to_candidate = distance(dataset[candidate_ids[trial], np.newaxis], dataset) 67 | 68 | # 计算每个候选节点的距离序列new_dist_sq, 距离总和new_pot 69 | new_dist_sq = np.minimum(closest_dist_sq, distance_to_candidate) 70 | new_pot = new_dist_sq.sum() 71 | 72 | # 选择最小的new_pot 73 | if (best_candidate is None) or (new_pot < best_pot): 74 | best_candidate = candidate_ids[trial] 75 | best_pot = new_pot 76 | best_dist_sq = new_dist_sq 77 | 78 | centers[c] = dataset[best_candidate] 79 | current_pot = best_pot 80 | closest_dist_sq = best_dist_sq 81 | 82 | return centers 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | iris = datasets.load_boston() 88 | km1 = KMeansBase(3) 89 | startTime = time.time() 90 | labels = km1.fit_predict(iris.data) 91 | print("km1 time",time.time() - startTime) 92 | print(np.array(sortLabel(labels))) 93 | 94 | km2 = KMeansPlusPlus(3, init="k-means++") 95 | startTime = time.time() 96 | labels = km2.fit_predict(iris.data) 97 | print("km2 time", time.time() - startTime) 98 | print(np.array(sortLabel(labels))) 99 | 100 | kmeans = KMeans(init='k-means++', n_clusters= 3, n_init=10) 101 | startTime = time.time() 102 | label = kmeans.fit_predict(iris.data) 103 | print("sklearn time",time.time() - startTime) 104 | print(sortLabel(label)) -------------------------------------------------------------------------------- /knn/.ipynb_checkpoints/KNN-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "import time\n", 12 | "searchPath=os.path.abspath('..')\n", 13 | "sys.path.append(searchPath)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from sklearn.model_selection import train_test_split\n", 23 | "from sklearn.datasets import load_iris\n", 24 | "import numpy as np\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from knn.knn_base import KNN\n", 27 | "from knn.knn_kdtree import KNNKdTree\n", 28 | "from utils.data_generater import random_points\n", 29 | "from utils.plot import plot_knn_predict" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "def getData(number):\n", 39 | " data = random_points(2, number)\n", 40 | " label = [0] * (number // 2) + [1] * (number // 2)\n", 41 | " return np.array(data), np.array(label)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "dataList = []\n", 51 | "labelList = []\n", 52 | "for num in [30, 500, 1000, 2000, 5000, 10000, 50000, 400000]:\n", 53 | " data, label = getData(num)\n", 54 | " dataList.append(data)\n", 55 | " labelList.append(label)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "models = {\"knn\":KNN(), \"kdtree\":KNNKdTree()}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "model = knn, dataNum = 30, takeTime = 0.00391\n", 77 | "model = kdtree, dataNum = 30, takeTime = 0.00417\n", 78 | "model = knn, dataNum = 500, takeTime = 0.03806\n", 79 | "model = kdtree, dataNum = 500, takeTime = 0.00856\n", 80 | "model = knn, dataNum = 1000, takeTime = 0.05203\n", 81 | "model = kdtree, dataNum = 1000, takeTime = 0.01386\n", 82 | "model = knn, dataNum = 2000, takeTime = 0.1387\n", 83 | "model = kdtree, dataNum = 2000, takeTime = 0.02863\n", 84 | "model = knn, dataNum = 5000, takeTime = 0.28177\n", 85 | "model = kdtree, dataNum = 5000, takeTime = 0.07277\n", 86 | "model = knn, dataNum = 10000, takeTime = 0.47404\n", 87 | "model = kdtree, dataNum = 10000, takeTime = 0.16433\n", 88 | "model = knn, dataNum = 50000, takeTime = 2.0887\n", 89 | "model = kdtree, dataNum = 50000, takeTime = 0.93545\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "for data, label in zip(dataList, labelList):\n", 95 | " for name, model in models.items():\n", 96 | " startTime = time.time()\n", 97 | " model.fit(data, label)\n", 98 | " for i in range(5):\n", 99 | " model.predict([0.3, 0.2])\n", 100 | " print(\"model = %s, dataNum = %s, takeTime = %s\"%(name, len(data), round(time.time() - startTime, 5)))" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "point = (0.3, 0.2)\n", 110 | "model = KNNKdTree()\n", 111 | "model.fit(dataList[0], labelList[0])\n", 112 | "plot_knn_predict(model, dataList[0], labelList[0], point)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "Python 3", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.6.4" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /knn/KNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "import time\n", 12 | "searchPath=os.path.abspath('..')\n", 13 | "sys.path.append(searchPath)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from sklearn.model_selection import train_test_split\n", 23 | "from sklearn.datasets import load_iris\n", 24 | "import numpy as np\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from knn.knn_base import KNN\n", 27 | "from knn.knn_kdtree import KNNKdTree\n", 28 | "from utils.data_generater import random_points\n", 29 | "from utils.plot import plot_knn_predict" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "def getData(number):\n", 39 | " data = random_points(2, number)\n", 40 | " label = [0] * (number // 2) + [1] * (number // 2)\n", 41 | " return np.array(data), np.array(label)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "dataList = []\n", 51 | "labelList = []\n", 52 | "for num in [30, 500, 1000, 2000, 5000, 10000, 50000, 400000]:\n", 53 | " data, label = getData(num)\n", 54 | " dataList.append(data)\n", 55 | " labelList.append(label)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "models = {\"knn\":KNN(), \"kdtree\":KNNKdTree()}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 6, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "model = knn, dataNum = 30, takeTime = 0.00391\n", 77 | "model = kdtree, dataNum = 30, takeTime = 0.00417\n", 78 | "model = knn, dataNum = 500, takeTime = 0.03806\n", 79 | "model = kdtree, dataNum = 500, takeTime = 0.00856\n", 80 | "model = knn, dataNum = 1000, takeTime = 0.05203\n", 81 | "model = kdtree, dataNum = 1000, takeTime = 0.01386\n", 82 | "model = knn, dataNum = 2000, takeTime = 0.1387\n", 83 | "model = kdtree, dataNum = 2000, takeTime = 0.02863\n", 84 | "model = knn, dataNum = 5000, takeTime = 0.28177\n", 85 | "model = kdtree, dataNum = 5000, takeTime = 0.07277\n", 86 | "model = knn, dataNum = 10000, takeTime = 0.47404\n", 87 | "model = kdtree, dataNum = 10000, takeTime = 0.16433\n", 88 | "model = knn, dataNum = 50000, takeTime = 2.0887\n", 89 | "model = kdtree, dataNum = 50000, takeTime = 0.93545\n", 90 | "model = knn, dataNum = 400000, takeTime = 16.82156\n", 91 | "model = kdtree, dataNum = 400000, takeTime = 11.85994\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "for data, label in zip(dataList, labelList):\n", 97 | " for name, model in models.items():\n", 98 | " startTime = time.time()\n", 99 | " model.fit(data, label)\n", 100 | " for i in range(5):\n", 101 | " model.predict([0.3, 0.2])\n", 102 | " print(\"model = %s, dataNum = %s, takeTime = %s\"%(name, len(data), round(time.time() - startTime, 5)))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 7, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "\n", 115 | "\n", 116 | "\n" 117 | ] 118 | }, 119 | { 120 | "data": { 121 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3XmUXFW5/vHvW9Vd1VMSMiFkItGEMTJoMyuESYJCAhIkQbgy5ncZxAtcEJlEcEAQvVwFIYKCCkRARpn0AoIgEYIMAgkaAmSCJCQhQ3fX/P7+qCb2UJ2uJFV1qqqfz1q90lX79Dlvzup++vQ+++xt7o6IiFSXUNAFiIhI4SncRUSqkMJdRKQKKdxFRKqQwl1EpAop3EVEqpDCXUSkCincRUSqkMJdRKQK1QR14CFDhvjo0aODOryISEV66aWXPnT3ob1tF1i4jx49mtmzZwd1eBGRimRm7+WznbplRESqkMJdRKQKKdxFRKqQwl1EpAoFdkNVSmfpe8t58IbHmffyfIZ9aismnXEoYz69TdBliUgRKdyr3AuPvswVx1xLOpUmlUjxylNv8KdfP80pV32Vo77+xYIcY9G/3mfmVffx9z+9RqS+loknHcjksyZS31RfkP2LyMazoFZiam5udg2FLK62lhhf2epUYi3xbm2R+ggzXv0Rw8duvVnHeGv225x/4OXE2xJk0pn1+/7ENkP52d9+QEM/BbxIIZnZS+7e3Nt26nOvYs8/8CJmlrMtk0rz2C1PbvYxfnTS9bSti60PdoBEW4Kl7y7j3v/5w2bvX0Q2jcK9Arhn8Mxq3JMb9XWrlq4mlUjlbEsl0yxbtGKz6lr63nKWzF+asy0RS/LYL5/arP2LyKZTn3sZc8/gLb+AlpvBW4EQXn8E1u8iLNTU69eP2XkbaiI1JHMEfF1DlB32HLdZ9cVaYoTDPV8fxFq7dweJSGnoyr2M+Zpvw7obwFcDSSAObQ/iK6fmdRW/6wE7MWjrgYRyBHC4NswhJ+y3WfUNG7tVzn0DWMjYZcJOm7V/Edl0Cvcy5ekl0HY/0NalJQHpRRB/otd9hEIhrnni24zaYQR1jVHqm+qo71fP4GED+dFTl9M4oHGzaqyN1HL8pVOoa4h2a4vURTj+0imbtX8R2XTqlilX8efo8Xevt+Jtj2J1E3vdzdARg5nx6o9468V5LJy7hMHDB7HrATsRChXm9/rR5xxOOp3hju/+HgzSqQyDthrA+b86izHjRxXkGCKy8RTuZSsMZtDTSFUL570nM2P7Pcax/R6b18fe076PPX8yR539RRa8uYhoQ4QR2w7rcZSOiJSGwr1cRfeDNblHukADVndEScvpTSRay9jdxgRdhgSsZXULT975HIv/9T4jtxvGhKn70ti/Ieiy+iSFe5my8BC88WRouY3O/e5RqN0hG/4iZeSVp17n0klXgWdHStU1Rrnp/F/z3Ye+xc777Rh0eX2ObqiWMWs6B/pfDuFRQAhsIDSeig26DduIbhmRYmtZ3cKlk39IrCW+fghsrCVO29oYlxzxA9rWdR0YIMWmK/cyZmZYw1HQcFTQpYhs0FMz/wqZ3DeIPOM8fdfzTDz5wBJX1bcp3CvIgrmLWbbgQ4aP24qtx3wi6HJE1ls87/0eH1qLtcR5v4cnmaV4FO4VYOl7y/nO0dewYM7i7BOn8SQ77LUtl951LgOG9A+6PBFGbjuMusZozknq6prqGDZ2qwCq6tvU517mEvEk3/jcJbz96nvE2xK0rG4lEUvyxnNzOf+g7xDUrJ4iHU2Yui8Wyj38NRwOsd8xe5e4IlG4l7lnfz+L1tWtnWZdhOzEXx+8s4zXnn4zoMpE/q2hXz3ff/gi6vvVU9dUB2Sv2BsHNPD9Ry6ivrEu4Ar7HnXLlLnXnnmTtnWxnG2JWII5s/6pOVykLIz/3A78bvFNPH33LN6fv5ThY7div2P2zjk9hRSfwr2IVry/ivuue5jnH5pNpC7CoScdwMSTD9yob/YBQ/oTrgmTTqW7tdVEamka2PvskCKlUt9Uz8STDgi6DEHhXjQL5i7mG/tcTLwtTjKefdJ04VuLeXjGn/jfv34v7yXoDj5hP+75yR9yhrtnMnz+6D0LWreIVAf1uRfJtafeQMvq1vXBDhBvTbBk3gfcfe1Dee9n5HbDOfaCyUQ7XO1byIg2RDjjupM0WkZEctIaqkXw0fLVHDfyP3MukgEwZPgg7lx400bt89U/v8E9P36IJfOXMmb8SI45bxLb7T62EOWKSAXJdw3VvLplzGwicB0QBm5296u6tI8CbgO2aN/mQnd/ZKOrrhKta9oI14Z7DPeebpBuyC4TdtKNUxHJW6/dMpadxOR64DBgR2CamXWdBegS4C533w2YCtxQ6EIryZajhhCu6Xnul81d3k5EpDf59LnvAcxz9/nungBmApO7bOPAx52/A4AlhSux8tTU1nDcxUfnHBUTbYhwwuVfCaAqEelL8umWGQ4s7PB6EdB1iMblwB/N7OtAI3BwQaqrYMecdwTx1ji/u/oBamrDZDIZInURzrv5dHbca9ugyxORKpdPuOd6prjrXdhpwK3ufq2Z7Q38xszGu3unxyrNbDowHWDUqOpegs3MOOGyY5hy3hH8c/bbROoibNv8ScJhTdUrIsWXT7gvAkZ2eD2C7t0upwATAdz9eTOrA4YAyzpu5O4zgBmQHS2ziTVXlPrGOnbZXzdCRaS08ulzfxEYZ2ZjzCxC9obpg122WQAcBGBmOwB1wPJCFioiIvnrNdzdPQWcBTwOzCE7KuYNM7vCzCa1b3YecJqZvQrcCZzomq5QRCQweY1zbx+z/kiX9y7r8PmbwL6FLU1ERDaVph8QEalCCncRkSqkcBcRqUIKdxGRKqRwFxGpQgp3EZEqpHAXEalCCncRkSqkcBcRqUIKdxGRKqRwFxGpQgp3EZEqpHAXEalCec0KKSKyqdydZ+6ZxV1X388H7y5nqzFbMvWbR/K5L++JWa6F3qQQFO4iUlQ3/fdtPDzj/4i1xAFYs2ItV5/4M+b87V9Mv/qEgKurXuqWEZGiWfjWYh668U/rg/1jsZY4D/zsUZa8/UFAlVU/hbuIFM3Tdz9POpnO2ZZJZ3jmnlklrqjvULiLSNHEW+OkU7nDPZVME2uJlbiivkPhLiJFs+sB46lvqsvZVtdUxy4TdipxRX2Hwr0Pi7fFefnJf/Dyk/8g3hbv/QtENtJuB32a4eO2pjbSeexGbbSGkdsNY9cDxgdUWfXTaJk+6v6fPcotF91BKJQdipbJOKd8bxpHfv2LAVcm1SQUCvGjJ7/NtafdyKyHXqI2UkMykWKfyc2cM+M/NRSyiBTufdCff/ccN194O/HWzlfrN3/rDrbYcgATjt03oMqkGjUOaOSyu85j7ap1fLh4JUOGD6LfwKagy6p66pbpg3516cxuwQ7Zm1+/umRmj1/n7iyYu5h5r7xDMpEsZolShfoNbGLM+FEK9hLRlXsfk0qmeP/tpT22vz9/Kalkiprazt8aLz/5D6499eesXr6GUCiEhYz/uPwYjjr7S/rTWqQMKdz7mHBNmJpIDcl47ivvmkiYcE2403vzXn6HSyddRbw10en9X148k5raGiadMbFo9YrIplG3TB9jZhwwdV/CteFubeHaMBOO3bfblfhtl99Foi3Rbft4a5zbvv27Hscxi0hwFO590GlXH8/grQcSqY+sfy9SH2Hw1gOZfk33uT5e/8sc3HPvKxFLsvS95cUqVQLg6cV48l+4d/+FLpVD3TJ90BZDBzDjtWt59OYnePLOZwE4cNrnOOzUg2js39Bt+0h9BD5qybmvdCpDXWO0qPVKaXjyDXz1NyH1Hlg2GrzxdKzxNN1XqUAK9z6qsX8DU849ginnHtHrtoeeOIF7fvwQyXiqW9vonUYwaKuBxShRSshTC/CVXwVvbX+jfTTVuutxUljTGcEVJ5tE3TLSq2MvmMzQkUOI1NWufy9cE6K+Xx3n3nx6gJVJoXjLzZCzG6YNWmbgrjlgKo3CXXrVOKCRG2b/kK9ecjTDPrUVQ4YP4gsnHsBNL/+IsbuOCbo8KYT4s0D3v8yyQpB8q5TVSAGoW0by0ti/geMuOprjLjo66FKkGGxD903SYPUlK0UKQ1fuIgL1U4DcszdiA6FmXEnLkc2ncBcRrPE4qBkFdLyCDwF12IAfarRMBVK4iwhm9digu6DpDAiPhNBgiH4BG3w3Ft0z6PJkE+TV525mE4HrgDBws7tflWObrwCXAw686u7HFbBOESkyCzVgTadDk0ZAVYNew93MwsD1wCHAIuBFM3vQ3d/ssM044FvAvu6+ysy2LFbBIiLSu3y6ZfYA5rn7fM8+jzwTmNxlm9OA6919FYC7LytsmSIisjHyCffhwMIOrxe1v9fRtsC2Zvacmc1q78YREZGA5NPnnus2eddppGqAccAEYATwFzMb7+4fddqR2XRgOsCoUaM2uliRvqxtXRvLF61k4CcGaMEL6VU+4b4IGNnh9QhgSY5tZrl7EnjHzN4iG/YvdtzI3WcAMwCam5t7mGdQRDpKxBLccM6v+NNtzxCuDZFKpGk+dBfOu/l0BgzpH3R5Uqby6ZZ5ERhnZmPMLAJMBR7sss39wAEAZjaEbDfN/EIWKuXL0x/gqXdx17zuxfCdKdfyp18/QyKWoG1tjGQ8yQuPvsw39r2EVLKnKQOkr+s13N09BZwFPA7MAe5y9zfM7Aozm9S+2ePACjN7E3gKON/dVxSraCkPnnyNzIeH48sPwVcciS/bh0xLz2uwVgv3JJmW35BZ/gUyS/cks/IkPDG7KMd65x/v8epTr3dbLCWdTLPy/VU8d98LRTmuVL68xrm7+yPAI13eu6zD5w6c2/4hfYCn3sZX/keHKWIBWmHtD8jghBqnBVle0bin8VWnQuJloH2mxMRz+MqX8P5XEmroOpBs87zy1BtkMrl7MNvWxZj1yN/Z/yv7FPSYUh30hKpsEl93A+ScBrYN1v2Y7B98VSj+f5B8lfXBvl4M1n674FPjRupqCYVz/5iaGfVaKEV6oHCXTRN/Dsj00JiE9HulrKZkvPX3//5rpZsQxJ8v6PH2mbw7nsl9nqMNEQ4+fr+CHk+qh8JdNo1Fem7zNFgPMwxWuh6DHbJ9U/GCHm7gJ7bg+MuOIdrQ+Qq9rjHK3pOa2WGvbQt6PKkems9dNk39UdByC5Bj9Z7wcCzc9Tm3KhE9AJKv0b1bBvAk1O5W8ENOu/AoxowfxR3f+z0L/7mEwVsPZMq5R/CFEydotkbpkcJdNok1nozHHoL0Mv4d8CEgig34foCVFZc1HIO3/KJ9SbqO3SV1UH84Fv5EUY671+GfZa/DP1uUfUt1UreMbBILDcAG3weNJ0NoS7AtIHooNvgeLPKZoMsrGgv1xwbfDbWfASJgjUA9NByH9b8i6PJE1tOVu2wyCw3A+p0L/frWCFirGYkNvgNPL4fMR1AzAtMydFJmFO5ScO4ZSM1t74PeHtvg+pyVy8JDITw06DJEclK4S0F57El8zcXgbWR7/Rxv+jqhxpODLk2kT1G4S8F44iX8o/+i20iStdeRob5qn1oVKUe6oSoF4+v+h5xDBGmDdddlu2tEpCQU7lI4ydd6bvNWyGiBLpFSUbhLAW1oxEgarKFklWyMTCZD69o20mlNWSzVQ+EuhVN/FFCbo8GgdlcsVF4LS6TTaX773d9z9NCT+fLgkzhyi6/x83NuJd5W2CkERIKgcJeCsabTIbw10HHoYy1YEzbgu0GV1aOrv/YzZl51L+tWtZBOpYm1xPnDTX/km4dcSaaHybpEKoXCXQom+/Tm/dB0NoTHQngbaDgBG/IwVjMm6PI6eW/OIp6992/EWzvPjZOIJXn7tfd4+Yl/BFSZSGFoKKQUlIWasKbToOm0oEvZoNmPvUJ2jZnuYutiPHvfC3z2kF1KXJVI4ejKXfokCxmQe0ZFMwiHNduiVDaFu/RJex3+WXqaLTfaWMfnp+xd2oJECkzhLn3SsE9txcEn7NdtEYxIfYQd996WnffbMaDKRApD4S591jd+Pp1TfnAcQ4YPAoP+g/tx7DeP5LsPXahFMKTiWU83lYqtubnZZ8+eHcixRbpydwV6hfL4c3jLLZBeAOHRWOOpWHSvoMsqGjN7yd2be9tOo2VEQMFeoTLrrod1M4C27BvpBXjiRbzpTEJN0wOtLWjqlhGRiuSphbDuRtYH+3ptsO6nePr9IMoqGwp3EalIHnuYzuvYdmqF2COlLKfsKNxFpDJl1gDJHhoTeGZtKaspOxUb7ulUmg8Xr6B1bdc/yUSkL7BIc88zjVpDVS/Uno+Ku6Hq7tx1zQPM/OH9JGNJMukMnzlkZ/7rxukMGT446PJEpFSi+0NoS0gvAlIdGmogtDVEPhdUZWWh4q7cb77wt/zmintYt6qFeFuCZCLFi4+9wpl7fIuW1S1BlyciJWIWxgbdAbWfAaJg/bL/RpqxwbdjVnHxVlAVdeW+ZuVa7v/poyRinfvZMukMLatbeOyXT3L0OUcEVJ2IlJqFh2CDf4unF0N6CYRHYOGtgy6rLFTUr7bXn51LTST376N4a4Jnfv+3ElckIuXAwsOxyO4K9g4qKtxrajf8h0ZttKL+EBERKZqKCvddJuxIJpN7uoS6xiiHnnhAiSsSESlPFRXu0fooZ/zkRKINkU7vR+pqGbn9cCYcu09AlYmIlJeK68c47JSDGDJiML++/C7eee09GgY0cPj/O4SvnD+Z2kiuxZlFRPqeigt3gN0P3ZXdD9016DJERMpWXt0yZjbRzN4ys3lmduEGtptiZm5mvU5HKSIixdNruJtZGLgeOAzYEZhmZt2WqTGzfsDZgMYjiogELJ8r9z2Aee4+390TwExgco7trgSuBmIFrE9ERDZBPuE+HFjY4fWi9vfWM7PdgJHu/ocN7cjMppvZbDObvXz58o0uVkRE8pNPuOdaomb9YHPLTuDwE+C83nbk7jPcvdndm4cOHZp/lSIislHyCfdFwMgOr0cASzq87geMB/5sZu8CewEP6qaqiEhw8gn3F4FxZjbGzCLAVODBjxvdfbW7D3H30e4+GpgFTHJ3rX4tIhKQXsPd3VPAWcDjwBzgLnd/w8yuMLNJxS5QREQ2Xl4PMbn7I8AjXd67rIdtJ2x+WSIisjkqam4ZERHJj8JdRKQKKdxFRKqQwl1EpAop3EVEqpDCXUSkCincRaqMp98n89F/k/ng02Q+2IHMiml44qWgy5ISU7iLVBFPL8U/PBJifwDiQBqSL+ErT8LjzwVdnpSQwl2kB+4JMi13kvnwKDLLDyWz5ko8vaT3LwyQr7sRfC2Q6dISw9dchnvuBeal+lTkMntSHdwTkHgRPAGR3bDQFkGXtJ57HF8xDVLzWL9EQetCvO1eGHQ7VtttvZryEH8MSOVuSy+DzBIID8/dLlVF4S6ByLQ+AGsvZ/2M0p7EG76K9buA7CzSwfKW2zsHOwAp8BS++gJsyAaXLgiOd71i78jA0yUrRYIV/E+R9DkenwVrLgVvwY6aix01F4hD6514y42bts9MC578J55eUZgi22bS46JiqQV4amHutqDVHQiEc7eFBkB4ZO42qToKdyk5X/e/5A7ONmi5Bfdk/vvyBJnV38GX7Y2vPBZfvj+ZlSfg6Q82s8iWntssDL5u8/ZfJNZ4Blg93dfYqYN+l2CWa+0dqUYKdym91Bzsy4uyH8+3ZT/aX0MKNiKY/aMLoO33QKw9kBOQmI2vmIJnWje9xtrP0vOPRwZqxmz6vovIakZig++GyN5kr+DDEB6DbXEdofpDgy5PSkh97lJ61q/nNk9BaAPtnTZdAPEnyA756ygNmXV42wNY47RNK7HpDDz+NNDWpaUeGk7GrG6T9lsKVvMpbNCtuMfBk1ioKeiSJAAKdym9hqn4vauAOHx5EQB+7wggBJHm/EfNJGe3d5HkamyF+FOwqeFeuz0MvAFffUH7XwQh8CQ0nIA1fX2T9llqZlGwaNBlSEAU7lJy1ngyHnuifTTKx+rAGrD+39uIPdWRe/32dqHGTawwy6L7wtC/QOpN8Dao2UFXwVIxFO5ScmZ1MPhOiD2CP3QPeAzqDsEajt24se7R/Xoe2mcNWP3RBag1BLXjN3s/IqWmcJdAmEWg/kis/shN30eoCe93Eaz9Pp1H39RDZC+I7LPZdVYiT87NPqmafAmsCRqOy/7itEjQpUkJKdylooUap+I1o/GWGyD5FoQGQ8PXsIYpZfEwVKl5/Dl81elAguwUBEth7TV47GEY9GsFfB+icJeKZ9G9sOheQZcROPcMvvp8uj9DEIPkHGh7EBqmBFGaBKDvXdqIVKvkP7I3fnNqw9t+V9JyJFgKd5Fq4a1scPRQZgNP3UrVUbiLVIvanbJj8XM3QvTzJS1HgqVwF6kSFuoPDdOA+hyNEazxxFKXJAFSuItUEev3TWg4HqjPDoOkDsJjsUG3Y+Gtgy5PSkijZUSqiFkY638+3nQmpOeD9cdqRgVdlgRA4S5ShSzUACE9WduXqVtGRKQKKdxFRKqQumWkbHns//CWmyC1EMLDsKbpED1UqwmJ5EHhLmUps/Yn0HIr0Na+QtNr+H1vQ/3fsf4XBVydSPlTt4yUHU8thJZf0m0VJG+D1pl46u1A6hKpJLpyl/ITfxzItF+xgz3fHvJfXgQY/vjDWL+zAytPpBLoyl3Kj7cBqZ4aNzA5loh8LK9wN7OJZvaWmc0zswtztJ9rZm+a2Wtm9oSZbVP4UqUU3NN4ZhXuiY37wtbWwhUR2ROsHr93RPZj7/rsx70j8Pu2xaJ7F+5YIlWq13A3szBwPXAYsCMwzcx27LLZy0Czu+8M3ANcXehCpbjc02TWXY8v2wNf9nl86WfJrP4Wnlnb+xd/8AFsuSUsXVqYYmp3h/BYoOvCErUQHgGRzxXmOCJVLJ8r9z2Aee4+37OXczOByR03cPen3P3jS7dZwIjClinF5msuhXUzwNdiX56PffltaHsIXzkN73GmwXb33w8tLdl/C8DMsEG3QvQgIILftz1+7ychOiE7R0ofXGFJZGPlc0N1OLCww+tFwJ4b2P4U4NHNKUpKy1OLoO0hIN6lJQHpRRB/EuoO7f6F48fD3Lngnn19xhlw5pmw/fbw+uubVZOFmrCB1+GZjyD9AYQ/gYUGbtY+RfqSfMI91xMjnnNDs+OBZmD/HtqnA9MBRo3SZEZlI/EcEOphdAr4I49hucL9jjvgiCNg+XJoa4NoNNs9c+edBSvNQltAaIuC7U+kr8jn79tFwMgOr0cAS7puZGYHAxcDk9y96yUgAO4+w92b3b156NChm1KvFEV4gwv49HgNsPPOcNllkEpBfX3238sug09/uhhFishGyCfcXwTGmdkYyy6dPhV4sOMGZrYbcBPZYF9W+DI7S6fTZDKZYh+m74juD57ueXRK/eE9f+1vfws1NXDeedl/f/Ob0tUtIj3qNdzdPQWcBTwOzAHucvc3zOwKM5vUvtk1QBNwt5m9YmYP9rC7zfL6c3M5e5+LOCwyjcOi07joi9/nvTmLinGoPsXCQ6HxZLqv4FMHNeMhsoHl2U4/PdvvfuWV2X9PP72YpYpInsw9Z/d50TU3N/vs2bPz3v7Vp9/g4i99n3jrv8dfm0FdUz03vHgVI7YdVowy+wx3x9vuh5YbsjdRbQA0Ho81Tif7B5uIlAMze8ndm3vbrmLGlF1/9i87BTtkB2nEWmL86tKZAVVVPcyMUMNRhIb+idBWcwh9YhahprMU7CIVqiLCfc3KtSx8q9s9XAA847zw8N9LXJGISHmriHDvTTAdSyIi5asiwr3/oH6M3C53n7qFjD2/9JkSVyQiUt4qItwBzvzfk4k2dO7/NYO6xjpOunJqQFWJFIe7k4glCGrAg1S+ign3Xfbfiasev5Qd9hqXvfkXDtF86G78dNb3NVJGqoa7c+91D3PssNM4oul4JvU/gZ+fcyux1pzPBYr0qGKGQnaUTqWxkBEKVczvJpG8/PSsm/njrX/uFOa10Vo+ufMorvvr9wiHwwFWJ+Wg6oZCdhSuCSvYpeosW7CcR295sttVejKeZMGcxfxNo8JkIyghRcrE7MdfJRTOPclP27oYz9z9fIkrkkqmcBepEGYbnN1NpBOFu0iZ2P2w3cikc98Dq2uqY/+v7FPiiqSSKdxFysTQEYP50vSDqWuIdnq/tq6W0TuNZPfDdg2oMqlECneRMnLG/5zEqT/8KoOHDcQMGvrXc+RZh/GjJ7+tkTKyUSpyKKRIX5BOpQnXKNCls6oeCinSFyjYZXMo3EVEqlA+C2SLFFwmk+Efz8xh+aIVbLPjCMZ95pNBlyRSVRTuUnLzXnmHSw6/ita1reDZ+VSGj92a7z1yEYO3Hhh0eSJVQd0yUlItq1s4/8DvsGLJStrWxmhbFyPWEufdNxZwwcHf0SyIIgWicJeS+uOvnyaVSHV7P53KsHzhCl575s0AqhKpPgp3Kak5z/+zx+lrU8kUb7/ybmkLEqlS6nOXkhoyYhDhmjDpVLpbW02khi22HBBAVRIU9wwknsXjfwNrxOq/iNWMDrqsqqArdympw045iHBtD+O3HfaZvHtpC5LAeOYjfMUk/KOzofUX0PIz/MMjyKy5WvdeCkDhLiU1crvhnHTlVKL1EULh7LdfTaSGaEOUi2ee021eFalevvpCSL0D3tr+TgqIQ+vtEH8iyNKqgrplpOSmnHsEux44ngeuf4wP5i9j7GfGMPnMiWw1esugS5MS8cxKiD8LJHO0tuEtv8DqDi51WVVF4S6BGLvrGM77xelBlyFBSb8PFgFP9NC+qLT1VCF1y4hI6YW37jnYAcIjS1dLlVK4i0jJWWgQRPcDanM01mON00teU7VRuItIIGzAVVAzDqyh/Z1aIAoNJ2J1BwZZWlVQn7uIBMJC/WHwfZB4Hk+8gFkj1E3EatQlUwgKdxEJjJlBdB8sqvVhC62iwt0zqyD2CJ5ejtVuB9GDMIsEXZaISNmpmHDPtD0Gqy9ofxXDrRGsHgbdjtWMCbQ2EZF8rP5wDe/PX8rgYYMYOmJwUY9VEeHuqYXtwR7r8GYLeCu+8mQY+mT2zzsRkTLUtq6NH592I3994EU6OWFIAAAFoUlEQVRqo7Uk40m222MsF93+DYYML07IV8RoGW+9g+yjyd1awFdB4oVSlyQikrdLJ/2Q5x54kUQsScvqVhKxJG/+9S3O3udiErENjPffDHmFu5lNNLO3zGyemV2Yoz1qZr9rb/+bmY0uaJWpf5E73AF3SC8o6OFERArlX3+fz9wX5pGMdZ5qIZ3KsG5VC8/cM6sox+013M0sDFwPHAbsCEwzsx27bHYKsMrdxwI/AX5Y0CprxpLzYQcAC0F4REEPJyJSKK8/O5dMOpOzrW1djNmPv1KU4+Zz5b4HMM/d57t7ApgJTO6yzWTgtvbP7wEOsgJ2glvDtB5KNbB+ENmzUIcSESmo+qY6wjW5ozYUMpoGNhbluPmE+3BgYYfXi9rfy7mNu6eA1UDB7hJYzTbQ/3tAtP2D7FNttgU26JeYVcStAxHpg/aZvHuPV+61dREOOWH/ohw3n1TMdQXedSb9fLbBzKab2Wwzm718+fJ86lsv1DAJG/oUNJ0LDSdj/b+DbfkMVjN2o/YjIlJK/Qf3Y/o1JxBt6PxMTl1jlIO++jm22704GZbPUMhFQMfngUcAS3rYZpGZ1QADgJVdd+TuM4AZAM3NzRu91IqFh2BNJ23sl4mIBGrymYcxevwoZl51H+++sYghwwcx5dwj2G/KXkU7Zj7h/iIwzszGAIuBqcBxXbZ5EPga8DwwBXjStU6WiMh6u+y/E7vsv1PJjtdruLt7yszOAh4HwsAv3f0NM7sCmO3uDwK3AL8xs3lkr9inFrNoERHZsLyeUHX3R4BHurx3WYfPY8AxhS1NREQ2lYaZiIhUIYW7iEgVUriLiFQhC2pQi5ktB97biC8ZAnxYpHIqlc5Jdzon3emcdFbp52Mbdx/a20aBhfvGMrPZ7t4cdB3lROekO52T7nROOusr50PdMiIiVUjhLiJShSop3GcEXUAZ0jnpTuekO52TzvrE+aiYPncREclfJV25i4hInsou3ANf0q8M5XFOzjWzN83sNTN7wsy2CaLOUurtnHTYboqZuZlV9eiIfM6HmX2l/fvkDTO7o9Q1lloePzejzOwpM3u5/Wfni0HUWTTuXjYfZCcmexv4JBABXgV27LLNGcCN7Z9PBX4XdN1lcE4OABraPz9d52T9dv2AZ4BZQHPQdQf8PTIOeBkY2P56y6DrLoNzMgM4vf3zHYF3g667kB/lduUe+JJ+ZajXc+LuT7l7a/vLWWTn3K9m+XyfAFwJXA3ESllcAPI5H6cB17v7KgB3X1biGkstn3PiQP/2zwfQfZ2KilZu4R74kn5lKJ9z0tEpwKNFrSh4vZ4TM9sNGOnufyhlYQHJ53tkW2BbM3vOzGaZ2cSSVReMfM7J5cDxZraI7Ky3Xy9NaaWR15S/JVSwJf2qSN7/XzM7HmgGirMoY/nY4Dmx7KK6PwFOLFVBAcvne6SGbNfMBLJ/2f3FzMa7+0dFri0o+ZyTacCt7n6tme1Ndk2K8e6ee8HTClNuV+4bs6QfG1rSr4rkc04ws4OBi4FJ7h4vUW1B6e2c9APGA382s3eBvYAHq/imar4/Nw+4e9Ld3wHeIhv21Sqfc3IKcBeAuz8P1JGdd6YqlFu4r1/Sz8wiZG+YPthlm4+X9IO+saRfr+ekvQviJrLBXu19qdDLOXH31e4+xN1Hu/tosvchJrn77GDKLbp8fm7uJ3vjHTMbQrabZn5JqyytfM7JAuAgADPbgWy4Ly9plUVUVuHe3of+8ZJ+c4C7vH1JPzOb1L7ZLcDg9iX9zgV6HAZXDfI8J9cATcDdZvaKmXX9Jq4qeZ6TPiPP8/E4sMLM3gSeAs539xXBVFx8eZ6T84DTzOxV4E7gxGq6UNQTqiIiVaisrtxFRKQwFO4iIlVI4S4iUoUU7iIiVUjhLiJShRTuIiJVSOEuIlKFFO4iIlXo/wN7Vchil0gbWwAAAABJRU5ErkJggg==\n", 122 | "text/plain": [ 123 | "" 124 | ] 125 | }, 126 | "metadata": {}, 127 | "output_type": "display_data" 128 | } 129 | ], 130 | "source": [ 131 | "point = (0.3, 0.2)\n", 132 | "model = KNNKdTree()\n", 133 | "model.fit(dataList[0], labelList[0])\n", 134 | "plot_knn_predict(model, dataList[0], labelList[0], point)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.6.4" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 2 166 | } 167 | -------------------------------------------------------------------------------- /knn/README.md: -------------------------------------------------------------------------------- 1 | # 实现KNN的基本算法和KNN_KdTree的算法 2 | 3 | # 结果比较 4 | 结果在knn.ipynb中展示 5 | 6 | # 相关博客 7 | #### [k近邻算法(KNN)](https://www.cnblogs.com/huangyc/p/9716079.html) 8 | -------------------------------------------------------------------------------- /knn/knn_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | from sklearn.model_selection import train_test_split 4 | from collections import Counter 5 | from sklearn.neighbors import KNeighborsClassifier,KDTree 6 | 7 | class KNN(object): 8 | def __init__(self, n_neighbors=3, p=2): 9 | """ 10 | parameter: n_neighbors 临近点个数 11 | parameter: p 距离度量 12 | """ 13 | self.n = n_neighbors 14 | self.p = p 15 | 16 | 17 | def fit(self, X_train, y_train): 18 | self.X_train = X_train 19 | self.y_train = y_train 20 | 21 | def predict(self, X): 22 | # 取出n个点 23 | knn_list = [] 24 | for i in range(self.n): 25 | dist = np.linalg.norm(X - self.X_train[i], ord=self.p) 26 | knn_list.append((dist, self.y_train[i])) 27 | 28 | for i in range(self.n, len(self.X_train)): 29 | max_index = knn_list.index(max(knn_list, key=lambda x: x[0])) 30 | dist = np.linalg.norm(X - self.X_train[i], ord=self.p) 31 | if knn_list[max_index][0] > dist: 32 | knn_list[max_index] = (dist, self.y_train[i]) 33 | 34 | # 统计 35 | knn = [k[-1] for k in knn_list] 36 | return Counter(knn).most_common()[0][0] 37 | 38 | # 统计准确度 39 | def score(self, X_test, y_test): 40 | right_count = 0 41 | for X, y in zip(X_test, y_test): 42 | label = self.predict(X) 43 | if label == y: 44 | right_count += 1 45 | return right_count / len(X_test) 46 | 47 | 48 | def main(model): 49 | iris = load_iris() 50 | X = iris.data[:100, [0, 2]] 51 | y = iris.target[:100] 52 | y = np.where(y == 1, 1, -1) 53 | X_train, X_test, y_train, y_test = \ 54 | train_test_split(X, y, test_size=0.3) 55 | knn = model() 56 | knn.fit(X_train, y_train) 57 | score = knn.score(X_test, y_test) 58 | print("socre = %s"%score) 59 | 60 | if __name__ == "__main__": 61 | main(KNN) -------------------------------------------------------------------------------- /knn/knn_kdtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from collections import Counter 4 | from knn.knn_base import KNN 5 | from utils.data_generater import random_points 6 | from utils.plot import plot_knn_predict 7 | 8 | # kd-tree每个结点中主要包含的数据结构如下 9 | class Node: 10 | def __init__(self, data, label, depth=0, lchild=None, rchild=None): 11 | self.data = data 12 | self.depth = depth 13 | self.lchild = lchild 14 | self.rchild = rchild 15 | self.label = label 16 | 17 | 18 | class KdTree: 19 | def __init__(self, dataSet, label): 20 | self.KdTree = None 21 | self.n = 0 22 | self.nearest = None 23 | self.create(dataSet, label) 24 | 25 | # 建立kdtree 26 | def create(self, dataSet, label, depth=0): 27 | if len(dataSet) > 0: 28 | m, n = np.shape(dataSet) 29 | self.n = n 30 | axis = depth % self.n 31 | mid = int(m / 2) 32 | dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) 33 | node = Node(dataSetcopy[mid], label[mid], depth) 34 | if depth == 0: 35 | self.KdTree = node 36 | node.lchild = self.create(dataSetcopy[:mid], label, depth+1) 37 | node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1) 38 | return node 39 | return None 40 | 41 | # 前序遍历 42 | def preOrder(self, node): 43 | if node is not None: 44 | print(node.depth, node.data) 45 | self.preOrder(node.lchild) 46 | self.preOrder(node.rchild) 47 | 48 | # 搜索kdtree的前count个近的点 49 | def search(self, x, count = 1): 50 | nearest = [] 51 | for i in range(count): 52 | nearest.append([-1, None]) 53 | # 初始化n个点,nearest是按照距离递减的方式 54 | self.nearest = np.array(nearest) 55 | 56 | def recurve(node): 57 | if node is not None: 58 | # 计算当前点的维度axis 59 | axis = node.depth % self.n 60 | # 计算测试点和当前点在axis维度上的差 61 | daxis = x[axis] - node.data[axis] 62 | # 如果小于进左子树,大于进右子树 63 | if daxis < 0: 64 | recurve(node.lchild) 65 | else: 66 | recurve(node.rchild) 67 | # 计算预测点x到当前点的距离dist 68 | dist = np.sqrt(np.sum(np.square(x - node.data))) 69 | for i, d in enumerate(self.nearest): 70 | # 如果有比现在最近的n个点更近的点,更新最近的点 71 | if d[0] < 0 or dist < d[0]: 72 | # 插入第i个位置的点 73 | self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) 74 | # 删除最后一个多出来的点 75 | self.nearest = self.nearest[:-1] 76 | break 77 | 78 | # 统计距离为-1的个数n 79 | n = list(self.nearest[:, 0]).count(-1) 80 | ''' 81 | self.nearest[-n-1, 0]是当前nearest中已经有的最近点中,距离最大的点。 82 | self.nearest[-n-1, 0] > abs(daxis)代表以x为圆心,self.nearest[-n-1, 0]为半径的圆与axis 83 | 相交,说明在左右子树里面有比self.nearest[-n-1, 0]更近的点 84 | ''' 85 | if self.nearest[-n-1, 0] > abs(daxis): 86 | if daxis < 0: 87 | recurve(node.rchild) 88 | else: 89 | recurve(node.lchild) 90 | 91 | recurve(self.KdTree) 92 | 93 | # nodeList是最近n个点的 94 | nodeList = self.nearest[:, 1] 95 | 96 | # knn是n个点的标签 97 | knn = [node.label for node in nodeList] 98 | return self.nearest[:, 1], Counter(knn).most_common()[0][0] 99 | 100 | 101 | class KNNKdTree(KNN): 102 | def __init__(self, n_neighbors=3, p=2): 103 | super(KNNKdTree, self).__init__(n_neighbors=n_neighbors, p=p) 104 | 105 | 106 | def fit(self, X_train, y_train): 107 | self.X_train = np.array(X_train) 108 | self.y_train = np.array(y_train) 109 | self.kdTree = KdTree(self.X_train, self.y_train) 110 | 111 | 112 | def predict(self, point): 113 | nearest, label = self.kdTree.search(point, self.n) 114 | # print("nearest", [node.data for node in nearest]) 115 | return nearest, label 116 | 117 | 118 | 119 | def score(self, X_test, y_test): 120 | right_count = 0 121 | for X, y in zip(X_test, y_test): 122 | _, label = self.predict(X) 123 | if label == y: 124 | right_count += 1 125 | return right_count / len(X_test) 126 | 127 | 128 | 129 | def simpleTest(): 130 | data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]] 131 | label = [0, 0, 0, 1, 1, 1] 132 | kdtree = KNNKdTree() 133 | kdtree.fit(data, label) 134 | _, predict_label = kdtree.predict([3, 4.5]) 135 | print("predict label:", predict_label) 136 | # plot_knn_predict(kdtree, data, [3,4.5]) 137 | 138 | def largeTest(): 139 | N = 400000 140 | startTime = time.time() 141 | data = random_points(2, N) 142 | label = [0] * (N // 2) + [1] * (N // 2) 143 | kdtree2 = KNNKdTree() 144 | kdtree2.fit(data, label) 145 | _, predict_label = kdtree2.predict([0.1, 0.5]) # 四十万个样本点中寻找离目标最近的点 146 | 147 | print("time: %s" % round(time.time() - startTime, 5)) 148 | print("predict label:", predict_label) 149 | 150 | 151 | def main(): 152 | simpleTest() 153 | largeTest() 154 | 155 | if __name__ == "__main__": 156 | main() -------------------------------------------------------------------------------- /logistic_regression/LogisticRegressionClassifier.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LogisticRegression 2 | 3 | from math import exp 4 | from utils.data_generater import * 5 | 6 | 7 | class LogisticRegressionClassifier(object): 8 | def __init__(self, max_iter=200, learning_rate=0.01): 9 | # 最大迭代次数 10 | self.max_iter = max_iter 11 | # 学习率 12 | self.learning_rate = learning_rate 13 | 14 | # sigmoid函数 15 | def sigmoid(self, x): 16 | return 1 / (1 + exp(-x)) 17 | 18 | # 处理训练数据,增加一列,为了weight和bias合并处理 19 | def data_matrix(self, X): 20 | data_mat = [] 21 | for d in X: 22 | data_mat.append([1.0, *d]) 23 | return data_mat 24 | 25 | 26 | def fit(self, X, y): 27 | data_mat = self.data_matrix(X) 28 | # self.weights包含了weight和bias合并处理 29 | self.weights = np.zeros((len(data_mat[0]), 1), dtype=np.float32) 30 | 31 | for iter_ in range(self.max_iter): 32 | for i in range(len(X)): 33 | result = self.sigmoid(np.dot(data_mat[i], self.weights)) 34 | error = y[i] - result 35 | # 梯度下降迭代权重参数self.weights 36 | self.weights += self.learning_rate * error * np.transpose([data_mat[i]]) 37 | print('LogisticRegression Model(learning_rate={},max_iter={})'.format(self.learning_rate, self.max_iter)) 38 | 39 | # 计算准确度 40 | def score(self, X_test, y_test): 41 | right = 0 42 | X_test = self.data_matrix(X_test) 43 | for x, y in zip(X_test, y_test): 44 | result = np.dot(x, self.weights) 45 | if (result > 0 and y == 1) or (result < 0 and y == 0): 46 | right += 1 47 | return right / len(X_test) 48 | 49 | 50 | if __name__ == "__main__": 51 | X_train, X_test, y_train, y_test = create_logistic_data() 52 | 53 | # 我们的LogisticRegression 54 | my_lr = LogisticRegressionClassifier() 55 | my_lr.fit(X_train, y_train) 56 | print("my LogisticRegression score", my_lr.score(X_test, y_test)) 57 | 58 | # sklearn的LogisticRegression 59 | sklearn_lr = LogisticRegression(max_iter=200) 60 | sklearn_lr.fit(X_train, y_train) 61 | print("sklearn LogisticRegression score", sklearn_lr.score(X_test, y_test)) 62 | -------------------------------------------------------------------------------- /logistic_regression/README.md: -------------------------------------------------------------------------------- 1 | # 实现逻辑回归的算法 2 | 3 | # 结果比较 4 | 结果在logistic_regression.ipynb中展示 5 | 6 | # 相关博客 7 | #### [逻辑回归(Logistic Regression)](https://www.cnblogs.com/huangyc/p/9813891.html) 8 | -------------------------------------------------------------------------------- /logistic_regression/max_entropy.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | 4 | 5 | class MaxEntropy: 6 | def __init__(self, EPS=0.005): 7 | self._samples = [] 8 | self._Y = set() # 标签集合,相当去去重后的y 9 | self._numXY = {} # key为(x,y),value为出现次数 10 | self._N = 0 # 样本数 11 | self._Ep_ = [] # 样本分布的特征期望值 12 | self._xyID = {} # key记录(x,y),value记录id号 13 | self._n = 0 # 特征键值(x,y)的个数 14 | self._C = 0 # 最大特征数 15 | self._IDxy = {} # key为(x,y),value为对应的id号 16 | self._w = [] 17 | self._EPS = EPS # 收敛条件 18 | self._lastw = [] # 上一次w参数值 19 | 20 | def loadData(self, dataset): 21 | self._samples = deepcopy(dataset) 22 | for items in self._samples: 23 | y = items[0] 24 | X = items[1:] 25 | self._Y.add(y) # 集合中y若已存在则会自动忽略 26 | for x in X: 27 | if (x, y) in self._numXY: 28 | self._numXY[(x, y)] += 1 29 | else: 30 | self._numXY[(x, y)] = 1 31 | 32 | self._N = len(self._samples) 33 | self._n = len(self._numXY) 34 | self._C = max([len(sample)-1 for sample in self._samples]) 35 | self._w = [0]*self._n 36 | self._lastw = self._w[:] 37 | 38 | self._Ep_ = [0] * self._n 39 | for i, xy in enumerate(self._numXY): # 计算特征函数fi关于经验分布的期望 40 | self._Ep_[i] = self._numXY[xy]/self._N 41 | self._xyID[xy] = i 42 | self._IDxy[i] = xy 43 | 44 | def _Zx(self, X): # 计算每个Z(x)值 45 | zx = 0 46 | for y in self._Y: 47 | ss = 0 48 | for x in X: 49 | if (x, y) in self._numXY: 50 | ss += self._w[self._xyID[(x, y)]] 51 | zx += math.exp(ss) 52 | return zx 53 | 54 | def _model_pyx(self, y, X): # 计算每个P(y|x) 55 | zx = self._Zx(X) 56 | ss = 0 57 | for x in X: 58 | if (x, y) in self._numXY: 59 | ss += self._w[self._xyID[(x, y)]] 60 | pyx = math.exp(ss)/zx 61 | return pyx 62 | 63 | def _model_ep(self, index): # 计算特征函数fi关于模型的期望 64 | x, y = self._IDxy[index] 65 | ep = 0 66 | for sample in self._samples: 67 | if x not in sample: 68 | continue 69 | pyx = self._model_pyx(y, sample) 70 | ep += pyx/self._N 71 | return ep 72 | 73 | def _convergence(self): # 判断是否全部收敛 74 | for last, now in zip(self._lastw, self._w): 75 | if abs(last - now) >= self._EPS: 76 | return False 77 | return True 78 | 79 | def predict(self, X): # 计算预测概率 80 | Z = self._Zx(X) 81 | result = {} 82 | for y in self._Y: 83 | ss = 0 84 | for x in X: 85 | if (x, y) in self._numXY: 86 | ss += self._w[self._xyID[(x, y)]] 87 | pyx = math.exp(ss)/Z 88 | result[y] = pyx 89 | return result 90 | 91 | def train(self, maxiter=1000): # 训练数据 92 | for loop in range(maxiter): # 最大训练次数 93 | print("iter:%d" % loop) 94 | self._lastw = self._w[:] 95 | for i in range(self._n): 96 | ep = self._model_ep(i) # 计算第i个特征的模型期望 97 | self._w[i] += math.log(self._Ep_[i]/ep)/self._C # 更新参数 98 | print("w:", self._w) 99 | if self._convergence(): # 判断是否收敛 100 | break 101 | 102 | if __name__ == "__main__": 103 | dataset = [['no', 'sunny', 'hot', 'high', 'FALSE'], 104 | ['no', 'sunny', 'hot', 'high', 'TRUE'], 105 | ['yes', 'overcast', 'hot', 'high', 'FALSE'], 106 | ['yes', 'rainy', 'mild', 'high', 'FALSE'], 107 | ['yes', 'rainy', 'cool', 'normal', 'FALSE'], 108 | ['no', 'rainy', 'cool', 'normal', 'TRUE'], 109 | ['yes', 'overcast', 'cool', 'normal', 'TRUE'], 110 | ['no', 'sunny', 'mild', 'high', 'FALSE'], 111 | ['yes', 'sunny', 'cool', 'normal', 'FALSE'], 112 | ['yes', 'rainy', 'mild', 'normal', 'FALSE'], 113 | ['yes', 'sunny', 'mild', 'normal', 'TRUE'], 114 | ['yes', 'overcast', 'mild', 'high', 'TRUE'], 115 | ['yes', 'overcast', 'hot', 'normal', 'FALSE'], 116 | ['no', 'rainy', 'mild', 'high', 'TRUE']] 117 | 118 | maxent = MaxEntropy() 119 | x = ['overcast', 'mild', 'high', 'FALSE'] 120 | maxent.loadData(dataset) 121 | maxent.train() 122 | print('predict:', maxent.predict(x)) 123 | -------------------------------------------------------------------------------- /naive_bayes/.ipynb_checkpoints/naiveBayes-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "from sklearn.datasets import load_iris\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "from naiveBayesBase import NaiveBayesBase\n", 25 | "from naiveBayesGaussian import GaussianNaiveBayes\n", 26 | "from utils.word_utils import *" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Test NaiveBayesBase" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def loadDataSet():\n", 43 | " '''数据加载函数。这里是一个小例子'''\n", 44 | " postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n", 45 | " ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n", 46 | " ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n", 47 | " ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n", 48 | " ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n", 49 | " ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n", 50 | " classVec = [0, 1, 0, 1, 0, 1] # 1代表侮辱性文字,0代表正常言论,代表上面6个样本的类别\n", 51 | " return postingList, classVec" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def checkNB():\n", 61 | " '''测试'''\n", 62 | " listOPosts, lisClasses = loadDataSet()\n", 63 | " myVocabList = createVocabList(listOPosts)\n", 64 | " trainMat = []\n", 65 | " for postinDoc in listOPosts:\n", 66 | " trainMat.append(setOfWord2Vec(myVocabList, postinDoc))\n", 67 | "\n", 68 | " nb = NaiveBayesBase()\n", 69 | " nb.fit(np.array(trainMat), np.array(lisClasses))\n", 70 | "\n", 71 | " testEntry1 = ['love', 'my', 'dalmation']\n", 72 | " thisDoc = np.array(setOfWord2Vec(myVocabList, testEntry1))\n", 73 | " print(testEntry1, 'classified as:', nb.predict(thisDoc))\n", 74 | "\n", 75 | " testEntry2 = ['stupid', 'garbage']\n", 76 | " thisDoc2 = np.array(setOfWord2Vec(myVocabList, testEntry2))\n", 77 | " print(testEntry2, 'classified as:', nb.predict(thisDoc2))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "['love', 'my', 'dalmation'] classified as: 0\n", 90 | "['stupid', 'garbage'] classified as: 1\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "checkNB()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Test GaussianNaiveBayes" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "ename": "NameError", 112 | "evalue": "name 'create_data' is not defined", 113 | "output_type": "error", 114 | "traceback": [ 115 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 116 | "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 117 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0miris\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_iris\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_test\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_test_split\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0miris\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0miris\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 118 | "\u001b[1;31mNameError\u001b[0m: name 'create_data' is not defined" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "iris = load_iris()\n", 124 | "X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "print(len(X_train))\n", 134 | "print(len(X_test))\n", 135 | "model = GaussianNaiveBayes()\n", 136 | "model.fit(X_train, y_train)\n", 137 | "print(model.predict([4.4, 3.2, 1.3, 0.2]))\n", 138 | "print(model.score(X_test, y_test))" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.6.5" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 2 170 | } 171 | -------------------------------------------------------------------------------- /naive_bayes/README.md: -------------------------------------------------------------------------------- 1 | # 实现朴素贝叶斯的基本算法和高斯混合朴素贝叶斯算法 2 | 3 | # 结果比较 4 | 结果在naiveBayes.ipynb中展示 5 | 6 | # 相关博客 7 | #### [朴素贝叶斯算法(Naive Bayes)](https://www.cnblogs.com/huangyc/p/9734956.html) 8 | -------------------------------------------------------------------------------- /naive_bayes/naiveBayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "from sklearn.datasets import load_iris\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "from naiveBayesBase import NaiveBayesBase\n", 25 | "from naiveBayesGaussian import GaussianNaiveBayes\n", 26 | "from utils.word_utils import *" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Test NaiveBayesBase" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def loadDataSet():\n", 43 | " '''数据加载函数。这里是一个小例子'''\n", 44 | " postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n", 45 | " ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n", 46 | " ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n", 47 | " ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n", 48 | " ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n", 49 | " ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n", 50 | " classVec = [0, 1, 0, 1, 0, 1] # 1代表侮辱性文字,0代表正常言论,代表上面6个样本的类别\n", 51 | " return postingList, classVec" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def checkNB():\n", 61 | " '''测试'''\n", 62 | " listOPosts, lisClasses = loadDataSet()\n", 63 | " myVocabList = createVocabList(listOPosts)\n", 64 | " trainMat = []\n", 65 | " for postinDoc in listOPosts:\n", 66 | " trainMat.append(setOfWord2Vec(myVocabList, postinDoc))\n", 67 | "\n", 68 | " nb = NaiveBayesBase()\n", 69 | " nb.fit(np.array(trainMat), np.array(lisClasses))\n", 70 | "\n", 71 | " testEntry1 = ['love', 'my', 'dalmation']\n", 72 | " thisDoc = np.array(setOfWord2Vec(myVocabList, testEntry1))\n", 73 | " print(testEntry1, 'classified as:', nb.predict(thisDoc))\n", 74 | "\n", 75 | " testEntry2 = ['stupid', 'garbage']\n", 76 | " thisDoc2 = np.array(setOfWord2Vec(myVocabList, testEntry2))\n", 77 | " print(testEntry2, 'classified as:', nb.predict(thisDoc2))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "['love', 'my', 'dalmation'] classified as: 0\n", 90 | "['stupid', 'garbage'] classified as: 1\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "checkNB()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Test GaussianNaiveBayes" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "iris = load_iris()\n", 112 | "X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 7, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "105\n", 125 | "45\n", 126 | "0\n", 127 | "0.9333333333333333\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "print(len(X_train))\n", 133 | "print(len(X_test))\n", 134 | "model = GaussianNaiveBayes()\n", 135 | "model.fit(X_train, y_train)\n", 136 | "print(model.predict([4.4, 3.2, 1.3, 0.2]))\n", 137 | "print(model.score(X_test, y_test))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.5" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /naive_bayes/naiveBayesBase.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.word_utils import * 3 | 4 | 5 | 6 | class NaiveBayesBase(object): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | 12 | def fit(self, trainMatrix, trainCategory): 13 | ''' 14 | 朴素贝叶斯分类器训练函数,求:p(Ci),基于词汇表的p(w|Ci) 15 | Args: 16 | trainMatrix : 训练矩阵,即向量化表示后的文档(词条集合) 17 | trainCategory : 文档中每个词条的列表标注 18 | Return: 19 | p0Vect : 属于0类别的概率向量(p(w1|C0),p(w2|C0),...,p(wn|C0)) 20 | p1Vect : 属于1类别的概率向量(p(w1|C1),p(w2|C1),...,p(wn|C1)) 21 | pAbusive : 属于1类别文档的概率 22 | ''' 23 | numTrainDocs = len(trainMatrix) 24 | # 长度为词汇表长度 25 | numWords = len(trainMatrix[0]) 26 | # p(ci) 27 | self.pAbusive = sum(trainCategory) / float(numTrainDocs) 28 | # 由于后期要计算p(w|Ci)=p(w1|Ci)*p(w2|Ci)*...*p(wn|Ci),若wj未出现,则p(wj|Ci)=0,因此p(w|Ci)=0,这样显然是不对的 29 | # 故在初始化时,将所有词的出现数初始化为1,分母即出现词条总数初始化为2 30 | p0Num = np.ones(numWords) 31 | p1Num = np.ones(numWords) 32 | p0Denom = 2.0 33 | p1Denom = 2.0 34 | for i in range(numTrainDocs): 35 | if trainCategory[i] == 1: 36 | p1Num += trainMatrix[i] 37 | p1Denom += sum(trainMatrix[i]) 38 | else: 39 | p0Num += trainMatrix[i] 40 | p0Denom += sum(trainMatrix[i]) 41 | # p(wi | c1) 42 | # 为了避免下溢出(当所有的p都很小时,再相乘会得到0.0,使用log则会避免得到0.0) 43 | self.p1Vect = np.log(p1Num / p1Denom) 44 | # p(wi | c2) 45 | self.p0Vect = np.log(p0Num / p0Denom) 46 | return self 47 | 48 | 49 | def predict(self, testX): 50 | ''' 51 | 朴素贝叶斯分类器 52 | Args: 53 | testX : 待分类的文档向量(已转换成array) 54 | p0Vect : p(w|C0) 55 | p1Vect : p(w|C1) 56 | pAbusive : p(C1) 57 | Return: 58 | 1 : 为侮辱性文档 (基于当前文档的p(w|C1)*p(C1)=log(基于当前文档的p(w|C1))+log(p(C1))) 59 | 0 : 非侮辱性文档 (基于当前文档的p(w|C0)*p(C0)=log(基于当前文档的p(w|C0))+log(p(C0))) 60 | ''' 61 | 62 | p1 = np.sum(testX * self.p1Vect) + np.log(self.pAbusive) 63 | p0 = np.sum(testX * self.p0Vect) + np.log(1 - self.pAbusive) 64 | if p1 > p0: 65 | return 1 66 | else: 67 | return 0 68 | 69 | def loadDataSet(): 70 | '''数据加载函数。这里是一个小例子''' 71 | postingList = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'], 72 | ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'], 73 | ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'], 74 | ['stop', 'posting', 'stupid', 'worthless', 'garbage'], 75 | ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'], 76 | ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']] 77 | classVec = [0, 1, 0, 1, 0, 1] # 1代表侮辱性文字,0代表正常言论,代表上面6个样本的类别 78 | return postingList, classVec 79 | 80 | 81 | def checkNB(): 82 | '''测试''' 83 | listPosts, listClasses = loadDataSet() 84 | myVocabList = createVocabList(listPosts) 85 | trainMat = [] 86 | for postDoc in listPosts: 87 | trainMat.append(setOfWord2Vec(myVocabList, postDoc)) 88 | 89 | nb = NaiveBayesBase() 90 | nb.fit(np.array(trainMat), np.array(listClasses)) 91 | 92 | testEntry1 = ['love', 'my', 'dalmation'] 93 | thisDoc = np.array(setOfWord2Vec(myVocabList, testEntry1)) 94 | print(testEntry1, 'classified as:', nb.predict(thisDoc)) 95 | 96 | testEntry2 = ['stupid', 'garbage'] 97 | thisDoc2 = np.array(setOfWord2Vec(myVocabList, testEntry2)) 98 | print(testEntry2, 'classified as:', nb.predict(thisDoc2)) 99 | 100 | 101 | if __name__ == "__main__": 102 | checkNB() -------------------------------------------------------------------------------- /naive_bayes/naiveBayesGaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | from sklearn.model_selection import train_test_split 4 | from scipy.stats import norm 5 | 6 | class GaussianNaiveBayes(object): 7 | def __init__(self): 8 | self.model = None 9 | 10 | # 概率密度函数 11 | def gaussian_probability(self, x, mean, stdev): 12 | return norm.pdf(x, loc=mean, scale=stdev) 13 | 14 | # 处理X_train 15 | def summarize(self, train_data): 16 | summaries = [(np.mean(X), np.std(X)) for X in zip(*train_data)] 17 | return summaries 18 | 19 | # 分类别求出数学期望和标准差 20 | def fit(self, X, y): 21 | labels = list(set(y)) 22 | data = {label:[] for label in labels} 23 | for f, label in zip(X, y): 24 | data[label].append(f) 25 | self.model = {label: self.summarize(value) for label, value in data.items()} 26 | return self 27 | 28 | # 计算概率 29 | def calculate_probabilities(self, input_data): 30 | probabilities = {} 31 | for label, value in self.model.items(): 32 | probabilities[label] = 1 33 | for i in range(len(value)): 34 | mean, stdev = value[i] 35 | probabilities[label] *= self.gaussian_probability(input_data[i], mean, stdev) 36 | return probabilities 37 | 38 | # 类别 39 | def predict(self, X_test): 40 | label = sorted(self.calculate_probabilities(X_test).items(), key=lambda x: x[-1])[-1][0] 41 | return label 42 | 43 | def score(self, X_test, y_test): 44 | right = 0 45 | for X, y in zip(X_test, y_test): 46 | label = self.predict(X) 47 | if label == y: 48 | right += 1 49 | 50 | return right / float(len(X_test)) 51 | 52 | if __name__ == "__main__": 53 | iris = load_iris() 54 | X,y = iris.data, iris.target 55 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) 56 | print(len(X_train)) 57 | print(len(X_test)) 58 | model = GaussianNaiveBayes() 59 | model.fit(X_train, y_train) 60 | print(model.predict([4.4, 3.2, 1.3, 0.2])) 61 | print(model.score(X_test, y_test)) -------------------------------------------------------------------------------- /perceptron/README.md: -------------------------------------------------------------------------------- 1 | # 实现感知机的基本算法和对偶算法 2 | 3 | # 结果比较 4 | 结果在perceptron.ipynb中展示 5 | 6 | # 相关博客 7 | #### [感知机原理(Perceptron)](https://www.cnblogs.com/huangyc/p/9706575.html) 8 | -------------------------------------------------------------------------------- /perceptron/perceptron.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from sklearn.model_selection import train_test_split\n", 22 | "from sklearn.datasets import load_iris\n", 23 | "import numpy as np\n", 24 | "from utils.plot import plot_decision_regions\n", 25 | "from perceptron.perceptron_base import PerceptronBase\n", 26 | "from perceptron.perceptron_dual import PerceptronDual" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "iris = load_iris()\n", 36 | "X = iris.data[:100, [0, 2]]\n", 37 | "y = iris.target[:100]\n", 38 | "y = np.where(y == 1, 1, -1)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "numpy.ndarray" 50 | ] 51 | }, 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "type(X)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 8, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "numpy.ndarray" 70 | ] 71 | }, 72 | "execution_count": 8, 73 | "metadata": {}, 74 | "output_type": "execute_result" 75 | } 76 | ], 77 | "source": [ 78 | "type(y)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def trainData(model, X , y):\n", 88 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4)\n", 89 | " ppn = model(eta=0.1, n_iter=10)\n", 90 | " ppn.fit(X_train, y_train)\n", 91 | " plot_decision_regions(ppn, X, y)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "## test PerceptronBase" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAG1hJREFUeJzt3X+MXeV95/HPd8YwA7axZWzCqjHMtthEKVaTMLa0coEtNqgpVhpVjQRSK7VBMsluq1S0qorzR//anewfjTbSNmatTLtbkRA1ZKkqQlIwKBqwUuxxG9VuCKYNNPwo2E4x2IQZ6pnv/vHcM/fO9b3nnPvj3HOee98vaXTvOXPuOc+5IV8/832+z3PM3QUAiMdY2Q0AAHSGwA0AkSFwA0BkCNwAEBkCNwBEhsANAJEhcANAZAjcABAZAjcARGZNnoPMbKOkr0i6SZJL+rS7f6/d8evWbfarr57qSwMBYBT8+MfHz7r7ljzH5grckr4k6Tvu/utmdrmkK9MOvvrqKX3+8/M5Tw0AuO8++5e8x2YGbjO7StKtkn5Lktz9fUnvd9s4AEBv8uS4f1bSGUl/bmZ/b2ZfMbO1zQeZ2X4zmzez+QsXzvS9oQCAIE/gXiPpY5IOuvtHJb0r6Y+aD3L3Q+4+7e7T69blStMAALqQJ3C/KulVd3+utv2IQiAHAJQgM3C7+xuSXjGzG2u79kj6QaGtAgC0lbeq5HclfbVWUfIjSb9dXJMAAGlyBW53/76k6YLbAgDIgZmTABAZAjcARIbADQCRIXADQGQI3AAQGQI3AESGwA0AkSFwA0BkCNwAEBkCNwBEhsANAJEhcANAZAjcABAZAjcARIbADQCRIXADQGQI3AAQGQI3AESGwA0AkSFwA0BkCNwAEBkCNwBEhsAN9JF7+jbQDwRuoE/m5qTDh+vB2j1sz82V2y4MHwI30Afu0uKidPRoPXgfPhy2FxfpeaO/1pTdAGAYmEl794b3R4+GH0natSvsNyuvbRg+uXrcZvaymZ0ws++b2XzRjQJi1Bi8E7EGbXL11dZJquSX3P0j7j5dWGuAiCXpkUaNOe9YkKuvPnLcQB805rR37ZIOHAivjTnvGJCrj0PeHLdLesLMXNL/dvdDBbYJiI6ZNDGxOqedpE0mJuJJl5Crj0PewL3b3V83s2skPWlmP3T3VX84mdl+SfsladOm6/rcTKD6br019EiT4JYEwV6DXeM5W233W9LuJGhL+e9j0G0dVblSJe7+eu31tKRHJe1qccwhd5929+l167b0t5VAJJqDVK9Bq4x8c7e5enLjg5MZuM1srZmtT95LulPSyaIbBoy6MvLN3ebqyY0PVp5UyQckPWqh67BG0tfc/TuFtgqosEGlA8rIN+fJ1be7f3Ljg5MZuN39R5J+YQBtASpvbi70IJNglPQsJyZCjrvfesk3dystV591/4Nu66iiHBDIqczURaNBlBe2ytVn3f/y8nDUsceAKe9AToNOBzTnm/furW9Lg+/Npt3/nj3SU09Vp63Djh430IFeprUvL6dvt7rWxIS0c+fqfPPOncXXhreb8t7u/sfGWufGd+1anRtPuwbyI3ADHeg2dTE7Kx08WA/Wy8the3a2mHb2Iq2sL+3+kxx3c2781lspFew3AjeQU7elcsvL0sKC9MMf1oP3wYNhe2Ghfc87ySkfO7Y6p3zsWLE59XZ57IUF6ckn0++/m9w4Pe/OkeMGcup2WvvYmPTZz0pf/nII1p/7XNh/441h/1ib7lOvOfWsssVuyvqeeabz+6dUsP/ocQMdSEsHpHn2WWn79tWpgu3bw/403ebUs1ITab9Pu2a39z9MS95WAYEb6FCn09rdQ5rh8cdDWsQ9vD7+eNifliroJqeep2yvl7K+bqb1D8uSt1VBqgQomLt06pT0/vvS5ZdLW7ZIZ86E7VOn2vc8uy0HzJOaGGRZX9XKGocBgRuV1+0U87TPFTVtfXl5dc462b7iCunDH5befjvs37JF2rAh7B8ba9+ebqefZ824TPt9v5enHZYlb6uEwI1K63aKedrnpGKmrc/OhtRHMuCYVI9MTkqf/nSoyDh2rH789u3SHXdk32M3089vuaV1aqL52Fa/L2J52qKWvB1V5LhRWd2WkWWVtC0s9L80La3k77336kG7sYzu2LGwP6s9nZbYZZXtJTnsTsv6elXEOUcVPW5UVp5cbTclbcm5+1malpT8JcE6Kfn70IfC/mefbZ8quOWW9Pb0u2wvmeXYPCPTPTsFg2qgx41KSysj67akrajStCR4N0rSJmlldEXcY7dlexKzHGNA4EaltcvF9lLSVlRpWpIeadQ4zb1dqqCIe8y6XrsZmUWlktBfpEpQWVllZHv2hNdOStqSwJPkm/tVmtaY007SI8n2wYPtZ0gWcY9Z91FGKgn9RY8bldWujCxZcW5srPOV6iYnw0/aKnbdGBsL502CdpI2+dCHwv60ae1pKwBm3WO3KweWkUpC/9DjRqWllZH1UtJWRGnavfeuruNOgne7oJ1H2j32Iuu8aaWEKB+BG5XXrhyu09l4jdtFlaY1B+msoN2Yb07+AUnyzTt3ri4jbJfyaf7crl3pVSBp311RqST0F4EbURqW2Xi9lPVllRGmXTPtu2s+T4zf67AjcCNanaRDmrfbTU3Po981zmnTz7PuMW1ae5oyUknoHwYnEbV2KY+0WuRenkZTRI1zVnlip2WEeUv2ykglobWZmc6Op8eNodM4HVxanaedng5T0F94oV6m11jGl9bzTjtvVl45ra3drJzHinvxSv6RP/JXb0oLi12dg8CNodM4hbsx/7tzZ1jU6Y470p9Gk7XintS/Guduc/XDkuMfFTMH3glvzp2TJF0z/hNdI+m5L9dXHbP78p+PwI2h9Mwz6fu3bw+BO6lQaXwaTdpKfWn56G51u3IeK+5V19ycdOTrP16176Wtt0nrJT3wQM/nJ3Bj6CRPnHnqqRDE1q+Xzp8P27ffHo759rfraZHkaTQf/3g4Pll6tVUqRCqmxrnbnDK56PLNzUkvviid/ud3VnrUkvTAxge1/79PNRzZe8BOELgxtJIec/KTbL/4Yuun0bz4YkiXtCuxk8grI5iZkfT6a2FjaUm7J49rm6SHbv5CmIklSZoq7Pq5A7eZjUual/Sau+8rrEVAB9qV9U1Oht710aP1Xncy5f3KK9s/jWZ8PKwD0pgK2bOnHpC7fSJN8/tW26imdoOJL229Tdq9u2nJxXs1CJ30uD8n6XlJVxXUFqAjeZ4405j/lcKklVtuSX8azYkTq3voBw9KO3aE/392+0QaqZin7qAYeQYT+5n66FSuwG1mH5R0l6T/Jun+QlsE5ND8xJnGsr4bb5SeeEKan2+d0pDaT+k+dSqUCjav8CdJv/iL4R+IdlPw25UK7txZv2bz77otI0R/FT2Y2G/mOar1zewRSTMKt/EHWamS66+f9s9/fr4/LcRQKCJN0LiUaqLxiTMLC6EXnQTWJ58MvXGpfe/XPfS4kzSKFNIoO3ZIt92WfY/N/0C0yo83/46gXY5kwtXp4yFgXzqYOFh2333H3X06z7GZPW4z2yfptLsfN7P/nHLcfkn7JWnTputyNhWjoNsH/mZJVt9LarGlfKvxZaU7tm0LvfXEtm35gmtWqWC/ywjRmbm5Wp5aWslV7548rns3HqsF7KmymtaxPKmS3ZI+YWa/ImlS0lVm9pC7/0bjQe5+SNIhKfS4+95SRKmI2YaJdk+c+cxn2q+41+6ayT8oCwvS00+vLiN8+ukwQJnVVpZKrZa5OenIEUmv1FMguyeP66FPfjNsrPQapgbdtJ5lBm53f0C1LHytx/0HzUEbaKeI2YZS+hNnHnwwBO9ur9mqjDALS6WWb3a2Vku98J60sFgfUNz6a0156vhHg6njRuGKmG2YlPw1TlX/7GfDVPbJyVDW1+k1zdLLCFkqtXpm7m9RonfJgGL1Bhd71VHgdvfvSvpuIS3B0EpLIfQStLZtCwtGNeaqt20L9djdXjMpFWxVRpiFpVKL1TyYmHjp7gMNaY/hC9Kt0ONGoYpaxS7Jnc/P15/LePhw2E57ckzaNRufdt78ubyBlqVS+yd7MDERf+qjUwRuFKqoVeyyVgBMe3IMK+5VU9vBxJ//QvhTKuLBxH4jcKNwRa1il6z01zjLMdnPinvVtzKYeP68tLRUH0y8+TMN631Ig5pGHhMCNwai32mCxtI999Wle7ff3r7kr4y2oi7fYCKBOguBG1FrfMhBY68b5Ws1mHjN+E/03N1/MnKDif1G4EaUktK9PXvqa4CsXx9y3EnpHqvxDRaDiYND4Ea0ktK9VvuLmmaPOgYTy0PgRpTSSvcSrMbXXzMzks4zmFgFBG5EKat075Zb2j/JhqCdH4OJ1UTgRkeqlDfOKt1Lm/JepfuoCgYT40HgRm5VzBu3K91Lm/L+zDPVu48yMJgYLwI3cilyedZ+y7tSn1Tt+yjCzIwYTBwCBG7kUtTyrEUg/x00DyZKtdQHg4nRI3AjtyKWZy1KL/nvmDGYOBoI3MitqOVZi9JN/ruK95FmZUDxH99ceXgAg4nDj8CNXIpannXQhuE+Zg7Un/IihTz1NkkP3f3NWsBmMHHYEbhxiXalchMTYUp5Y97YPa4lT2NcunVuTjry9aaHB9z8qfBmVa6agD0qCNxYJa3kb1hUfenWmQPvhDfnzq3s2z15XA998UzDUeSpRxmBGyvSSv527gz7OnlyepVVaenWliV6W+6Xfu6aph41EBC4sSKr5C85ZtjL6IrUPJiYeGnrbUP/gFv0D4Ebq2SV/A1rGV2RsgcTJQI1OkHgxipppXLJ+1a/I3jXMZiIohG4sSLvVPFYy+iKsjJDsWEw8YGNDzat90GuGv1D4MaKrFI5Ka4yuqLkG0ycKqNpGBEEbqySVSpX5TK6IjCYiCoicOMSaaVyVSqjK0rbwcQvN9ZRE6hRHgI3Rl7+wUSgGjIDt5lNSpqTNFE7/hF3/+OiG4b4xPJUGQYTEbs8Pe5FSbe7+wUzu0zSs2b2bXf/24LbhohU8ek4iZkZSa+/trIm9e7J47p18pj23/wUg4mIUmbgdneXdKG2eVntx4tsFOJSpafj5B9MnBK9asQqV47bzMYlHZd0g6Q/dffnCm0VolL203EYTMSoyRW43X1J0kfMbKOkR83sJnc/2XiMme2XtF+SNm26ru8NRbUN8uk4DCZi1HVUVeLu58zsu5J+WdLJpt8dknRIkq6/fppUyogp6qkyc3PSkSNiMBFokKeqZIukf68F7Ssk7ZX0PwpvGaLR76fKMJgIpMvT4/4Pkv5vLc89Jukv3f2xYpuFmPTjqTKzs9Lp01qZSs5gItBenqqSf5D00QG0BRHr9Kkys7O1yg9p5SG3uy97WQ9tvb8WsBlMBNph5iT6Jms6/Mz9LUr0rml+ygsBG8hC4EYhZmZqb15pqv64+wAPDwB6ROBGX8zNSUe+8VrYqA0qPrDxQWmjmqo/eHgA0CsCN7rSPJgo1dal/mTj47imymgaMPQI3MglezAxQY8aKBqBG20xmAhUE4EbkloPJl4z/hM9d/efMJgIVAyBe0QxmAjEi8A9IlbW/GAwEYgegXtIrQwm/vtFaWmpPph48xeactT0qIHYELiHSL7BRNb7AGJH4I4Ug4nA6CJwR4LBRAAJAndFpQ4mSrVe9VQZTQNQMgJ3RczMSDrDYCKAbATuEjGYCKAbBO4BYTARQL8QuAvCYCKAohC4+2huLrwe+XroVTOYCKAIBO4ezRx4R1p4byVXfc34T7R78mU99MUztSPoUQPoLwJ3h2ZmdGmJ3s9/IWzcy0AigOIRuFOsPOXlTIvqj1UPDyBgAxgcAneTmQPvSOfPrx5QnJT2f/LHVH8AqISRDtzNg4mJl+4+EN6sGlCcEgBUwcgF7pkD74Q3585JajWYKDGgCKDKhj5wM5gIYNgMVeBmMBHAKMgM3Ga2VdJfSLpW0rKkQ+7+paIblheDiQBGTZ4e90VJv+/uf2dm6yUdN7Mn3f0HBbftEgwmAkCOwO3u/yrpX2vvz5vZ85J+RtJAAnerwcQHNn6D9T4AjKyOctxmNiXpo5Kea/G7/ZL2S9KmTdd13aCWg4lb7pd+rnG506muzw8AscsduM1snaRvSvo9d3+n+ffufkjSIUm6/vppz3PO/IOJ5Kj77uRJ6YknpLNnpc2bpTvvlG66qXrnBHCJXIHbzC5TCNpfdff/18sF2w4mfnGq4SgCdaFOnpQeflgaH5euvDKkoR5+WLrnnu4DbRHnBNBSnqoSkzQr6Xl3/2InJ+9sMBED88QTIcBOTITtiQlpcTHs7zbIFnFOAC3l6XHvlvSbkk6Y2fdr+w64++PtPvDGa0ua+S8hWDOYWEFnz4ZecaPLLw/7q3ROAC3lqSp5VpJ1ctId4z/Q/Ec+xWBiVW3eHFIZSe9Ykt5/P+yv0jkBtDRWyFmvvZbp5FV2551hjGFxUXIPr0tLYX+VzgmgpWICN6rtppvCoOHGjdJPfxpeex1ELOKcAFoaqrVK0IGbboonqD72mHT4cOjFT0xIe/dK+/b1dk5KFxExetzoj6Qc8Ny51eWAJ0/2dt7HHpO+9a2QLx8bC6/f+lbYX7W2AgNC4EZ/NJYDmoXX8fGwvxeHD4fzjY+HwD0+HrYPH65eW4EBIXCjP86eDeV/jfpRDri4GIJrI7Owv1tFtRUYEHLcwywtj9tt3nh2Vpqfl5aXQw94ejpUEG3eLL3xhrSwIF28KK1ZI01OhgqjXkxMhPRII/fVZYedonQRkaPHPazS8rjd5o1nZ6WjR0PQlsLr0aNh//btYSmDixdDj/jixbC9fXtv97F3bwjUS0vhektLYXvv3u7PSekiIkfgHlZpedxu88bz8+HVrP6T7D91StqwIfS03cPrhg1hfy/27ZPuuiukMpaXw+tdd/VWVULpIiJHqiR27dIhaVPQFxdDwG7UmDdulw5JetretPjj8nI478JC/RxLS+GzSd44KzVDeR6QGz3umKWlQzZvvjQ3nORxJyYuDb5J3jgtHdI8SJgwC+d+773V+997L+zPSs0Ukdbp9nsDIkDgjllaOiQtj5uWN05Lh6xd27oda9eGfHYr589np2aKSOt0+70BESBwxyytrC0tj5uWN0562s2Wl0OAm5xcvX9yMrvCI6ukL+0+KAcELkGOuwq6ze9u3iy98kpISbiHgHbFFdLWreH33UxrHxsLQbo5lTI2Fq63sBCuk1xvfDzsf+ut1kF/bKy+NnfSs0968knATyvP++lP08sBu/nuKAdE5Ohxl62XfOuGDSGwJUHWPWxv2JD+ubS88Q03tP7MDTeE87777urrvftu2N+uXvvaa6UdO1b/Y+AetnfsCNvdpnW6/e4oB0TkCNxl6yXfeuJECLxJKsEsbJ84kf65tLyxWes0gln69d56q/W13npLevvtkAdv/NzatWG/1H1ap9vvjnJARI5USdnOng093tOn62mEdevCBBYpPRWQlPU1lvYtL9fzv+0+m1YOePas9IEPrM4ru6/ONzemURrzzWvWtG7L2bPSpk3S1Vdfes5EWlpnakq67rr6fUxN1b87M+nNN+uzNdevz5erjml1RKAJPe6yuYfKi8Y0QrKdlQpIK+tL+2za59LKCNesWZ2nTlIYa9Z0f84safcxORl69EnN+NJS2G4eQAWGDIG7bBcu1N839nIvXMhOBaTlf9M+m/a5tPzvunX19jUG6XXruj9nFkr3gEsQuMt28WI9x9xYqXHxYnbZWlr+N+2zaZ9Ly/+aSVddtTpXnWx3e84safexsBBSMOPj4Zrj42F7YaG7/y2ASJDjLluy+t1ll9X3LS3VUwxZZWv79rVetyPrs+0+J7XP/27eHHLxl19ezykn5YDdnjNL1n2cOxdy8onFxRC8e8H0e1QcPe6yFZViKKLkbfv2UAnSuALg22/3vgJgmrT7KOIemQ6PCBC4y1ZUiqGIkrdTp0LVRuMKgOvX974CYJq0+yjiHsmpIwKkSgYl7c/vIlIMvX62lbNnQ067cYJPc1lfEdLuo4h7bLeqIlAR9LgHYVj+/O6lrC8Wo3CPiB6BexCG5c/vUZgqPgr3iOhlBm4z+zMzO21mkXUPK2RYVqMbhanio3CPiF6eHPf/kfS/JP1FsU0ZYkWuRjfo0rVRmCo+CveIqGX2uN19TtK/DaAtw6uoP7+HJXcOoCPkuAehqD+/hyV3DqAjfSsHNLP9kvZL0nW9zlwbRkX8+U3pGjCS+tbjdvdD7j7t7tNbGhcjQnEoXQNGEqmSmFG6BoykPOWAD0v6nqQbzexVM7u3+GYhF0rXgJGUmeN293sG0RB0idI1YOSQKgGAyBC4ASAyBG4AiAyBGwAiQ+AGgMgQuAEgMgRuAIgMgRsAIkPgBoDIELgBIDIEbgCIDIEbACJD4AaAyBC4ASAyBG4AiAyBGwAiQ+AGgMgQuAEgMgRuAIgMgRsAIkPgBoDIELgBIDIEbgCIDIEbACJD4AaAyBC4ASAyBG4AiEyuwG1mv2xmL5jZP5nZHxXdKABAe5mB28zGJf2ppI9L+rCke8zsw0U3DADQWp4e9y5J/+TuP3L39yV9XdKvFtssAEA7eQL3z0h6pWH71do+AEAJ1uQ4xlrs80sOMtsvaX9t84Ldd98LvTSsIjZLOlt2IyqK7yYd3097fDetXZ/3wDyB+1VJWxu2Pyjp9eaD3P2QpEN5LxwDM5t39+my21FFfDfp+H7a47vpXZ5UyTFJ28zsP5rZ5ZLulvTXxTYLANBOZo/b3S+a2e9I+htJ45L+zN3/sfCWAQBaypMqkbs/LunxgttSRUOV+ukzvpt0fD/t8d30yNwvGWcEAFQYU94BIDIE7jbMbNzM/t7MHiu7LVVjZi+b2Qkz+76ZzZfdnioxs41m9oiZ/dDMnjez/1R2m6rCzG6s/TeT/LxjZr9XdrtilCvHPaI+J+l5SVeV3ZCK+iV3pxb3Ul+S9B13//VaFdaVZTeoKtz9BUkfkVaW0nhN0qOlNipS9LhbMLMPSrpL0lfKbgviYWZXSbpV0qwkufv77n6u3FZV1h5J/+zu/1J2Q2JE4G7tf0r6Q0nLZTekolzSE2Z2vDZjFsHPSjoj6c9rabavmNnashtVUXdLerjsRsSKwN3EzPZJOu3ux8tuS4XtdvePKawY+V/N7NayG1QRayR9TNJBd/+opHclsQxyk1oK6ROSvlF2W2JF4L7UbkmfMLOXFVZCvN3MHiq3SdXi7q/XXk8r5Ch3lduiynhV0qvu/lxt+xGFQI7VPi7p79z9zbIbEisCdxN3f8DdP+juUwp/zj3t7r9RcrMqw8zWmtn65L2kOyWdLLdV1eDub0h6xcxurO3aI+kHJTapqu4RaZKeUFWCTn1A0qNmJoX/fr7m7t8pt0mV8ruSvlpLB/xI0m+X3J5KMbMrJd0h6b6y2xIzZk4CQGRIlQBAZAjcABAZAjcARIbADQCRIXADQGQI3AAQGQI3AESGwA0Akfn/yEuKPnBghloAAAAASUVORK5CYII=\n", 109 | "text/plain": [ 110 | "" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "output_type": "display_data" 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "CPU times: user 113 ms, sys: 9.36 ms, total: 123 ms\n", 121 | "Wall time: 128 ms\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "%%time\n", 127 | "trainData(PerceptronBase ,X ,y)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## test PerceptronDual" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAHChJREFUeJzt3XGMHNV9B/Dv7+7sW+w7uJizoQ0214SzkWurEJ8tVSebBhsrEVYaRUECKZWaWDqgLSJKqyh2/uh/vfSPRInSYNfikrYyMSqkRI2h1BgUHaDU+C5BsROInWLHEILPphhs8J1zt7/+8XZuZ/dm3szO7uzOm/1+pNPezM7MvllZP7/7ze+9J6oKIiJyR0erG0BERLVh4CYicgwDNxGRYxi4iYgcw8BNROQYBm4iIscwcBMROYaBm4jIMQzcRESO6YpzkIj0AXgYwDoACuALqvqTsON7evr12msHGtJAIqJ2cObM5HlVXR7n2FiBG8C3ADytqp8VkcUAltgOvvbaAXz1qxMxL01ERPfeK7+Je2xk4BaRqwFsAfCXAKCqVwBcSdo4IiKqT5wc90cAnAPwPRH5mYg8LCJLqw8SkRERmRCRiUuXzjW8oUREZMQJ3F0APgZgj6reCuB9AF+pPkhV96nqkKoO9fTEStMQEVECcQL3GwDeUNUjpe3HYQI5ERG1QGTgVtW3ALwuImtKu7YC+GWqrSIiolBxq0oeAPBIqaLkNQCfT69JRERkEytwq+rLAIZSbgsREcXAkZNERI5h4CYicgwDNxGRYxi4iYgcw8BNROQYBm4iIscwcBMROYaBm4jIMQzcRESOYeAmInIMAzcRkWMYuImIHMPATUTkGAZuIiLHMHATETmGgZuIyDEM3EREjmHgJiJyDAM3EZFjGLiJiBzDwE1E5BgGbiIixzBwEzWQqn2bqBEYuIkaZHwcOHy4HKxVzfb4eGvbRfnDwE3UAKrAzAzw0kvl4H34sNmemWHPmxqrq9UNIMoDEWDbNvP7Sy+ZHwDYtMnsF2ld2yh/YvW4ReS0iBwTkZdFZCLtRhG5yB+8Pa4Gbebqs62WVMnHVfUWVR1KrTVEDvPSI37+nLcrmKvPPua4iRrAn9PetAnYvdu8+nPeLmCu3g1xc9wK4JCIKIB/VtV9KbaJyDkiQHd3ZU7bS5t0d7uTLmGu3g1xA/ewqr4pIisAPCMir6pqxR9OIjICYAQAli1b1eBmEmXfli2mR+oFNy8I1hvs/NcM2m40r91e0Abi30ez29quYqVKVPXN0usUgCcAbAo4Zp+qDqnqUE/P8sa2ksgR1UGq3qDVinxz0lw9c+PNExm4RWSpiPR6vwPYDuB42g0janetyDcnzdUzN95ccVIl1wF4QkzXoQvA91X16VRbRZRhzUoHtCLfHCdXH3b/zI03T2TgVtXXAPxJE9pClHnj46YH6QUjr2fZ3W1y3I1WT745KVuuPur+m93WdsVyQKKYWpm68GtGeWFQrj7q/ovFfNSxu4BD3olianY6oDrfvG1beRtofm/Wdv9btwLPPpudtuYde9xENahnWHuxaN8O+qzubmDjxsp888aN6deGhw15D7v/jo7g3PimTZW5cdtnUHwM3EQ1SJq6GBsD9uwpB+ti0WyPjaXTznrYyvps9+/luKtz41u2sFSw0Ri4iWJKWipXLALT08Crr5aD9549Znt6Orzn7eWUjx6tzCkfPZpuTj0sjz09DTzzjP3+k+TG2fOuHXPcRDElHdbe0QHcfz/w0EMmWD/4oNm/Zo3Z3xHSfao3px5VtpikrO/552u/f5YKNh573EQ1sKUDbF54AVi9ujJVsHq12W+TNKcelZqwvW/7zKT3n6cpb7OAgZuoRrUOa1c1aYannjJpEVXz+tRTZr8tVZAkpx6nbK+esr4kw/rzMuVtVjBVQpQyVeDECeDKFWDxYmD5cuDcObN94kR4zzNpOWCc1EQzy/qyVtaYBwzclHlJh5jbzktr2HqxWJmz9ravugpYuxZ4912zf/ly4JprzP6OjvD2JB1+HjXi0vZ+o6enzcuUt1nCwE2ZlnSIue08IJ1h62NjJvXhPXD0qkcKBeALXzAVGUePlo9fvRq4447oe0wy/Hzz5uDURPWxQe+nMT1tWlPetivmuCmzkpaRRZW0TU83vjTNVvJ3+XI5aPvL6I4eNfuj2lNriV1U2Z6Xw661rK9eaVyzXbHHTZkVJ1ebpKTNu3YjS9O8kj8vWHslfzffbPa/8EJ4qmDzZnt7Gl22541yrB6RqRqdgqFsYI+bMs1WRpa0pC2t0jQvePt5aRNbGV0a95i0bA/gKEcXMHBTpoXlYuspaUurNM1Lj/j5h7mHpQrSuMeozwsbkZlWKokai6kSyqyoMrKtW81rLSVtXuDx8s2NKk3z57S99Ii3vWdP+AjJNO4x6j5akUqixmKPmzIrrIzMm3Guo6P2meoKBfNjm8UuiY4Oc10vaHtpk5tvNvttw9ptMwBG3WPSmQNbkUqixmGPmzLNVkZWT0lbGqVpO3dW1nF7wTssaMdhu8d6RF3XVkpIrcfATZkXVg5X62g8/3ZapWnVQToqaPvzzd5/IF6+eePGyjLCsJRP9XmbNtmrQGzfXVqpJGosBm5yUl5G49VT1hdVRmj7TNt3V30dF7/XvGPgJmfVkg6p3g4bmh5Ho2ucbcPPo+7RNqzdphWpJGocPpwkp4WlPGy1yPWsRpNGjXNUeWKtZYRxS/ZakUqixmDgptyxDQe/fNn81LoaTdR1k9Y4J11VJ+l5lA9MlVDu+Idw+/O/GzeaSZ3uuMO+Gk3UjHtA42qck+bq85Ljp2QYuCmXnn/evn/1ahO4vQoV/2o0tpn6bPnopJLOnMcZ99oXUyWUO96KM88+C1y8aPZdvGi2vVRJ0Go0ly9HD/dOa7h80pwyc9HtiT1uyi2vx+z9eNsnTwavRnPypEmXhJXYAVzJhbIhdo9bRDpF5GcicjDNBhHVovphYrFogmehANx+O9Dba7Z7e832kiXmZ+1aE7QB87p2rVmNprOzPD+IZ+vWco7bNgTf+4/Bz79te4+oFrX0uB8E8AqAq1NqC1FN4qw448//AmbQyubN9tVojh2r7KHv2QOsX29yyklXpAHSWXWH2lOsHreI3ADgTgAPp9sconiiVpw5dCh8xRnbajQ//7m5xoc+BOzaZV5ffdUEc9v0rK1YdYfaV9we9zcBfBlAb4ptoRxr9GjDOCvOBK3wUiiY48LK6NavN79fuACMjprf16wB1q2zj6zkVKnUTJE9bhHZAWBKVScjjhsRkQkRmbh06VzDGkjuS2tFFduKMzZRq9EMDlYePzgYL7hyqlRqljipkmEAnxKR0wAeBXC7iOyvPkhV96nqkKoO9fQsb3AzyVVpjDb0hK04MzcXvsJL1AK809PAc89VlhE+95zZH9VWW6lgWmWE1J4iUyWqugvALgAQkT8D8Heq+rmU20U5kcZoQ8C+4szevcB99yX/zKAywiicKpWaiQNwKHVppAm8FWf8Q9Xvv99sFwqmrK/Wz7SVERYKyaZKTXPVHWpfNQ3AUdUfA/hxKi2h3LKttlJP0BocNBUk/lz14KCpx076mV6pYFAZYRROlUrNwpGTlKokK9XEve7MDDAxUV6X8fBhs21bOcb2mf5cePV5cQMtp0qlZmDgplSlNYtd1AyAtpVjOOMeuY6Bm1KX1ix23kx//lGO3n7OuEd5xoeT1BSNThNUl+6pLizd44x7lFcM3OQ0/yIHQZM8EeURAzc5ySvd27rVlOwB5nXr1nLpHmfjo7xi4CZnhZXobd6c3jB7oizgw0lykq10z+NN2+p/b9Om+ie4Imo1Bm5yUlTp3ubNnI2P8ouBm2rS6OlZ6xFVumdb1DdL90FUK+a4KbYs5o3DSvdss/Fl8T6IasEeN8Xin54VyHbeOO5MfUC274MoDAM3xZLW9KxpYP6b8o6Bm2LzAmBY3jhL6sl/E2Udc9wUm2uruCTJfxO5gD1uiiWt6VmbLS/3Qe2NgZsWCCuV6+4OXjndpSlPOXUr5QEDN1UYHzfVI15Q83qo3d2tblnjcOpWch1z3DTPtiL79LT5sa2c7hJO3UouY4+b5kWV/HnHsIyOqDHGxpKdx8BNFaJK/lhGR5Tc+Djw4g/Pmo3pGQDAcGESAPBS2EkBGLipgm11dO/3oPcYvIkWGn3gt8DcXMW+4cIk9v/x18zGzp3z+x+5N/51GbhpXtyh4iyjIyobGwOmpgCcOzvfi/as6HwbRx46GnDWzoB98TFw07yoUjmAZXTU3sbHgRcfW9iL3tW3FyPLDwDDw6ZsKWUM3FQhqlSOZXTUDvwPDaemALx+Zn57uDCJ/d8+V3XGAIBdTWiZwcBNC9hK5VhGR3kzNgZM/eJseUfVQ8NBAPs3fK0iH91qDNxE1HZGd78HXLgwv72rby9GPvps+YAFQTo7QRuIEbhFpABgHEB36fjHVfXv024YuYerylDWjI6WfvGlOjyn7t7ty0cPIGvB2SZOj3sGwO2qeklEFgF4QUT+S1X/J+W2kUNsQ+Wb8KyGCKNfWljVMVyYxJbCUYxseDagF+3uP8zIwK2qCuBSaXNR6cexAc6UJpdWxyG3jY8DJ0+Wt6cmyz3pFZ1v48jdXw/oKQzApd50HLFy3CLSCWASwE0AvqOqR1JtFTnFpdVxyC2jXwp/aAgAO/uOYuQfBnxnuNuLrkWswK2qcwBuEZE+AE+IyDpVPe4/RkRGAIwAwLJlqxreUMo2l1bHoewaHcXC0jtvlCEQkO4YaEazMqemqhJVvSAiPwbwCQDHq97bB2AfANx44xBTKW3GNlSewZuqWR8arrwN2OWvic5XmqMR4lSVLAfw+1LQvgrANgD/mHrLyBlcVYZsRne/B1y8GDzacOWBqiANNHMgi6vi9Lj/AMC/lvLcHQD+XVUPptsscglXlSGgarThy5XDwk9tuCskzcEgnUScqpKfA7i1CW0hh3FVmfYyuvs9YPpyecf0DFZ0vo3BRacBAIOLUDUsnOmORuLISWoYDofPr/Fx4MVHK/PRpzbcVd7I0HDwdsDATUTz5h8avhkyA15F6R2DdaswcBO1qbGxUi7aUwrUu/r2Ar2oCtJAu5beZREDN1HOeQ8N/aMMPadW3mbmkAaq5u2gLGPgJsqR0VGYlVg8vtGGC0cZAqzqcBMDN5HjqidXOrXyNmDFivIB8w8OB5raLkoPAzeRA8bHgRdfRMx1DdmLzjsGbqKMmV/X0ON/aFgARj59hnPltjkGbqIW8aYonfpF8DzS+z/9g/KOLVtQTnUMgNobAzdRE9jWNRwEsP+h6sVngXaZopRqx8BNlJLa1zUkioeBm6hO1ilKKyZXGgBHG1IjMHAT1WD0S2eB38+Wd8zNYUXn29jZ+1jIuoYM1NR4DNxEVcbHzeuLT1emOoBS6d1dXy/vcHSVcHIbAze1vYp1DX8/C8zNYbgwiWEA+zd8LVerg1M+MHBT2xkbC1gd/Jb7ygdUBGr2oil7GLgpd+ZHGXq4riHlDAM35cLoKCrmkB4uTGJLoTQMvC9oilIOCyd3MXCTU+Z702Gld4ODnJ6Uco+BmzIrbF3DFQCOrPxMwOrgTHdQe2DgpsyIXNewojfNVAe1LwZuaqr5UYZA4LqGw4VJ7P8GVwcnsmHgplRVrGvoG2UIIGRdQyKKwsBNDTE/2vDRhQ8Nd/XtxcgnzvChIVGDMHBTzUZHAVx8r7yjNCx8Refb2NX3WMjq4NX7iCgpBm6KhesaEmUHAzcBqHpoeDFkciWua0iUCZGBW0RWAvg3ANcDKALYp6rfSrthlL7RB8oPDYHSmoYlI3dzXUOirIrT454F8Leq+lMR6QUwKSLPqOovU24bNcD8Q8MfRqxrWLGmIcCUB1F2RQZuVf0dgN+Vfr8oIq8A+DAABu6MGR8vzSHt8T00XAFUpTo87FUTuaamHLeIDAC4FcCRgPdGAIwAwLJlqxrQNLIZHwdefGzhAJZdfXsx0nvAbHx0Bdc1JMqh2IFbRHoA/ADAF1X1ver3VXUfgH0AcOONQ9qwFra5sTFgasq3wze50nBhEvu/Xb06+ABiPTg8fhw4dAg4fx7o7we2bwfWrauvsWlck4gWiBW4RWQRTNB+RFX/I90m0fy6hkEPDVceCJhcqUbHjwMHDgCdncCSJSalcuAAcM89yQNtGtckokBxqkoEwBiAV1T1G+k3qT3Y1jUEgFN37za/LHho2IAyvEOHTIDt7jbb3d3AzIzZnzTIpnFNIgoUp8c9DOAvABwTkZdL+3ar6lPpNSt/Rnf7sksXL87P27ECwJEN9zV3XcPz502v2G/xYrM/S9ckokBxqkpeACBNaEtuVI8yBEoDWP7wM2ajF61dNqu/3/Tyvd4xAFy5YvZn6ZpEFIgjJxOKWtfw1N27AwawZGS04fbtJv88M2N6xVeumHz69u3ZuiYRBWLgrsH8uoYAMDcXsa5hhuuj160zDw0bWQGSxjWJKBADd4DxceDkSWBqMsfrGq5b505QPXgQOHzY9Oa7u4Ft24AdO+q7JksXyWFtH7grHhpOX55f1xAATnFdw/jSKgc8eBB48klABOjoMCmYJ5807yUN3ixdJMe1VeAeHUXw6uArbzO/9AIYHua6hkmkVQ54+LAJ2p2d5X1zc2Z/0sDN0kVyXC4Dt21dw1Mb7goovWOArlta5YAzM6an7Sdi9ifF0kVyXC4Cd23rGrZRqsOWx02aNx4bAyYmgGLRBNShIfMfYX8/8NZbwPQ0MDsLdHUBhQJw/fX13UN3t0mP+KlWlh3WiqWL5DinAnesdQ2BgNGGbciWxz19OlneeGwMeOml8naxWN5evdo80RUxP7OzZqBRvXN6b9tm2jY3Z66ran62bUt+TZYukuMyG7ht6xoOF05j/zeCJlcaaE7jXGDL4545kyxvPDFhXsU3HkvV7B8cBK65Bvjgg3KPe8kS4MSJ+u7Da08jq0pYukiOy0TgHhsLKb3zHhoCnKI0TFg6xJbHjcobh6VDikXzvlZN/lgsmutOT5evMTdnzvXyxlGpGZbnEcXW1MBdMUVp1bqGu/r2BqwOzoeGVrZ0iC2P+8EH4XljWzrES1VUEzHXu3y5cv/ly6bnHVXSl0ZaJ+n3xv8syAGpB27rFKUV6xoOpN2U/LGlQ2x5XC8YBuWNvaAYlA5ZuhS4dGlhO5YuNfnsIBcvRpf0pZHWSfq9MXCTA1IJ3G+duYLRvzKpj+HCJPbfxXUNU2FLh9jyuF5wCkpd/OhHwZ9VLJrjZmdNSsRTKJj9QQHdE5WaqSetkwTLAclxqQTu9avexcRXn/btyfC8HVmQNL/b3w+8/rpJSaiagHbVVcDKleb9JMPaOzpMkK5OiXR0mM+bni730r2ecH8/8M475Rx49Xlej3ZurnyeSLnHmzStAyT77lgOSI7riD6EUuXlWy9cqMy3Hj8efa5XxeEFWVWzfc019vO8nPOVK5V544MHgZtuCj7nppvMdd9/v/Lz3n/f7A+r177+emD9+sr/DFTN9vr1Znv7dhPUZ2bMe16Q377d/CWgaraLxXLw37Yt+Xdn+zwiBzBwt5o/3+r1Qjs7zf4ox46ZwOvlo70HeMeO2c/z55w7OsyrSHn/4sWVxy9ebPbbPu+dd4I/6513gHffNXlw/3lLl5r9QDmt09dn/uPp6ys/KNyxA7jzTtOGYtG83nnnwtx4Ld+d7fOIHJCJcsC2dv686fFOTZXTCD09JpcM2FMBXv7XnwMuFsv537BzbXnj8+eB665b+HDSyzdXV5b4881dXcFtOX8eWLYMuPbahdf02NI6AwPAqlXl+xgYKH93IsDZs+Xa8d7eeLlql2ZHJKrCHnerqZrKC38awduOSgV0dy/MRXv5X9u5tvP6+xfmlL38b1dXZZ7aS2F0dSW/ZhTbfRQKpkfv1YzPzZntQiH6ukQOY+BuNX81hr+Xe+lSdCrAlv+1nWs7z5b/7ekpt88fpHt6kl8zSj2pJKKcYuButdnZco7ZX6kxO2v+5A/KN3upAFv+13au7Txb/lcEuPrqyly1t530mlFs9zE9bVIwnZ3mMzs7zba/XJEoh5jjbjVv9rtFi8r75ubKKYaosrUdO4IHokSdG3YeEJ7/7e83ufjFi8s5Za8cMOk1o0Tdx4ULJifvmZkxwbseHH5PGcced6ullWJIo+Rt9WpTCTI7W54B8N13zf602O4jjXuspzyTqEkYuFstrRRDGiVvJ06Yqo2uLhMovSqOemcAtLHdRxr3yJw6OYCpkmax/fmdRoqh3nODnD9vctr+AT7VZX1psN1HGvfI4fCUcexxN0Ne/vyup6zPFe1wj+Q8Bu5myMuf3+0wVLwd7pGcFxm4ReS7IjIlIo51DzMkqqzPFe0wVLwd7pGcFyfH/S8A/gnAv6XblBxLcza6ZpeutcNQ8Xa4R3JaZI9bVccB/F8T2pJfaf35nZfcORHVhDnuZkjrz++85M6JqCYNKwcUkREAIwCwqt6Ra3mUxp/fLF0jaksN63Gr6j5VHVLVoeX+yYgoPSxdI2pLTJW4jKVrRG0pTjngAQA/AbBGRN4QkZ3pN4tiYekaUVuKzHGr6j3NaAglxNI1orbDVAkRkWMYuImIHMPATUTkGAZuIiLHMHATETmGgZuIyDEM3EREjmHgJiJyDAM3EZFjGLiJiBzDwE1E5BgGbiIixzBwExE5hoGbiMgxDNxERI5h4CYicgwDNxGRYxi4iYgcw8BNROQYBm4iIscwcBMROYaBm4jIMQzcRESOYeAmInIMAzcRkWMYuImIHMPATUTkmFiBW0Q+ISK/EpFfi8hX0m4UERGFiwzcItIJ4DsAPglgLYB7RGRt2g0jIqJgcXrcmwD8WlVfU9UrAB4F8OfpNouIiMLECdwfBvC6b/uN0j4iImqBrhjHSMA+XXCQyAiAkdLmJbn33l/V07CM6AdwvtWNyCh+N3b8fsLxuwl2Y9wD4wTuNwCs9G3fAODN6oNUdR+AfXE/2AUiMqGqQ61uRxbxu7Hj9xOO30394qRKjgIYFJE/EpHFAO4G8J/pNouIiMJE9rhVdVZE/gbAfwPoBPBdVf1F6i0jIqJAcVIlUNWnADyVcluyKFepnwbjd2PH7yccv5s6ieqC54xERJRhHPJOROQYBu4QItIpIj8TkYOtbkvWiMhpETkmIi+LyESr25MlItInIo+LyKsi8oqI/Gmr25QVIrKm9G/G+3lPRL7Y6na5KFaOu009COAVAFe3uiEZ9XFVZS3uQt8C8LSqfrZUhbWk1Q3KClX9FYBbgPmpNH4L4ImWNspR7HEHEJEbANwJ4OFWt4XcISJXA9gCYAwAVPWKql5obasyayuA/1XV37S6IS5i4A72TQBfBlBsdUMySgEcEpHJ0ohZMj4C4ByA75XSbA+LyNJWNyqj7gZwoNWNcBUDdxUR2QFgSlUnW92WDBtW1Y/BzBj51yKypdUNyoguAB8DsEdVbwXwPgBOg1yllEL6FIDHWt0WVzFwLzQM4FMichpmJsTbRWR/a5uULar6Zul1CiZHuam1LcqMNwC8oapHStuPwwRyqvRJAD9V1bOtboirGLirqOouVb1BVQdg/px7TlU/1+JmZYaILBWRXu93ANsBHG9tq7JBVd8C8LqIrCnt2grgly1sUlbdA6ZJ6sKqEqrVdQCeEBHA/Pv5vqo+3domZcoDAB4ppQNeA/D5FrcnU0RkCYA7ANzb6ra4jCMniYgcw1QJEZFjGLiJiBzDwE1E5BgGbiIixzBwExE5hoGbiMgxDNxERI5h4CYicsz/A4uBJteAUG02AAAAAElFTkSuQmCC\n", 145 | "text/plain": [ 146 | "" 147 | ] 148 | }, 149 | "metadata": {}, 150 | "output_type": "display_data" 151 | }, 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "CPU times: user 132 ms, sys: 5.06 ms, total: 137 ms\n", 157 | "Wall time: 135 ms\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "%%time\n", 163 | "trainData(PerceptronDual ,X ,y)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "Python 3", 177 | "language": "python", 178 | "name": "python3" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.6.4" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /perceptron/perceptron_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | from sklearn.model_selection import train_test_split 4 | from utils.plot import plot_decision_regions 5 | 6 | 7 | class PerceptronBase(object): 8 | """ 9 | 原始形态感知机 10 | """ 11 | def __init__(self, eta=0.1, n_iter=50): 12 | # 学习率 13 | self.eta = eta 14 | # 迭代次数 15 | self.n_iter = n_iter 16 | 17 | def fit(self, X, y): 18 | # 初始化参数w,b 19 | self.w = np.zeros(X.shape[1]) 20 | self.b = 0 21 | # 记录所有error 22 | self.errors_ = [] 23 | for _ in range(self.n_iter): 24 | errors = 0 25 | for xi, yi in zip(X, y): 26 | update = self.eta * (yi - self.predict(xi)) 27 | self.w += update * xi 28 | self.b += update 29 | errors += int(update != 0.0) 30 | if errors == 0: 31 | break 32 | self.errors_.append(errors) 33 | 34 | return self 35 | 36 | def sign(self, xi): 37 | return np.dot(xi, self.w) + self.b 38 | 39 | def predict(self, xi): 40 | return np.where(self.sign(xi) <= 0.0, -1, 1) 41 | 42 | 43 | def main(): 44 | iris = load_iris() 45 | X = iris.data[:100, [0, 2]] 46 | y = iris.target[:100] 47 | y = np.where(y == 1, 1, -1) 48 | X_train, X_test, y_train, y_test = \ 49 | train_test_split(X, y, test_size=0.3) 50 | ppn = PerceptronBase(eta=0.1, n_iter=10) 51 | ppn.fit(X_train, y_train) 52 | plot_decision_regions(ppn, X, y) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /perceptron/perceptron_dual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.datasets import load_iris 3 | from sklearn.model_selection import train_test_split 4 | from utils.plot import plot_decision_regions 5 | from perceptron.perceptron_base import PerceptronBase 6 | 7 | 8 | class PerceptronDual(PerceptronBase): 9 | """ 10 | 对偶形态感知机 11 | """ 12 | def __init__(self, eta=0.1, n_iter=50): 13 | super(PerceptronDual, self).__init__(eta=eta, n_iter=n_iter) 14 | 15 | 16 | # 计算Gram Matrix 17 | def calculate_g_matrix(self, X): 18 | n_sample = X.shape[0] 19 | self.G_matrix = np.zeros((n_sample, n_sample)) 20 | # 填充Gram Matrix 21 | for i in range(n_sample): 22 | for j in range(n_sample): 23 | self.G_matrix[i][j] = np.sum(X[i] * X[j]) 24 | 25 | # 迭代的判定条件 26 | def judge(self, X, y, index): 27 | tmp = self.b 28 | n_sample = X.shape[0] 29 | for m in range(n_sample): 30 | tmp += self.alpha[m] * y[m] * self.G_matrix[index][m] 31 | 32 | return tmp * y[index] 33 | 34 | def fit(self, X, y): 35 | """ 36 | 对偶形态的感知机 37 | 由于对偶形式中训练实例仅以内积的形式出现 38 | 因此,若事先求出Gram Matrix,能大大减少计算量 39 | """ 40 | # 读取数据集中含有的样本数,特征向量数 41 | n_samples, n_features = X.shape 42 | self.alpha, self.b = [0] * n_samples, 0 43 | self.w = np.zeros(n_features) 44 | # 计算Gram_Matrix 45 | self.calculate_g_matrix(X) 46 | 47 | i = 0 48 | while i < n_samples: 49 | if self.judge(X, y, i) <= 0: 50 | self.alpha[i] += self.eta 51 | self.b += self.eta * y[i] 52 | i = 0 53 | else: 54 | i += 1 55 | 56 | for j in range(n_samples): 57 | self.w += self.alpha[j] * X[j] * y[j] 58 | 59 | return self 60 | 61 | 62 | def main(): 63 | iris = load_iris() 64 | X = iris.data[:100, [0, 2]] 65 | y = iris.target[:100] 66 | y = np.where(y == 1, 1, -1) 67 | X_train, X_test, y_train, y_test = \ 68 | train_test_split(X, y, test_size=0.3) 69 | ppn = PerceptronDual(eta=0.1, n_iter=10) 70 | ppn.fit(X_train, y_train) 71 | plot_decision_regions(ppn, X, y) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /support_vector_machine/.ipynb_checkpoints/svm-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from sklearn.datasets import load_iris\n", 24 | "from sklearn.model_selection import train_test_split\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from support_vector_machine.svm import SVM\n", 27 | "from sklearn.svm import SVC\n", 28 | "np.random.seed(10)\n", 29 | "\n", 30 | "%matplotlib inline" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# data\n", 40 | "def create_data():\n", 41 | " iris = load_iris()\n", 42 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 43 | " df['label'] = iris.target\n", 44 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 45 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 46 | " for i in range(len(data)):\n", 47 | " if data[i,-1] == 0:\n", 48 | " data[i,-1] = -1\n", 49 | " # print(data)\n", 50 | " return data[:,:2], data[:,-1]" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "X, y = create_data()\n", 60 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "" 72 | ] 73 | }, 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | }, 78 | { 79 | "data": { 80 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "metadata": {}, 86 | "output_type": "display_data" 87 | } 88 | ], 89 | "source": [ 90 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 91 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 92 | "plt.legend()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "----\n", 100 | "\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "svm = SVM(max_iter=200)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 7, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "'train done!'" 121 | ] 122 | }, 123 | "execution_count": 7, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "svm.fit(X_train, y_train)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 8, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "0.48" 141 | ] 142 | }, 143 | "execution_count": 8, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "svm.score(X_test, y_test)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 9, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 161 | " decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n", 162 | " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", 163 | " tol=0.001, verbose=False)" 164 | ] 165 | }, 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "\n", 173 | "clf = SVC()\n", 174 | "clf.fit(X_train, y_train)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 10, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "0.96" 186 | ] 187 | }, 188 | "execution_count": 10, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "clf.score(X_test, y_test)" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.6.4" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /support_vector_machine/README.md: -------------------------------------------------------------------------------- 1 | # 实现支持向量机(SVM)的算法 2 | 3 | # 结果比较 4 | 结果在svm.ipynb中展示 5 | 6 | # 相关博客 7 | #### [1. 感知机原理(Perceptron)](https://www.cnblogs.com/huangyc/p/9706575.html) 8 | #### [2. 感知机(Perceptron)基本形式和对偶形式实现](https://www.cnblogs.com/huangyc/p/10294583.html) 9 | #### [3. 支持向量机(SVM)拉格朗日对偶性(KKT)](https://www.cnblogs.com/huangyc/p/9979178.html) 10 | #### [4. 支持向量机(SVM)原理](https://www.cnblogs.com/huangyc/p/9931233.html) 11 | #### [5. 支持向量机(SVM)软间隔](https://www.cnblogs.com/huangyc/p/9938306.html) 12 | #### [6. 支持向量机(SVM)核函数](https://www.cnblogs.com/huangyc/p/9940487.html) -------------------------------------------------------------------------------- /support_vector_machine/svm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "searchPath=os.path.abspath('..')\n", 12 | "sys.path.append(searchPath)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from sklearn.datasets import load_iris\n", 24 | "from sklearn.model_selection import train_test_split\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "from support_vector_machine.svm import SVM\n", 27 | "from sklearn.svm import SVC\n", 28 | "np.random.seed(10)\n", 29 | "\n", 30 | "%matplotlib inline" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# data\n", 40 | "def create_data():\n", 41 | " iris = load_iris()\n", 42 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 43 | " df['label'] = iris.target\n", 44 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 45 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 46 | " for i in range(len(data)):\n", 47 | " if data[i,-1] == 0:\n", 48 | " data[i,-1] = -1\n", 49 | " # print(data)\n", 50 | " return data[:,:2], data[:,-1]" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "X, y = create_data()\n", 60 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "" 72 | ] 73 | }, 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | }, 78 | { 79 | "data": { 80 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "metadata": {}, 86 | "output_type": "display_data" 87 | } 88 | ], 89 | "source": [ 90 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 91 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 92 | "plt.legend()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "----\n", 100 | "\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "svm = SVM(max_iter=200)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 7, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "'train done!'" 121 | ] 122 | }, 123 | "execution_count": 7, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "svm.fit(X_train, y_train)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 8, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "0.48" 141 | ] 142 | }, 143 | "execution_count": 8, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | } 147 | ], 148 | "source": [ 149 | "svm.score(X_test, y_test)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 9, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 161 | " decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n", 162 | " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", 163 | " tol=0.001, verbose=False)" 164 | ] 165 | }, 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "\n", 173 | "clf = SVC()\n", 174 | "clf.fit(X_train, y_train)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 10, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "0.96" 186 | ] 187 | }, 188 | "execution_count": 10, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "clf.score(X_test, y_test)" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.6.4" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /support_vector_machine/svm.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn.svm import SVC 3 | from utils.data_generater import * 4 | 5 | class SVM(object): 6 | def __init__(self, max_iter=100, kernel='linear'): 7 | self.max_iter = max_iter 8 | self._kernel = kernel 9 | 10 | def init_args(self, features, labels): 11 | self.m, self.n = features.shape 12 | self.X = features 13 | self.Y = labels 14 | self.b = 0.0 15 | 16 | # 将Ei保存在一个列表里 17 | self.alpha = np.ones(self.m) 18 | self.E = [self._E(i) for i in range(self.m)] 19 | # 松弛变量 20 | self.C = 1.0 21 | 22 | def _KKT(self, i): 23 | y_g = self._g(i) * self.Y[i] 24 | if self.alpha[i] == 0: 25 | return y_g >= 1 26 | elif 0 < self.alpha[i] < self.C: 27 | return y_g == 1 28 | else: 29 | return y_g <= 1 30 | 31 | # g(x)预测值,输入xi(X[i]) 32 | def _g(self, i): 33 | r = self.b 34 | for j in range(self.m): 35 | r += self.alpha[j] * self.Y[j] * self.kernel(self.X[i], self.X[j]) 36 | return r 37 | 38 | # 核函数 39 | def kernel(self, x1, x2): 40 | if self._kernel == 'linear': 41 | return sum([x1[k] * x2[k] for k in range(self.n)]) 42 | elif self._kernel == 'poly': 43 | return (sum([x1[k] * x2[k] for k in range(self.n)]) + 1) ** 2 44 | 45 | return 0 46 | 47 | # E(x)为g(x)对输入x的预测值和y的差 48 | def _E(self, i): 49 | return self._g(i) - self.Y[i] 50 | 51 | def _init_alpha(self): 52 | # 外层循环首先遍历所有满足0= 0: 65 | j = min(range(self.m), key=lambda x: self.E[x]) 66 | else: 67 | j = max(range(self.m), key=lambda x: self.E[x]) 68 | return i, j 69 | 70 | def _compare(self, _alpha, L, H): 71 | if _alpha > H: 72 | return H 73 | elif _alpha < L: 74 | return L 75 | else: 76 | return _alpha 77 | 78 | def fit(self, features, labels): 79 | self.init_args(features, labels) 80 | 81 | for t in range(self.max_iter): 82 | # train 83 | i1, i2 = self._init_alpha() 84 | 85 | # 边界 86 | if self.Y[i1] == self.Y[i2]: 87 | L = max(0, self.alpha[i1] + self.alpha[i2] - self.C) 88 | H = min(self.C, self.alpha[i1] + self.alpha[i2]) 89 | else: 90 | L = max(0, self.alpha[i2] - self.alpha[i1]) 91 | H = min(self.C, self.C + self.alpha[i2] - self.alpha[i1]) 92 | 93 | E1 = self.E[i1] 94 | E2 = self.E[i2] 95 | # eta=K11+K22-2K12 96 | eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2 * self.kernel( 97 | self.X[i1], 98 | self.X[i2]) 99 | if eta <= 0: 100 | # print('eta <= 0') 101 | continue 102 | 103 | alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E1 - E2) / eta # 此处有修改,根据书上应该是E1 - E2,书上130-131页 104 | alpha2_new = self._compare(alpha2_new_unc, L, H) 105 | 106 | alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new) 107 | 108 | b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new - self.alpha[i1]) - self.Y[ 109 | i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new - self.alpha[i2]) + self.b 110 | b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new - self.alpha[i1]) - self.Y[ 111 | i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new - self.alpha[i2]) + self.b 112 | 113 | if 0 < alpha1_new < self.C: 114 | b_new = b1_new 115 | elif 0 < alpha2_new < self.C: 116 | b_new = b2_new 117 | else: 118 | # 选择中点 119 | b_new = (b1_new + b2_new) / 2 120 | 121 | # 更新参数 122 | self.alpha[i1] = alpha1_new 123 | self.alpha[i2] = alpha2_new 124 | self.b = b_new 125 | 126 | self.E[i1] = self._E(i1) 127 | self.E[i2] = self._E(i2) 128 | return 'train done!' 129 | 130 | def predict(self, data): 131 | r = self.b 132 | for i in range(self.m): 133 | r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i]) 134 | 135 | return 1 if r > 0 else -1 136 | 137 | def score(self, X_test, y_test): 138 | right_count = 0 139 | for i in range(len(X_test)): 140 | result = self.predict(X_test[i]) 141 | if result == y_test[i]: 142 | right_count += 1 143 | return right_count / len(X_test) 144 | 145 | def _weight(self): 146 | # linear model 147 | yx = self.Y.reshape(-1, 1) * self.X 148 | self.w = np.dot(yx.T, self.alpha) 149 | return self.w 150 | 151 | 152 | if __name__ == "__main__": 153 | X_train, X_test, y_train, y_test = create_svm_data() 154 | 155 | # 我们的svm 156 | my_svm = svm = SVM(max_iter=200) 157 | my_svm.fit(X_train, y_train) 158 | print("my svm score", my_svm.score(X_test, y_test)) 159 | 160 | # sklearn的svc 161 | sklearn_svc = SVC() 162 | sklearn_svc.fit(X_train, y_train) 163 | print("sklearn svm score", sklearn_svc.score(X_test, y_test)) -------------------------------------------------------------------------------- /utils/data_generater.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sklearn.datasets import load_iris 4 | from sklearn.model_selection import train_test_split 5 | from random import random 6 | 7 | np.random.seed(10) 8 | def makeRandomPoint(num, dim, upper): 9 | return np.random.normal(loc=upper, size=[num, dim]) 10 | 11 | 12 | # 产生一个k维随机向量,每维分量值在0~1之间 13 | def random_point(k): 14 | return [random() for _ in range(k)] 15 | 16 | 17 | # 产生n个k维随机向量 18 | def random_points(k, n): 19 | return [random_point(k) for _ in range(n)] 20 | 21 | 22 | # data 23 | def create_logistic_data(): 24 | iris = load_iris() 25 | df = pd.DataFrame(iris.data, columns=iris.feature_names) 26 | df['label'] = iris.target 27 | df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label'] 28 | data = np.array(df.iloc[:100, [0,1,-1]]) 29 | X, y = data[:,:2], data[:,-1] 30 | return train_test_split(X, y, test_size=0.3) 31 | 32 | def create_svm_data(): 33 | iris = load_iris() 34 | df = pd.DataFrame(iris.data, columns=iris.feature_names) 35 | df['label'] = iris.target 36 | df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label'] 37 | data = np.array(df.iloc[:100, [0, 1, -1]]) 38 | for i in range(len(data)): 39 | if data[i, -1] == 0: 40 | data[i, -1] = -1 41 | X, y = data[:,:2], data[:,-1] 42 | return train_test_split(X, y, test_size=0.3) -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numbers 3 | from scipy.stats import multivariate_normal 4 | 5 | def distance(point1, point2): 6 | return np.sqrt(np.sum(np.square(point1 - point2), axis=1)) 7 | 8 | 9 | def check_random_state(seed): 10 | if seed is None or seed is np.random: 11 | return np.random.mtrand._rand 12 | if isinstance(seed, (numbers.Integral, np.integer)): 13 | return np.random.RandomState(seed) 14 | if isinstance(seed, np.random.RandomState): 15 | return seed 16 | raise ValueError('%r cannot be used to seed a numpy.random.RandomState' 17 | ' instance' % seed) 18 | 19 | 20 | def sortLabel(label): 21 | label = np.array(label) 22 | labelOld = [] 23 | labelNum = len(list(set(label))) 24 | for i in label: 25 | if i not in labelOld: 26 | labelOld.append(i) 27 | if len(labelOld) == labelNum: 28 | break 29 | 30 | labelNew = sorted(labelOld) 31 | for i, old in enumerate(labelOld): 32 | label[label == old] = labelNew[i] + 10000 33 | return label - 10000 34 | 35 | def prob(x, mu, cov): 36 | norm = multivariate_normal(mean=mu, cov=cov) 37 | return norm.pdf(x) 38 | 39 | def log_prob(x, mu, cov): 40 | norm = multivariate_normal(mean=mu, cov=cov) 41 | return norm.logpdf(x) 42 | 43 | 44 | def log_weight_prob(x, alpha, mu, cov): 45 | N = x.shape[0] 46 | return np.mat(np.log(alpha) + log_prob(x, mu, cov)).reshape([N, 1]) -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.colors import ListedColormap 3 | import numpy as np 4 | 5 | def plot_decision_regions(model, X, y, resolution=0.02): 6 | """ 7 | 拟合效果可视化 8 | :param X:training sets 9 | :param y:training labels 10 | :param resolution:分辨率 11 | :return:None 12 | """ 13 | # initialization colors map 14 | colors = ['red', 'blue'] 15 | markers = ['o', 'x'] 16 | cmap = ListedColormap(colors[:len(np.unique(y))]) 17 | 18 | # plot the decision regions 19 | x1_max, x1_min = max(X[:, 0]) + 1, min(X[:, 0]) - 1 20 | x2_max, x2_min = max(X[:, 1]) + 1, min(X[:, 1]) - 1 21 | xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution), 22 | np.arange(x2_min, x2_max, resolution)) 23 | Z = model.predict(np.array([xx1.ravel(), xx2.ravel()]).T) 24 | Z = Z.reshape(xx1.shape) 25 | plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap) 26 | plt.xlim(xx1.min(), xx1.max()) 27 | plt.ylim(xx2.min(), xx2.max()) 28 | 29 | # plot class samples 30 | for idx, cl in enumerate(np.unique(y)): 31 | plt.scatter(x=X[y == cl, 0], y=X[y == cl, 1], 32 | alpha=0.8, c=cmap(idx), 33 | marker=markers[idx], label=cl) 34 | plt.show() 35 | 36 | 37 | def plot_knn_predict(model, dataset, label, x): 38 | dataset = np.array(dataset) 39 | plt.scatter(x[0], x[1], c="r", marker='*', s = 40) # 测试点 40 | near, predict_label = model.predict(x) # 设置临近点的个数 41 | plt.scatter(dataset[:,0], dataset[:,1], c=label, s = 50) # 画所有的数据点 42 | for n in near: 43 | print(n) 44 | plt.scatter(n.data[0], n.data[1], c="r", marker='+', s = 40) # k个最近邻点 45 | plt.show() -------------------------------------------------------------------------------- /utils/word_utils.py: -------------------------------------------------------------------------------- 1 | def createVocabList(dataSet): 2 | ''' 3 | 创建所有文档中出现的不重复词汇列表 4 | Args: 5 | dataSet: 所有文档 6 | Return: 7 | 包含所有文档的不重复词列表,即词汇表 8 | ''' 9 | vocabSet = set([]) 10 | # 创建两个集合的并集 11 | for document in dataSet: 12 | vocabSet = vocabSet | set(document) 13 | return list(vocabSet) 14 | 15 | 16 | # 词袋模型(bag-of-words model):词在文档中出现的次数 17 | def bagOfWords2Vec(vocabList, inputSet): 18 | ''' 19 | 依据词汇表,将输入文本转化成词袋模型词向量 20 | Args: 21 | vocabList: 词汇表 22 | inputSet: 当前输入文档 23 | Return: 24 | returnVec: 转换成词向量的文档 25 | 例子: 26 | vocabList = ['I', 'love', 'python', 'and', 'machine', 'learning'] 27 | inputset = ['python', 'machine', 'learning', 'python', 'machine'] 28 | returnVec = [0, 0, 2, 0, 2, 1] 29 | 长度与词汇表一样长,出现了的位置为1,未出现为0,如果词汇表中无该单词则print 30 | ''' 31 | returnVec = [0] * len(vocabList) 32 | for word in inputSet: 33 | if word in vocabList: 34 | returnVec[vocabList.index(word)] += 1 35 | else: 36 | print("the word: %s is not in my vocabulary!" % word) 37 | return returnVec 38 | 39 | 40 | # 词集模型(set-of-words model):词在文档中是否存在,存在为1,不存在为0 41 | def setOfWord2Vec(vocabList, inputSet): 42 | ''' 43 | 依据词汇表,将输入文本转化成词集模型词向量 44 | Args: 45 | vocabList: 词汇表 46 | inputSet: 当前输入文档 47 | Return: 48 | returnVec: 转换成词向量的文档 49 | 例子: 50 | vocabList = ['I', 'love', 'python', 'and', 'machine', 'learning'] 51 | inputset = ['python', 'machine', 'learning'] 52 | returnVec = [0, 0, 1, 0, 1, 1] 53 | 长度与词汇表一样长,出现了的位置为1,未出现为0,如果词汇表中无该单词则print 54 | ''' 55 | returnVec = [0] * len(vocabList) 56 | for word in inputSet: 57 | if word in vocabList: 58 | returnVec[vocabList.index(word)] = 1 59 | else: 60 | print("the word: %s is not in my vocabulary!" % word) 61 | return returnVec 62 | 63 | 64 | --------------------------------------------------------------------------------