├── decision_tree ├── lenses.gif ├── sms_tree.gif ├── sms_tree.pkl ├── english_big.txt ├── sms_tree_2.gif ├── lenses.py ├── lenses.txt ├── sms_tree_2.dot ├── sms_tree.dot ├── lenses.dot ├── sms_tree.py └── trees.py ├── linear_regression ├── lasso_ws ├── lwlr_k01.png ├── lwlr_k05.png ├── lasso_traj.png ├── lwlr_k003.png ├── ridge_traj.png ├── stage_traj.png ├── ridge_traj_ori.png ├── std_linear_regression.png ├── local_weighted_linear_regression.py ├── stage_wise_regression.py ├── standard_linear_regression.py ├── ridge_regression.py ├── lasso_regression.py ├── ex0.txt ├── ex1.txt └── lasso_regression.ipynb ├── naive_bayes ├── english_big.txt ├── distribution.png ├── bayes.py └── sms.py ├── Reinforcement Learning ├── T.npy ├── Calculating Transition Probabilities.ipynb ├── Defining Initial Distribution.ipynb ├── Calculating State Utilities.ipynb ├── Value Iteration Algorithm.ipynb └── Policy Iteration Algorithm.ipynb ├── logistic_regression ├── english_big.txt ├── grad_ascent_params.png ├── grad_ascent_animation.gif ├── stoch_grad_ascent_params.png ├── stoch_grad_ascent_animation.gif ├── logreg_stoch_grad_ascent.py ├── testSet.txt ├── sms.py └── logreg_grad_ascent.py ├── support_vector_machine ├── svm_ga.png ├── platt_smo.png ├── simple_smo.png ├── testSet.txt ├── svm_ga.py ├── svm_simple_smo.py └── svm_platt_smo.py ├── classification_and_regression_trees ├── pic │ ├── Thumbs.db │ ├── ex0_data.png │ ├── ex0_tree.png │ ├── ex2_data.png │ ├── bike_data.png │ ├── ex00_tree.png │ ├── ex2_prune.png │ ├── ex2_tree_1.png │ ├── exp2_data.png │ ├── exp2_tree.png │ ├── ex2_tree_2000.png │ ├── bike_regression.png │ ├── ex00_regression.png │ ├── ex0_regression.png │ ├── exp2_regression.png │ ├── ex2_regression_ori.png │ └── ex2_regression_2000.png ├── dot │ ├── ex00.dot │ ├── exp2.dot │ ├── ex0.dot │ ├── ex2_prune.dot │ └── ex2.dot ├── exp2.dot ├── ex2.dot ├── prune.py ├── compare.py ├── model_tree.py ├── ex0.txt ├── exp2.txt ├── ex00.txt ├── exp.txt ├── ex2test.txt ├── ex2.txt ├── bikeSpeedVsIq_train.txt ├── bikeSpeedVsIq_test.txt ├── regression_tree.py └── notebook │ └── 后剪枝.ipynb ├── .gitignore └── README.md /decision_tree/lenses.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/decision_tree/lenses.gif -------------------------------------------------------------------------------- /decision_tree/sms_tree.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/decision_tree/sms_tree.gif -------------------------------------------------------------------------------- /decision_tree/sms_tree.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/decision_tree/sms_tree.pkl -------------------------------------------------------------------------------- /linear_regression/lasso_ws: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/lasso_ws -------------------------------------------------------------------------------- /naive_bayes/english_big.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/naive_bayes/english_big.txt -------------------------------------------------------------------------------- /Reinforcement Learning/T.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/Reinforcement Learning/T.npy -------------------------------------------------------------------------------- /decision_tree/english_big.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/decision_tree/english_big.txt -------------------------------------------------------------------------------- /decision_tree/sms_tree_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/decision_tree/sms_tree_2.gif -------------------------------------------------------------------------------- /linear_regression/lwlr_k01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/lwlr_k01.png -------------------------------------------------------------------------------- /linear_regression/lwlr_k05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/lwlr_k05.png -------------------------------------------------------------------------------- /naive_bayes/distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/naive_bayes/distribution.png -------------------------------------------------------------------------------- /linear_regression/lasso_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/lasso_traj.png -------------------------------------------------------------------------------- /linear_regression/lwlr_k003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/lwlr_k003.png -------------------------------------------------------------------------------- /linear_regression/ridge_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/ridge_traj.png -------------------------------------------------------------------------------- /linear_regression/stage_traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/stage_traj.png -------------------------------------------------------------------------------- /logistic_regression/english_big.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/logistic_regression/english_big.txt -------------------------------------------------------------------------------- /support_vector_machine/svm_ga.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/support_vector_machine/svm_ga.png -------------------------------------------------------------------------------- /linear_regression/ridge_traj_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/ridge_traj_ori.png -------------------------------------------------------------------------------- /support_vector_machine/platt_smo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/support_vector_machine/platt_smo.png -------------------------------------------------------------------------------- /support_vector_machine/simple_smo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/support_vector_machine/simple_smo.png -------------------------------------------------------------------------------- /logistic_regression/grad_ascent_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/logistic_regression/grad_ascent_params.png -------------------------------------------------------------------------------- /linear_regression/std_linear_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/linear_regression/std_linear_regression.png -------------------------------------------------------------------------------- /logistic_regression/grad_ascent_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/logistic_regression/grad_ascent_animation.gif -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/Thumbs.db -------------------------------------------------------------------------------- /logistic_regression/stoch_grad_ascent_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/logistic_regression/stoch_grad_ascent_params.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex0_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex0_data.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex0_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex0_tree.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_data.png -------------------------------------------------------------------------------- /logistic_regression/stoch_grad_ascent_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/logistic_regression/stoch_grad_ascent_animation.gif -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/bike_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/bike_data.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex00_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex00_tree.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_prune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_prune.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_tree_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_tree_1.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/exp2_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/exp2_data.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/exp2_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/exp2_tree.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_tree_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_tree_2000.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/bike_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/bike_regression.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex00_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex00_regression.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex0_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex0_regression.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/exp2_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/exp2_regression.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_regression_ori.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_regression_ori.png -------------------------------------------------------------------------------- /classification_and_regression_trees/pic/ex2_regression_2000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PytLab/MLBox/HEAD/classification_and_regression_trees/pic/ex2_regression_2000.png -------------------------------------------------------------------------------- /classification_and_regression_trees/dot/ex00.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "ccd352d8-dbf6-4f59-ae0b-c983f39e5c87" [label="0: 0.50794"]; 3 | "46052817-27f4-4748-8f02-43ee4d2315dc" [label="-0.04"]; 4 | "b2df42be-284f-415f-8db7-18f807450a5b" [label="1.02"]; 5 | "ccd352d8-dbf6-4f59-ae0b-c983f39e5c87" -> "46052817-27f4-4748-8f02-43ee4d2315dc"; 6 | "ccd352d8-dbf6-4f59-ae0b-c983f39e5c87" -> "b2df42be-284f-415f-8db7-18f807450a5b"; 7 | } -------------------------------------------------------------------------------- /decision_tree/lenses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from trees import DecisionTreeClassifier 5 | 6 | lense_labels = ['age', 'prescript', 'astigmatic', 'tearRate'] 7 | X = [] 8 | Y = [] 9 | 10 | with open('lenses.txt', 'r') as f: 11 | for line in f: 12 | comps = line.strip().split('\t') 13 | X.append(comps[: -1]) 14 | Y.append(comps[-1]) 15 | 16 | clf = DecisionTreeClassifier() 17 | clf.create_tree(X, Y, lense_labels) 18 | 19 | -------------------------------------------------------------------------------- /classification_and_regression_trees/dot/exp2.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "5c49cf77-b404-459e-b4fd-513a927807dc" [label="0: 0.304401"]; 3 | "83d1a5dd-ca47-4f50-845b-387b99fa210e" [label="[3.4687793552577886, 1.1852174309187824]"]; 4 | "dff8f3c5-1acc-4500-993a-7ab19e72d907" [label="[0.0016985569361161585, 11.964773944276974]"]; 5 | "5c49cf77-b404-459e-b4fd-513a927807dc" -> "83d1a5dd-ca47-4f50-845b-387b99fa210e"; 6 | "5c49cf77-b404-459e-b4fd-513a927807dc" -> "dff8f3c5-1acc-4500-993a-7ab19e72d907"; 7 | } -------------------------------------------------------------------------------- /classification_and_regression_trees/exp2.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "c830d5ff-5d25-4637-a268-2bb63f5d4351" [label="0: 0.304401"]; 3 | "44889deb-3d44-405b-a7cf-d8dfa5604cb9" [label="[3.4687793552577886, 1.1852174309187824]"]; 4 | "4a419f47-2097-4b6e-b01e-047203bf4370" [label="[0.0016985569361161585, 11.964773944276974]"]; 5 | "c830d5ff-5d25-4637-a268-2bb63f5d4351" -> "44889deb-3d44-405b-a7cf-d8dfa5604cb9"; 6 | "c830d5ff-5d25-4637-a268-2bb63f5d4351" -> "4a419f47-2097-4b6e-b01e-047203bf4370"; 7 | } -------------------------------------------------------------------------------- /classification_and_regression_trees/ex2.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "e1b05249-eb8e-4afd-837c-d2f5a5299a6a" [label="0: 0.508542"]; 3 | "b82d5e44-41de-40ec-8558-fad039b53058" [label="-2.64"]; 4 | "0b668e3e-42eb-4735-a6ba-420826ffc809" [label="0: 0.731636"]; 5 | "e1a950cd-cd46-4ce1-941e-c59c56377d2e" [label="107.69"]; 6 | "b2ee8f32-0401-4b83-a2ee-3f9212b6d8a1" [label="96.32"]; 7 | "e1b05249-eb8e-4afd-837c-d2f5a5299a6a" -> "b82d5e44-41de-40ec-8558-fad039b53058"; 8 | "e1b05249-eb8e-4afd-837c-d2f5a5299a6a" -> "0b668e3e-42eb-4735-a6ba-420826ffc809"; 9 | "0b668e3e-42eb-4735-a6ba-420826ffc809" -> "e1a950cd-cd46-4ce1-941e-c59c56377d2e"; 10 | "0b668e3e-42eb-4735-a6ba-420826ffc809" -> "b2ee8f32-0401-4b83-a2ee-3f9212b6d8a1"; 11 | } -------------------------------------------------------------------------------- /decision_tree/lenses.txt: -------------------------------------------------------------------------------- 1 | young myope no reduced no lenses 2 | young myope no normal soft 3 | young myope yes reduced no lenses 4 | young myope yes normal hard 5 | young hyper no reduced no lenses 6 | young hyper no normal soft 7 | young hyper yes reduced no lenses 8 | young hyper yes normal hard 9 | pre myope no reduced no lenses 10 | pre myope no normal soft 11 | pre myope yes reduced no lenses 12 | pre myope yes normal hard 13 | pre hyper no reduced no lenses 14 | pre hyper no normal soft 15 | pre hyper yes reduced no lenses 16 | pre hyper yes normal no lenses 17 | presbyopic myope no reduced no lenses 18 | presbyopic myope no normal no lenses 19 | presbyopic myope yes reduced no lenses 20 | presbyopic myope yes normal hard 21 | presbyopic hyper no reduced no lenses 22 | presbyopic hyper no normal soft 23 | presbyopic hyper yes reduced no lenses 24 | presbyopic hyper yes normal no lenses 25 | -------------------------------------------------------------------------------- /classification_and_regression_trees/dot/ex0.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "5db27cbb-29af-4987-9cd2-9217c781000d" [label="0: 0.400158"]; 3 | "a81daf61-ab07-4e65-8b8a-55ee0bd0b40c" [label="0: 0.208197"]; 4 | "1f1412f1-659b-4347-8013-f6e57e634c2b" [label="-0.02"]; 5 | "32292eec-1e38-4eff-9700-cba03d93d7d8" [label="1.03"]; 6 | "7cee5c66-0140-4be7-ab6e-01245b3c8199" [label="0: 0.609483"]; 7 | "3308a031-f17e-494c-b015-9b5f3d904dba" [label="1.98"]; 8 | "d53cb038-a0fa-4635-9a07-36d40c33d6b9" [label="0: 0.816742"]; 9 | "adeac9bb-ef8a-4a91-a821-bade5e047d0d" [label="2.98"]; 10 | "208da594-0e82-4b01-af38-b60bac08624d" [label="3.99"]; 11 | "5db27cbb-29af-4987-9cd2-9217c781000d" -> "a81daf61-ab07-4e65-8b8a-55ee0bd0b40c"; 12 | "a81daf61-ab07-4e65-8b8a-55ee0bd0b40c" -> "1f1412f1-659b-4347-8013-f6e57e634c2b"; 13 | "a81daf61-ab07-4e65-8b8a-55ee0bd0b40c" -> "32292eec-1e38-4eff-9700-cba03d93d7d8"; 14 | "5db27cbb-29af-4987-9cd2-9217c781000d" -> "7cee5c66-0140-4be7-ab6e-01245b3c8199"; 15 | "7cee5c66-0140-4be7-ab6e-01245b3c8199" -> "3308a031-f17e-494c-b015-9b5f3d904dba"; 16 | "7cee5c66-0140-4be7-ab6e-01245b3c8199" -> "d53cb038-a0fa-4635-9a07-36d40c33d6b9"; 17 | "d53cb038-a0fa-4635-9a07-36d40c33d6b9" -> "adeac9bb-ef8a-4a91-a821-bade5e047d0d"; 18 | "d53cb038-a0fa-4635-9a07-36d40c33d6b9" -> "208da594-0e82-4b01-af38-b60bac08624d"; 19 | } -------------------------------------------------------------------------------- /classification_and_regression_trees/prune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from regression_tree import * 5 | 6 | def not_tree(tree): 7 | ''' 判断是否不是一棵树结构 8 | ''' 9 | return type(tree) is not dict 10 | 11 | def collapse(tree): 12 | ''' 对一棵树进行塌陷处理, 得到给定树结构的平均值 13 | ''' 14 | if not_tree(tree): 15 | return tree 16 | ltree, rtree = tree['left'], tree['right'] 17 | return (collapse(ltree) + collapse(rtree))/2 18 | 19 | def postprune(tree, test_data): 20 | ''' 根据测试数据对树结构进行后剪枝 21 | ''' 22 | if not_tree(tree): 23 | return tree 24 | 25 | # 若没有测试数据则直接返回树平均值 26 | if not test_data: 27 | return collapse(tree) 28 | 29 | ltree, rtree = tree['left'], tree['right'] 30 | 31 | if not_tree(ltree) and not_tree(rtree): 32 | # 分割数据用于测试 33 | ldata, rdata = split_dataset(test_data, tree['feat_idx'], tree['feat_val']) 34 | # 分别计算合并前和合并后的测试数据误差 35 | err_no_merge = (np.sum((np.array(ldata) - ltree)**2) + 36 | np.sum((np.array(rdata) - rtree)**2)) 37 | err_merge = np.sum((np.array(test_data) - (ltree + rtree)/2)**2) 38 | 39 | if err_merge < err_no_merge: 40 | print('merged') 41 | return (ltree + rtree)/2 42 | else: 43 | return tree 44 | 45 | tree['left'] = postprune(tree['left'], test_data) 46 | tree['right'] = postprune(tree['right'], test_data) 47 | 48 | return tree 49 | 50 | -------------------------------------------------------------------------------- /linear_regression/local_weighted_linear_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from math import exp 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from standard_linear_regression import load_data, get_corrcoef 10 | 11 | def lwlr(x, X, Y, k): 12 | ''' 局部加权线性回归,给定一个点,获取相应权重矩阵并返回回归系数 13 | ''' 14 | m = X.shape[0] 15 | 16 | # 创建针对x的权重矩阵 17 | W = np.matrix(np.zeros((m, m))) 18 | for i in range(m): 19 | xi = np.array(X[i][0]) 20 | x = np.array(x) 21 | W[i, i] = exp((np.linalg.norm(x - xi))/(-2*k**2)) 22 | 23 | # 获取此点相应的回归系数 24 | 25 | xWx = X.T*W*X 26 | if np.linalg.det(xWx) == 0: 27 | print('xWx is a singular matrix') 28 | return 29 | w = xWx.I*X.T*W*Y 30 | 31 | return w 32 | 33 | if '__main__' == __name__: 34 | k = 0.03 35 | 36 | X, Y = load_data('ex0.txt') 37 | 38 | y_prime = [] 39 | for x in X.tolist(): 40 | w = lwlr(x, X, Y, k).reshape(1, -1).tolist()[0] 41 | y_prime.append(np.dot(x, w)) 42 | 43 | corrcoef = get_corrcoef(np.array(Y.reshape(1, -1)), np.array(y_prime)) 44 | print('Correlation coefficient: {}'.format(corrcoef)) 45 | 46 | fig = plt.figure() 47 | ax = fig.add_subplot(111) 48 | 49 | # 绘制数据点 50 | x = X[:, 1].reshape(1, -1).tolist()[0] 51 | y = Y.reshape(1, -1).tolist()[0] 52 | ax.scatter(x, y) 53 | 54 | # 绘制拟合直线 55 | x, y = list(zip(*sorted(zip(x, y_prime), key=lambda x: x[0]))) 56 | ax.plot(x, y, c='r') 57 | 58 | plt.show() 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | .DS_Store 104 | 105 | *.swp 106 | -------------------------------------------------------------------------------- /linear_regression/stage_wise_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from standard_linear_regression import load_data, get_corrcoef 8 | from standard_linear_regression import standarize 9 | 10 | def stagewise_regression(X, y, eps=0.01, niter=100): 11 | ''' 通过向前逐步回归获取回归系数 12 | ''' 13 | m, n = X.shape 14 | w = np.matrix(np.zeros((n, 1))) 15 | min_error = float('inf') 16 | all_ws = np.matrix(np.zeros((niter, n))) 17 | 18 | # 计算残差平方和 19 | rss = lambda X, y, w: (y - X*w).T*(y - X*w) 20 | 21 | for i in range(niter): 22 | print('{}: w = {}'.format(i, w.T[0, :])) 23 | for j in range(n): 24 | for sign in [-1, 1]: 25 | w_test = w.copy() 26 | w_test[j, 0] += eps*sign 27 | test_error = rss(X, y, w_test) 28 | if test_error < min_error: 29 | min_error = test_error 30 | w = w_test 31 | all_ws[i, :] = w.T 32 | 33 | return all_ws 34 | 35 | if '__main__' == __name__: 36 | X, y = load_data('abalone.txt') 37 | X, y = standarize(X), standarize(y) 38 | 39 | epsilon = 0.005 40 | niter = 1000 41 | all_ws = stagewise_regression(X, y, eps=epsilon, niter=niter) 42 | 43 | w = all_ws[-1, :] 44 | y_prime = X*w.T 45 | 46 | # 计算相关系数 47 | corrcoef = get_corrcoef(np.array(y.reshape(1, -1)), 48 | np.array(y_prime.reshape(1, -1))) 49 | print('Correlation coefficient: {}'.format(corrcoef)) 50 | 51 | # 绘制逐步线性回归回归系数变化轨迹 52 | 53 | fig = plt.figure() 54 | ax = fig.add_subplot(111) 55 | ax.plot(list(range(niter)), all_ws) 56 | plt.show() 57 | 58 | -------------------------------------------------------------------------------- /classification_and_regression_trees/compare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from regression_tree import * 5 | from model_tree import linear_regression 6 | 7 | def get_corrcoef(X, Y): 8 | # X Y 的协方差 9 | cov = np.mean(X*Y) - np.mean(X)*np.mean(Y) 10 | return cov/(np.var(X)*np.var(Y))**0.5 11 | 12 | if '__main__' == __name__: 13 | # 加载数据 14 | data_train = load_data('bikeSpeedVsIq_train.txt') 15 | data_test = load_data('bikeSpeedVsIq_test.txt') 16 | 17 | dataset_test = np.matrix(data_test) 18 | m, n = dataset_test.shape 19 | testset = np.ones((m, n+1)) 20 | testset[:, 1:] = dataset_test 21 | X_test, y_test = testset[:, :-1], testset[:, -1] 22 | 23 | # 获取标准线性回归模型 24 | w, X, y = linear_regression(data_train) 25 | y_lr = X_test*w 26 | y_test = np.array(y_test).T 27 | y_lr = np.array(y_lr).T[0] 28 | corrcoef_lr = get_corrcoef(y_test, y_lr) 29 | print('linear regression correlation coefficient: {}'.format(corrcoef_lr)) 30 | 31 | # 获取模型树回归模型 32 | tree = create_tree(data_train, fleaf, ferr, opt={'err_tolerance': 1, 33 | 'n_tolerance': 4}) 34 | y_tree = [tree_predict([x], tree) for x in X_test[:, 1].tolist()] 35 | corrcoef_tree = get_corrcoef(np.array(y_tree), y_test) 36 | print('regression tree correlation coefficient: {}'.format(corrcoef_tree)) 37 | 38 | plt.scatter(np.array(data_train)[:, 0], np.array(data_train)[:, 1]) 39 | # 绘制线性回归曲线 40 | x = np.sort([i for i in X_test[:, 1].tolist()]) 41 | y = [np.dot([1.0, i], np.array(w.T).tolist()[0]) for i in x] 42 | plt.plot(x, y, c='r') 43 | 44 | # 绘制回归树回归曲线 45 | y = [tree_predict([i], tree) for i in x] 46 | plt.plot(x, y, c='y') 47 | plt.show() 48 | 49 | -------------------------------------------------------------------------------- /naive_bayes/bayes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | 8 | class NaiveBayesClassifier(object): 9 | ''' 朴素贝叶斯分类器 10 | ''' 11 | 12 | def train(self, dataset, classes): 13 | ''' 训练朴素贝叶斯模型 14 | 15 | :param dataset: 所有的文档数据向量 16 | :type dataset: MxN matrix containing all doc vectors. 17 | 18 | :param classes: 所有文档的类型 19 | :type classes: 1xN list 20 | 21 | :return cond_probs: 训练得到的条件概率矩阵 22 | :type cond_probs: MxK matrix 23 | 24 | :return cls_probs: 各种类型的概率 25 | :type cls_probs: 1xK list 26 | ''' 27 | # 按照不同类型记性分类 28 | sub_datasets = defaultdict(lambda: []) 29 | cls_cnt = defaultdict(lambda: 0) 30 | 31 | for doc_vect, cls in zip(dataset, classes): 32 | sub_datasets[cls].append(doc_vect) 33 | cls_cnt[cls] += 1 34 | 35 | # 计算类型概率 36 | cls_probs = {k: v/len(classes) for k, v in cls_cnt.items()} 37 | 38 | # 计算不同类型的条件概率 39 | cond_probs = {} 40 | dataset = np.array(dataset) 41 | for cls, sub_dataset in sub_datasets.items(): 42 | sub_dataset = np.array(sub_dataset) 43 | # Improve the classifier. 44 | cond_prob_vect = np.log((np.sum(sub_dataset, axis=0) + 1)/(np.sum(dataset) + 2)) 45 | cond_probs[cls] = cond_prob_vect 46 | 47 | return cond_probs, cls_probs 48 | 49 | def classify(self, doc_vect, cond_probs, cls_probs): 50 | ''' 使用朴素贝叶斯将doc_vect进行分类. 51 | ''' 52 | pred_probs = {} 53 | for cls, cls_prob in cls_probs.items(): 54 | cond_prob_vect = cond_probs[cls] 55 | pred_probs[cls] = np.sum(cond_prob_vect*doc_vect) + np.log(cls_prob) 56 | return max(pred_probs, key=pred_probs.get) 57 | 58 | -------------------------------------------------------------------------------- /logistic_regression/logreg_stoch_grad_ascent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from logreg_grad_ascent import LogisticRegressionClassifier as BaseClassifer 10 | from logreg_grad_ascent import load_data, snapshot 11 | 12 | class LogisticRegressionClassifier(BaseClassifer): 13 | 14 | def stoch_gradient_ascent(self, dataset, labels, max_iter=150): 15 | ''' 使用随机梯度上升算法优化Logistic回归模型参数 16 | ''' 17 | dataset = np.matrix(dataset) 18 | m, n = dataset.shape 19 | w = np.matrix(np.ones((n, 1))) 20 | ws = [] 21 | 22 | for i in range(max_iter): 23 | data_indices = list(range(m)) 24 | random.shuffle(data_indices) 25 | for j, idx in enumerate(data_indices): 26 | data, label = dataset[idx], labels[idx] 27 | error = label - self.sigmoid((data*w).tolist()[0][0]) 28 | alpha = 4/(1 + j + i) + 0.01 29 | w += alpha*data.T*error 30 | ws.append(w.T.tolist()[0]) 31 | 32 | self.w = w 33 | 34 | return w, np.array(ws) 35 | 36 | if '__main__' == __name__: 37 | clf = LogisticRegressionClassifier() 38 | dataset, labels = load_data('testSet.txt') 39 | w, ws = clf.stoch_gradient_ascent(dataset, labels, max_iter=500) 40 | m, n = ws.shape 41 | 42 | # 绘制分割线 43 | for i, w in enumerate(ws): 44 | if i % (m//10) == 0: 45 | print('{}.png saved'.format(i)) 46 | snapshot(w.tolist(), dataset, labels, '{}.png'.format(i)) 47 | 48 | fig = plt.figure() 49 | for i in range(n): 50 | label = 'w{}'.format(i) 51 | ax = fig.add_subplot(n, 1, i+1) 52 | ax.plot(ws[:, i], label=label) 53 | ax.legend() 54 | 55 | fig.savefig('stoch_grad_ascent_params.png') 56 | -------------------------------------------------------------------------------- /linear_regression/standard_linear_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def load_data(filename): 9 | ''' 加载数据 10 | ''' 11 | X, Y = [], [] 12 | with open(filename, 'r') as f: 13 | for line in f: 14 | splited_line = [float(i) for i in line.split()] 15 | x, y = splited_line[: -1], splited_line[-1] 16 | X.append(x) 17 | Y.append(y) 18 | X, Y = np.matrix(X), np.matrix(Y).T 19 | return X, Y 20 | 21 | def standarize(X): 22 | ''' 中心化 & 标准化数据 (零均值, 单位标准差) 23 | ''' 24 | std_deviation = np.std(X, 0) 25 | mean = np.mean(X, 0) 26 | return (X - mean)/std_deviation 27 | 28 | def std_linreg(X, Y): 29 | xTx = X.T*X 30 | if np.linalg.det(xTx) == 0: 31 | print('xTx is a singular matrix') 32 | return 33 | return xTx.I*X.T*Y 34 | 35 | def get_corrcoef(X, Y): 36 | # X Y 的协方差 37 | cov = np.mean(X*Y) - np.mean(X)*np.mean(Y) 38 | return cov/(np.var(X)*np.var(Y))**0.5 39 | 40 | if '__main__' == __name__: 41 | # 加载数据 42 | X, Y = load_data('abalone.txt') 43 | X, Y = standarize(X), standarize(Y) 44 | w = std_linreg(X, Y) 45 | Y_prime = X*w 46 | 47 | print('w: {}'.format(w)) 48 | 49 | # 计算相关系数 50 | corrcoef = get_corrcoef(np.array(Y.reshape(1, -1)), 51 | np.array(Y_prime.reshape(1, -1))) 52 | print('Correlation coeffient: {}'.format(corrcoef)) 53 | 54 | #fig = plt.figure() 55 | #ax = fig.add_subplot(111) 56 | 57 | ## 绘制数据点 58 | #x = X[:, 1].reshape(1, -1).tolist()[0] 59 | #y = Y.reshape(1, -1).tolist()[0] 60 | #ax.scatter(x, y) 61 | 62 | ## 绘制拟合直线 63 | #x1, x2 = min(x), max(x) 64 | #y1 = (np.matrix([1, x1])*w).tolist()[0][0] 65 | #y2 = (np.matrix([1, x2])*w).tolist()[0][0] 66 | #ax.plot([x1, x2], [y1, y2], c='r') 67 | 68 | #plt.show() 69 | 70 | -------------------------------------------------------------------------------- /linear_regression/ridge_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from math import exp 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from standard_linear_regression import load_data, get_corrcoef, standarize 10 | 11 | def ridge_regression(X, y, lambd=0.2): 12 | ''' 获取岭回归系数 13 | ''' 14 | XTX = X.T*X 15 | m, _ = XTX.shape 16 | I = np.matrix(np.eye(m)) 17 | w = (XTX + lambd*I).I*X.T*y 18 | return w 19 | 20 | def ridge_traj(X, y, ntest=30): 21 | ''' 获取岭轨迹矩阵 22 | ''' 23 | _, n = X.shape 24 | ws = np.zeros((ntest, n)) 25 | for i in range(ntest): 26 | w = ridge_regression(X, y, lambd=exp(i-10)) 27 | ws[i, :] = w.T 28 | return ws 29 | 30 | if '__main__' == __name__: 31 | ntest = 30 32 | # 加载数据 33 | X, y = load_data('abalone.txt') 34 | 35 | # 中心化 & 标准化 36 | X, y = standarize(X), standarize(y) 37 | 38 | # 测试数据和训练数据 39 | w_test, errors = [], [] 40 | for i in range(ntest): 41 | lambd = exp(i - 10) 42 | # 训练数据 43 | X_train, y_train = X[: 180, :], y[: 180, :] 44 | # 测试数据 45 | X_test, y_test = X[180: -1, :], y[180: -1, :] 46 | 47 | # 岭回归系数 48 | w = ridge_regression(X_train, y_train, lambd) 49 | error = np.std(X_test*w - y_test) 50 | w_test.append(w) 51 | errors.append(error) 52 | 53 | # 选择误差最小的回归系数 54 | w_best, e_best = min(zip(w_test, errors), key=lambda x: x[1]) 55 | print('Best w: {}, best error: {}'.format(w_best, e_best)) 56 | 57 | y_prime = X*w_best 58 | # 计算相关系数 59 | corrcoef = get_corrcoef(np.array(y.reshape(1, -1)), 60 | np.array(y_prime.reshape(1, -1))) 61 | print('Correlation coefficient: {}'.format(corrcoef)) 62 | 63 | # 绘制岭轨迹 64 | ws = ridge_traj(X, y, ntest) 65 | fig = plt.figure() 66 | ax = fig.add_subplot(111) 67 | 68 | lambdas = [i-10 for i in range(ntest)] 69 | ax.plot(lambdas, ws) 70 | 71 | plt.show() 72 | 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLBox 2 | Machine Learning Algorithms implementations 3 | 4 | # Blogs 5 | - [机器学习算法实践-决策树(Decision Tree)](http://pytlab.github.io/2017/07/09/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E5%86%B3%E7%AD%96%E6%A0%91/) 6 | - [机器学习算法实践-朴素贝叶斯(Naive Bayes)](http://pytlab.github.io/2017/07/11/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E5%AE%9E%E8%B7%B5-%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF-Naive-Bayes/) 7 | - [机器学习算法实践-Logistic回归与梯度上升算法(上)](http://pytlab.github.io/2017/07/13/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-Logistic%E5%9B%9E%E5%BD%92%E4%B8%8E%E6%A2%AF%E5%BA%A6%E4%B8%8A%E5%8D%87%E7%AE%97%E6%B3%95-%E4%B8%8A/) 8 | - [机器学习算法实践-Logistic回归与梯度上升算法(下)](http://pytlab.github.io/2017/07/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-Logistic%E5%9B%9E%E5%BD%92%E4%B8%8E%E6%A2%AF%E5%BA%A6%E4%B8%8A%E5%8D%87%E7%AE%97%E6%B3%95-%E4%B8%8B/) 9 | - [机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.github.io/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/) 10 | - [机器学习算法实践-SVM核函数和软间隔](http://pytlab.github.io/2017/08/30/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-SVM%E6%A0%B8%E5%87%BD%E6%95%B0%E5%92%8C%E8%BD%AF%E9%97%B4%E9%9A%94/) 11 | - [机器学习算法实践-SVM中的SMO算法](http://pytlab.github.io/2017/09/01/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-SVM%E4%B8%AD%E7%9A%84SMO%E7%AE%97%E6%B3%95/) 12 | - [机器学习算法实践-Platt SMO和遗传算法优化SVM](http://pytlab.github.io/2017/10/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-Platt-SMO%E5%92%8C%E9%81%97%E4%BC%A0%E7%AE%97%E6%B3%95%E4%BC%98%E5%8C%96SVM/) 13 | - [机器学习算法实践-标准与局部加权线性回归](http://pytlab.github.io/2017/10/24/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%A0%87%E5%87%86%E4%B8%8E%E5%B1%80%E9%83%A8%E5%8A%A0%E6%9D%83%E7%BA%BF%E6%80%A7%E5%9B%9E%E5%BD%92/) 14 | - [机器学习算法实践-岭回归和LASSO](http://pytlab.github.io/2017/10/27/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E5%AE%9E%E8%B7%B5-%E5%B2%AD%E5%9B%9E%E5%BD%92%E5%92%8CLASSO%E5%9B%9E%E5%BD%92/) 15 | - [机器学习算法实践-树回归](http://pytlab.github.io/2017/11/03/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%A0%91%E5%9B%9E%E5%BD%92/) 16 | -------------------------------------------------------------------------------- /decision_tree/sms_tree_2.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "8fbb40df-9b8c-4525-a34a-0ec254360649" [label="call"]; 3 | "9c2cf1a6-e34b-4f3c-9cc0-17a4e20f12f7" [label="to"]; 4 | "ef9e9738-2596-4bdf-a42d-28d7f5471ca7" [label="your"]; 5 | "626c0a8b-c1fe-42d9-ad4f-b79e03cf82f7" [label="from"]; 6 | "56bdce7c-b23c-4d52-a802-1c28377ec7f5" [label="explicit"]; 7 | "ebe24cea-1310-40fe-b164-25032c942aec" [label="ham"]; 8 | "1a56632b-860b-4ace-b604-59b9c3b06405" [label="spam"]; 9 | "d7636d96-6f9e-4883-a581-c8919088cbf2" [label="spam"]; 10 | "1d1933b4-12e1-41ea-b6c1-46f8bacb851c" [label="spam"]; 11 | "ac8ca11e-10f5-4a3f-8c1d-e31933d74a8d" [label="when"]; 12 | "00cb082b-b9c3-4417-9d25-209f2b4957c8" [label="spam"]; 13 | "b7cc5eda-0d6a-4893-ba1d-78641ed8a949" [label="ham"]; 14 | "577ef6a5-eb97-4dc1-9fae-741253db33aa" [label="dead"]; 15 | "ae9e2b6c-1bdb-4cdb-aaea-01f0d3e138c6" [label="spam"]; 16 | "6c303284-fb0a-44e7-b92d-dcae4ffd828d" [label="ham"]; 17 | "8fbb40df-9b8c-4525-a34a-0ec254360649" -> "9c2cf1a6-e34b-4f3c-9cc0-17a4e20f12f7" [label="0"]; 18 | "9c2cf1a6-e34b-4f3c-9cc0-17a4e20f12f7" -> "ef9e9738-2596-4bdf-a42d-28d7f5471ca7" [label="0"]; 19 | "ef9e9738-2596-4bdf-a42d-28d7f5471ca7" -> "626c0a8b-c1fe-42d9-ad4f-b79e03cf82f7" [label="0"]; 20 | "626c0a8b-c1fe-42d9-ad4f-b79e03cf82f7" -> "56bdce7c-b23c-4d52-a802-1c28377ec7f5" [label="0"]; 21 | "56bdce7c-b23c-4d52-a802-1c28377ec7f5" -> "ebe24cea-1310-40fe-b164-25032c942aec" [label="0"]; 22 | "56bdce7c-b23c-4d52-a802-1c28377ec7f5" -> "1a56632b-860b-4ace-b604-59b9c3b06405" [label="1"]; 23 | "626c0a8b-c1fe-42d9-ad4f-b79e03cf82f7" -> "d7636d96-6f9e-4883-a581-c8919088cbf2" [label="1"]; 24 | "ef9e9738-2596-4bdf-a42d-28d7f5471ca7" -> "1d1933b4-12e1-41ea-b6c1-46f8bacb851c" [label="1"]; 25 | "9c2cf1a6-e34b-4f3c-9cc0-17a4e20f12f7" -> "ac8ca11e-10f5-4a3f-8c1d-e31933d74a8d" [label="1"]; 26 | "ac8ca11e-10f5-4a3f-8c1d-e31933d74a8d" -> "00cb082b-b9c3-4417-9d25-209f2b4957c8" [label="0"]; 27 | "ac8ca11e-10f5-4a3f-8c1d-e31933d74a8d" -> "b7cc5eda-0d6a-4893-ba1d-78641ed8a949" [label="1"]; 28 | "8fbb40df-9b8c-4525-a34a-0ec254360649" -> "577ef6a5-eb97-4dc1-9fae-741253db33aa" [label="1"]; 29 | "577ef6a5-eb97-4dc1-9fae-741253db33aa" -> "ae9e2b6c-1bdb-4cdb-aaea-01f0d3e138c6" [label="0"]; 30 | "577ef6a5-eb97-4dc1-9fae-741253db33aa" -> "6c303284-fb0a-44e7-b92d-dcae4ffd828d" [label="1"]; 31 | } -------------------------------------------------------------------------------- /decision_tree/sms_tree.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "959b4c0c-1821-446d-94a1-c619c2decfcd" [label="call"]; 3 | "18665160-b058-437f-9b2e-05df2eb55661" [label="to"]; 4 | "2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" [label="your"]; 5 | "bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" [label="areyouunique"]; 6 | "ca091fc7-8a4e-4970-9ec3-485a4628ad29" [label="02073162414"]; 7 | "aac20872-1aac-499d-b2b5-caf0ef56eff3" [label="ham"]; 8 | "18aa8685-a6e8-4d76-bad5-ccea922bb14d" [label="spam"]; 9 | "3f7f30b1-4dbb-4459-9f25-358ad3c6d50b" [label="spam"]; 10 | "44d1f972-cd97-4636-b6e6-a389bf560656" [label="spam"]; 11 | "7f3c8562-69b5-47a9-8ee4-898bd4b6b506" [label="i"]; 12 | "a6f22325-8841-4a81-bc04-4e7485117aa1" [label="spam"]; 13 | "c181fe42-fd3c-48db-968a-502f8dd462a4" [label="ldn"]; 14 | "51b9477a-0326-4774-8622-24d1d869a283" [label="ham"]; 15 | "16f6aecd-c675-4291-867c-6c64d27eb3fc" [label="spam"]; 16 | "adb05303-813a-4fe0-bf98-c319eb70be48" [label="spam"]; 17 | "959b4c0c-1821-446d-94a1-c619c2decfcd" -> "18665160-b058-437f-9b2e-05df2eb55661" [label="0"]; 18 | "18665160-b058-437f-9b2e-05df2eb55661" -> "2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" [label="0"]; 19 | "2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" -> "bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" [label="0"]; 20 | "bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" -> "ca091fc7-8a4e-4970-9ec3-485a4628ad29" [label="0"]; 21 | "ca091fc7-8a4e-4970-9ec3-485a4628ad29" -> "aac20872-1aac-499d-b2b5-caf0ef56eff3" [label="0"]; 22 | "ca091fc7-8a4e-4970-9ec3-485a4628ad29" -> "18aa8685-a6e8-4d76-bad5-ccea922bb14d" [label="1"]; 23 | "bcbcc17c-9e2a-4bd4-a039-6e51fde5f8fd" -> "3f7f30b1-4dbb-4459-9f25-358ad3c6d50b" [label="1"]; 24 | "2eb9860d-d241-45ca-85e6-cbd80fe2ebf7" -> "44d1f972-cd97-4636-b6e6-a389bf560656" [label="1"]; 25 | "18665160-b058-437f-9b2e-05df2eb55661" -> "7f3c8562-69b5-47a9-8ee4-898bd4b6b506" [label="1"]; 26 | "7f3c8562-69b5-47a9-8ee4-898bd4b6b506" -> "a6f22325-8841-4a81-bc04-4e7485117aa1" [label="0"]; 27 | "7f3c8562-69b5-47a9-8ee4-898bd4b6b506" -> "c181fe42-fd3c-48db-968a-502f8dd462a4" [label="1"]; 28 | "c181fe42-fd3c-48db-968a-502f8dd462a4" -> "51b9477a-0326-4774-8622-24d1d869a283" [label="0"]; 29 | "c181fe42-fd3c-48db-968a-502f8dd462a4" -> "16f6aecd-c675-4291-867c-6c64d27eb3fc" [label="1"]; 30 | "959b4c0c-1821-446d-94a1-c619c2decfcd" -> "adb05303-813a-4fe0-bf98-c319eb70be48" [label="1"]; 31 | } -------------------------------------------------------------------------------- /decision_tree/lenses.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "99d3b650-7557-420c-be5f-037403909eef" [label="tearRate"]; 3 | "ccf5c62e-14ca-4cef-9525-4b8f026622dc" [label="no lenses"]; 4 | "6a72f3f9-51ce-4433-b052-34765c65a61e" [label="astigmatic"]; 5 | "91ea78df-9cfd-4334-a592-1c8b3c193f0d" [label="age"]; 6 | "b5d2e2b7-241b-4c46-a56b-61ba9a1e7678" [label="soft"]; 7 | "62193a33-c49d-4bce-b820-1613685e09ce" [label="soft"]; 8 | "01240d64-7b96-40fc-9a4b-185cc0fca9d6" [label="prescript"]; 9 | "5571119a-43b5-414e-9bf5-c9c62a9dee8c" [label="soft"]; 10 | "087246f9-495f-44ef-8ea0-5043b238c1c1" [label="no lenses"]; 11 | "c0b04ca3-692d-4498-8292-165ed4997ce5" [label="prescript"]; 12 | "8f6cfe1f-a0ea-46de-a456-f3f8b35bca8d" [label="age"]; 13 | "4d2b5c7f-e85e-4d44-8da9-0d88de048430" [label="hard"]; 14 | "cdc375a5-561f-48c9-a847-ccc73f1cc44c" [label="no lenses"]; 15 | "4600fda0-b8a8-45cc-8174-d554de9b7e84" [label="no lenses"]; 16 | "08a19fa5-952c-4ab3-a283-dbfe2e3e5870" [label="hard"]; 17 | "99d3b650-7557-420c-be5f-037403909eef" -> "ccf5c62e-14ca-4cef-9525-4b8f026622dc" [label="reduced"]; 18 | "99d3b650-7557-420c-be5f-037403909eef" -> "6a72f3f9-51ce-4433-b052-34765c65a61e" [label="normal"]; 19 | "6a72f3f9-51ce-4433-b052-34765c65a61e" -> "91ea78df-9cfd-4334-a592-1c8b3c193f0d" [label="no"]; 20 | "91ea78df-9cfd-4334-a592-1c8b3c193f0d" -> "b5d2e2b7-241b-4c46-a56b-61ba9a1e7678" [label="young"]; 21 | "91ea78df-9cfd-4334-a592-1c8b3c193f0d" -> "62193a33-c49d-4bce-b820-1613685e09ce" [label="pre"]; 22 | "91ea78df-9cfd-4334-a592-1c8b3c193f0d" -> "01240d64-7b96-40fc-9a4b-185cc0fca9d6" [label="presbyopic"]; 23 | "01240d64-7b96-40fc-9a4b-185cc0fca9d6" -> "5571119a-43b5-414e-9bf5-c9c62a9dee8c" [label="hyper"]; 24 | "01240d64-7b96-40fc-9a4b-185cc0fca9d6" -> "087246f9-495f-44ef-8ea0-5043b238c1c1" [label="myope"]; 25 | "6a72f3f9-51ce-4433-b052-34765c65a61e" -> "c0b04ca3-692d-4498-8292-165ed4997ce5" [label="yes"]; 26 | "c0b04ca3-692d-4498-8292-165ed4997ce5" -> "8f6cfe1f-a0ea-46de-a456-f3f8b35bca8d" [label="hyper"]; 27 | "8f6cfe1f-a0ea-46de-a456-f3f8b35bca8d" -> "4d2b5c7f-e85e-4d44-8da9-0d88de048430" [label="young"]; 28 | "8f6cfe1f-a0ea-46de-a456-f3f8b35bca8d" -> "cdc375a5-561f-48c9-a847-ccc73f1cc44c" [label="pre"]; 29 | "8f6cfe1f-a0ea-46de-a456-f3f8b35bca8d" -> "4600fda0-b8a8-45cc-8174-d554de9b7e84" [label="presbyopic"]; 30 | "c0b04ca3-692d-4498-8292-165ed4997ce5" -> "08a19fa5-952c-4ab3-a283-dbfe2e3e5870" [label="myope"]; 31 | } -------------------------------------------------------------------------------- /linear_regression/lasso_regression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import itertools 5 | from math import exp 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from standard_linear_regression import load_data, standarize, get_corrcoef 11 | 12 | 13 | def lasso_regression(X, y, lambd=0.2, threshold=0.1): 14 | ''' 通过坐标下降(coordinate descent)法获取LASSO回归系数 15 | ''' 16 | # 计算残差平方和 17 | rss = lambda X, y, w: (y - X*w).T*(y - X*w) 18 | 19 | # 初始化回归系数w. 20 | m, n = X.shape 21 | w = np.matrix(np.zeros((n, 1))) 22 | r = rss(X, y, w) 23 | 24 | # 使用坐标下降法优化回归系数w 25 | niter = itertools.count(1) 26 | 27 | for it in niter: 28 | for k in range(n): 29 | # 计算常量值z_k和p_k 30 | z_k = (X[:, k].T*X[:, k])[0, 0] 31 | p_k = 0 32 | for i in range(m): 33 | p_k += X[i, k]*(y[i, 0] - sum([X[i, j]*w[j, 0] for j in range(n) if j != k])) 34 | 35 | if p_k < -lambd/2: 36 | w_k = (p_k + lambd/2)/z_k 37 | elif p_k > lambd/2: 38 | w_k = (p_k - lambd/2)/z_k 39 | else: 40 | w_k = 0 41 | 42 | w[k, 0] = w_k 43 | 44 | r_prime = rss(X, y, w) 45 | delta = abs(r_prime - r)[0, 0] 46 | r = r_prime 47 | #print('Iteration: {}, delta = {}'.format(it, delta)) 48 | 49 | if delta < threshold: 50 | break 51 | 52 | return w 53 | 54 | def lasso_traj(X, y, ntest=30): 55 | ''' 获取回归系数轨迹矩阵 56 | ''' 57 | _, n = X.shape 58 | ws = np.zeros((ntest, n)) 59 | for i in range(ntest): 60 | w = lasso_regression(X, y, lambd=exp(i-10)) 61 | ws[i, :] = w.T 62 | print('lambda = e^({}), w = {}'.format(i-10, w.T[0, :])) 63 | return ws 64 | 65 | if '__main__' == __name__: 66 | X, y = load_data('abalone.txt') 67 | X, y = standarize(X), standarize(y) 68 | # w = lasso_regression(X, y, lambd=10) 69 | # 70 | # y_prime = X*w 71 | # # 计算相关系数 72 | # corrcoef = get_corrcoef(np.array(y.reshape(1, -1)), 73 | # np.array(y_prime.reshape(1, -1))) 74 | # print('Correlation coefficient: {}'.format(corrcoef)) 75 | 76 | 77 | ntest = 30 78 | 79 | # 绘制轨迹 80 | ws = lasso_traj(X, y, ntest) 81 | fig = plt.figure() 82 | ax = fig.add_subplot(111) 83 | 84 | lambdas = [i-10 for i in range(ntest)] 85 | ax.plot(lambdas, ws) 86 | 87 | plt.show() 88 | 89 | -------------------------------------------------------------------------------- /logistic_regression/testSet.txt: -------------------------------------------------------------------------------- 1 | -0.017612 14.053064 0 2 | -1.395634 4.662541 1 3 | -0.752157 6.538620 0 4 | -1.322371 7.152853 0 5 | 0.423363 11.054677 0 6 | 0.406704 7.067335 1 7 | 0.667394 12.741452 0 8 | -2.460150 6.866805 1 9 | 0.569411 9.548755 0 10 | -0.026632 10.427743 0 11 | 0.850433 6.920334 1 12 | 1.347183 13.175500 0 13 | 1.176813 3.167020 1 14 | -1.781871 9.097953 0 15 | -0.566606 5.749003 1 16 | 0.931635 1.589505 1 17 | -0.024205 6.151823 1 18 | -0.036453 2.690988 1 19 | -0.196949 0.444165 1 20 | 1.014459 5.754399 1 21 | 1.985298 3.230619 1 22 | -1.693453 -0.557540 1 23 | -0.576525 11.778922 0 24 | -0.346811 -1.678730 1 25 | -2.124484 2.672471 1 26 | 1.217916 9.597015 0 27 | -0.733928 9.098687 0 28 | -3.642001 -1.618087 1 29 | 0.315985 3.523953 1 30 | 1.416614 9.619232 0 31 | -0.386323 3.989286 1 32 | 0.556921 8.294984 1 33 | 1.224863 11.587360 0 34 | -1.347803 -2.406051 1 35 | 1.196604 4.951851 1 36 | 0.275221 9.543647 0 37 | 0.470575 9.332488 0 38 | -1.889567 9.542662 0 39 | -1.527893 12.150579 0 40 | -1.185247 11.309318 0 41 | -0.445678 3.297303 1 42 | 1.042222 6.105155 1 43 | -0.618787 10.320986 0 44 | 1.152083 0.548467 1 45 | 0.828534 2.676045 1 46 | -1.237728 10.549033 0 47 | -0.683565 -2.166125 1 48 | 0.229456 5.921938 1 49 | -0.959885 11.555336 0 50 | 0.492911 10.993324 0 51 | 0.184992 8.721488 0 52 | -0.355715 10.325976 0 53 | -0.397822 8.058397 0 54 | 0.824839 13.730343 0 55 | 1.507278 5.027866 1 56 | 0.099671 6.835839 1 57 | -0.344008 10.717485 0 58 | 1.785928 7.718645 1 59 | -0.918801 11.560217 0 60 | -0.364009 4.747300 1 61 | -0.841722 4.119083 1 62 | 0.490426 1.960539 1 63 | -0.007194 9.075792 0 64 | 0.356107 12.447863 0 65 | 0.342578 12.281162 0 66 | -0.810823 -1.466018 1 67 | 2.530777 6.476801 1 68 | 1.296683 11.607559 0 69 | 0.475487 12.040035 0 70 | -0.783277 11.009725 0 71 | 0.074798 11.023650 0 72 | -1.337472 0.468339 1 73 | -0.102781 13.763651 0 74 | -0.147324 2.874846 1 75 | 0.518389 9.887035 0 76 | 1.015399 7.571882 0 77 | -1.658086 -0.027255 1 78 | 1.319944 2.171228 1 79 | 2.056216 5.019981 1 80 | -0.851633 4.375691 1 81 | -1.510047 6.061992 0 82 | -1.076637 -3.181888 1 83 | 1.821096 10.283990 0 84 | 3.010150 8.401766 1 85 | -1.099458 1.688274 1 86 | -0.834872 -1.733869 1 87 | -0.846637 3.849075 1 88 | 1.400102 12.628781 0 89 | 1.752842 5.468166 1 90 | 0.078557 0.059736 1 91 | 0.089392 -0.715300 1 92 | 1.825662 12.693808 0 93 | 0.197445 9.744638 0 94 | 0.126117 0.922311 1 95 | -0.679797 1.220530 1 96 | 0.677983 2.556666 1 97 | 0.761349 10.693862 0 98 | -2.168791 0.143632 1 99 | 1.388610 9.341997 0 100 | 0.317029 14.739025 0 101 | -------------------------------------------------------------------------------- /support_vector_machine/testSet.txt: -------------------------------------------------------------------------------- 1 | 3.542485 1.977398 -1 2 | 3.018896 2.556416 -1 3 | 7.551510 -1.580030 1 4 | 2.114999 -0.004466 -1 5 | 8.127113 1.274372 1 6 | 7.108772 -0.986906 1 7 | 8.610639 2.046708 1 8 | 2.326297 0.265213 -1 9 | 3.634009 1.730537 -1 10 | 0.341367 -0.894998 -1 11 | 3.125951 0.293251 -1 12 | 2.123252 -0.783563 -1 13 | 0.887835 -2.797792 -1 14 | 7.139979 -2.329896 1 15 | 1.696414 -1.212496 -1 16 | 8.117032 0.623493 1 17 | 8.497162 -0.266649 1 18 | 4.658191 3.507396 -1 19 | 8.197181 1.545132 1 20 | 1.208047 0.213100 -1 21 | 1.928486 -0.321870 -1 22 | 2.175808 -0.014527 -1 23 | 7.886608 0.461755 1 24 | 3.223038 -0.552392 -1 25 | 3.628502 2.190585 -1 26 | 7.407860 -0.121961 1 27 | 7.286357 0.251077 1 28 | 2.301095 -0.533988 -1 29 | -0.232542 -0.547690 -1 30 | 3.457096 -0.082216 -1 31 | 3.023938 -0.057392 -1 32 | 8.015003 0.885325 1 33 | 8.991748 0.923154 1 34 | 7.916831 -1.781735 1 35 | 7.616862 -0.217958 1 36 | 2.450939 0.744967 -1 37 | 7.270337 -2.507834 1 38 | 1.749721 -0.961902 -1 39 | 1.803111 -0.176349 -1 40 | 8.804461 3.044301 1 41 | 1.231257 -0.568573 -1 42 | 2.074915 1.410550 -1 43 | -0.743036 -1.736103 -1 44 | 3.536555 3.964960 -1 45 | 8.410143 0.025606 1 46 | 7.382988 -0.478764 1 47 | 6.960661 -0.245353 1 48 | 8.234460 0.701868 1 49 | 8.168618 -0.903835 1 50 | 1.534187 -0.622492 -1 51 | 9.229518 2.066088 1 52 | 7.886242 0.191813 1 53 | 2.893743 -1.643468 -1 54 | 1.870457 -1.040420 -1 55 | 5.286862 -2.358286 1 56 | 6.080573 0.418886 1 57 | 2.544314 1.714165 -1 58 | 6.016004 -3.753712 1 59 | 0.926310 -0.564359 -1 60 | 0.870296 -0.109952 -1 61 | 2.369345 1.375695 -1 62 | 1.363782 -0.254082 -1 63 | 7.279460 -0.189572 1 64 | 1.896005 0.515080 -1 65 | 8.102154 -0.603875 1 66 | 2.529893 0.662657 -1 67 | 1.963874 -0.365233 -1 68 | 8.132048 0.785914 1 69 | 8.245938 0.372366 1 70 | 6.543888 0.433164 1 71 | -0.236713 -5.766721 -1 72 | 8.112593 0.295839 1 73 | 9.803425 1.495167 1 74 | 1.497407 -0.552916 -1 75 | 1.336267 -1.632889 -1 76 | 9.205805 -0.586480 1 77 | 1.966279 -1.840439 -1 78 | 8.398012 1.584918 1 79 | 7.239953 -1.764292 1 80 | 7.556201 0.241185 1 81 | 9.015509 0.345019 1 82 | 8.266085 -0.230977 1 83 | 8.545620 2.788799 1 84 | 9.295969 1.346332 1 85 | 2.404234 0.570278 -1 86 | 2.037772 0.021919 -1 87 | 1.727631 -0.453143 -1 88 | 1.979395 -0.050773 -1 89 | 8.092288 -1.372433 1 90 | 1.667645 0.239204 -1 91 | 9.854303 1.365116 1 92 | 7.921057 -1.327587 1 93 | 8.500757 1.492372 1 94 | 1.339746 -0.291183 -1 95 | 3.107511 0.758367 -1 96 | 2.609525 0.902979 -1 97 | 3.263585 1.367898 -1 98 | 2.912122 -0.202359 -1 99 | 1.731786 0.589096 -1 100 | 2.387003 1.573131 -1 101 | -------------------------------------------------------------------------------- /support_vector_machine/svm_ga.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 使用遗传算法框架GAFT优化SVM. 5 | 6 | GAFT项目地址: https://github.com/PytLab/gaft 7 | ''' 8 | 9 | import random 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | from gaft import GAEngine 15 | from gaft.components import GAIndividual, GAPopulation 16 | from gaft.operators import RouletteWheelSelection, UniformCrossover, FlipBitBigMutation 17 | 18 | from gaft.analysis.fitness_store import FitnessStore 19 | from gaft.analysis.console_output import ConsoleOutput 20 | 21 | 22 | def load_data(filename): 23 | dataset, labels = [], [] 24 | with open(filename, 'r') as f: 25 | for line in f: 26 | x, y, label = [float(i) for i in line.strip().split()] 27 | dataset.append([x, y]) 28 | labels.append(label) 29 | return dataset, labels 30 | 31 | def get_w(alphas, dataset, labels): 32 | ''' 通过已知数据点和拉格朗日乘子获得分割超平面参数w 33 | ''' 34 | alphas, dataset, labels = np.array(alphas), np.array(dataset), np.array(labels) 35 | yx = labels.reshape(1, -1).T*np.array([1, 1])*dataset 36 | w = np.dot(yx.T, alphas) 37 | 38 | return w.tolist() 39 | 40 | # Population definition. 41 | indv_template = GAIndividual(ranges=[(-2, 2), (-2, 2), (-5, 5)], 42 | encoding='binary', 43 | eps=[0.001, 0.001, 0.005]) 44 | population = GAPopulation(indv_template=indv_template, size=600).init() 45 | 46 | # Genetic operators. 47 | selection = RouletteWheelSelection() 48 | crossover = UniformCrossover(pc=0.8, pe=0.5) 49 | mutation = FlipBitBigMutation(pm=0.1, pbm=0.55, alpha=0.6) 50 | 51 | engine = GAEngine(population=population, selection=selection, 52 | crossover=crossover, mutation=mutation, 53 | analysis=[ConsoleOutput, FitnessStore]) 54 | 55 | # 加载数据 56 | dataset, labels = load_data('testSet.txt') 57 | 58 | @engine.fitness_register 59 | def fitness(indv): 60 | w, b = indv.variants[: -1], indv.variants[-1] 61 | min_dis = min([y*(np.dot(w, x) + b) for x, y in zip(dataset, labels)]) 62 | return float(min_dis) 63 | 64 | if '__main__' == __name__: 65 | engine.run(300) 66 | 67 | variants = engine.population.best_indv(engine.fitness).variants 68 | w = variants[: -1] 69 | b = variants[-1] 70 | 71 | # 分类数据点 72 | classified_pts = {'+1': [], '-1': []} 73 | for point, label in zip(dataset, labels): 74 | if label == 1.0: 75 | classified_pts['+1'].append(point) 76 | else: 77 | classified_pts['-1'].append(point) 78 | 79 | fig = plt.figure() 80 | ax = fig.add_subplot(111) 81 | 82 | # 绘制数据点 83 | for label, pts in classified_pts.items(): 84 | pts = np.array(pts) 85 | ax.scatter(pts[:, 0], pts[:, 1], label=label) 86 | 87 | # 绘制分割线 88 | x1, _ = max(dataset, key=lambda x: x[0]) 89 | x2, _ = min(dataset, key=lambda x: x[0]) 90 | a1, a2 = w 91 | y1, y2 = (-b - a1*x1)/a2, (-b - a1*x2)/a2 92 | ax.plot([x1, x2], [y1, y2]) 93 | 94 | plt.show() 95 | 96 | -------------------------------------------------------------------------------- /Reinforcement Learning/Calculating Transition Probabilities.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | " 1. Set of possible states : S = {s0,s1,s2,......,sn} \n", 8 | " 2. Initial State: s0 \n", 9 | " 3. Transition Model: T(s,s')" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "Let’s suppose we have a chain with only two states s0 and s1, where s0 is the initial state. The process is in s0 90% of the time and it can move to s1 the remaining 10% of the time. When the process is in state s1 it will remain there 50% of the time. Given this data we can create a Transition Matrix T as follows:\n", 17 | "T=[[0.90 0.10]\n", 18 | " [0.50 0.50]]" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "#### Computing the k-step transition probability:" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "T: [[0.9 0.1]\n", 38 | " [0.5 0.5]]\n", 39 | "T_5: [[0.83504 0.16496]\n", 40 | " [0.8248 0.1752 ]]\n", 41 | "T_25: [[0.83333333 0.16666667]\n", 42 | " [0.83333333 0.16666667]]\n", 43 | "T_50: [[0.83333333 0.16666667]\n", 44 | " [0.83333333 0.16666667]]\n", 45 | "T_100: [[0.83333333 0.16666667]\n", 46 | " [0.83333333 0.16666667]]\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "import numpy as np\n", 52 | "\n", 53 | "#Here we declare the Transition Matrix T\n", 54 | "T = np.array([[0.90, 0.10],\n", 55 | " [0.50, 0.50]])\n", 56 | "\n", 57 | "#Obtain T after 5 steps\n", 58 | "T_5 = np.linalg.matrix_power(T, 5)\n", 59 | "\n", 60 | "#Obtain T after 25 steps\n", 61 | "T_25 = np.linalg.matrix_power(T, 25)\n", 62 | "\n", 63 | "#Obtain T after 50 steps\n", 64 | "T_50 = np.linalg.matrix_power(T, 50)\n", 65 | "\n", 66 | "#Obtain T after 100 steps\n", 67 | "T_100 = np.linalg.matrix_power(T, 100)\n", 68 | "\n", 69 | "#Print the matrices\n", 70 | "print(\"T: \" + str(T))\n", 71 | "print(\"T_5: \" + str(T_5))\n", 72 | "print(\"T_25: \" + str(T_25))\n", 73 | "print(\"T_50: \" + str(T_50))\n", 74 | "print(\"T_100: \" + str(T_100))" 75 | ] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "Python 3", 81 | "language": "python", 82 | "name": "python3" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.5.4" 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 2 99 | } 100 | -------------------------------------------------------------------------------- /decision_tree/sms_tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 通过垃圾短信数据训练朴素贝叶斯模型,并进行留存交叉验证 5 | ''' 6 | 7 | import re 8 | import random 9 | import os 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | from trees import DecisionTreeClassifier 15 | 16 | ENCODING = 'ISO-8859-1' 17 | TRAIN_PERCENTAGE = 0.9 18 | 19 | def get_doc_vector(words, vocabulary): 20 | ''' 根据词汇表将文档中的词条转换成文档向量 21 | 22 | :param words: 文档中的词条列表 23 | :type words: list of str 24 | 25 | :param vocabulary: 总的词汇列表 26 | :type vocabulary: list of str 27 | 28 | :return doc_vect: 用于贝叶斯分析的文档向量 29 | :type doc_vect: list of int 30 | ''' 31 | doc_vect = [0]*len(vocabulary) 32 | 33 | for word in words: 34 | if word in vocabulary: 35 | idx = vocabulary.index(word) 36 | doc_vect[idx] = 1 37 | 38 | return doc_vect 39 | 40 | def parse_line(line): 41 | ''' 解析数据集中的每一行返回词条向量和短信类型. 42 | ''' 43 | cls = line.split(',')[-1].strip() 44 | content = ','.join(line.split(',')[: -1]) 45 | word_vect = [word.lower() for word in re.split(r'\W+', content) if word] 46 | return word_vect, cls 47 | 48 | def parse_file(filename): 49 | ''' 解析文件中的数据 50 | ''' 51 | vocabulary, word_vects, classes = [], [], [] 52 | with open(filename, 'r', encoding=ENCODING) as f: 53 | for line in f: 54 | if line: 55 | word_vect, cls = parse_line(line) 56 | vocabulary.extend(word_vect) 57 | word_vects.append(word_vect) 58 | classes.append(cls) 59 | vocabulary = list(set(vocabulary)) 60 | 61 | return vocabulary, word_vects, classes 62 | 63 | if '__main__' == __name__: 64 | clf = DecisionTreeClassifier() 65 | vocabulary, word_vects, classes = parse_file('english_big.txt') 66 | 67 | # 训练数据 & 测试数据 68 | ntest = int(len(classes)*(1-TRAIN_PERCENTAGE)) 69 | 70 | test_word_vects = [] 71 | test_classes = [] 72 | for i in range(ntest): 73 | idx = random.randint(0, len(word_vects)-1) 74 | test_word_vects.append(word_vects.pop(idx)) 75 | test_classes.append(classes.pop(idx)) 76 | 77 | train_word_vects = word_vects 78 | train_classes = classes 79 | 80 | train_dataset = [get_doc_vector(words, vocabulary) for words in train_word_vects] 81 | 82 | # 生成决策树 83 | if not os.path.exists('sms_tree.pkl'): 84 | clf.create_tree(train_dataset, train_classes, vocabulary) 85 | clf.dump_tree('sms_tree.pkl') 86 | else: 87 | clf.load_tree('sms_tree.pkl') 88 | 89 | # 测试模型 90 | error = 0 91 | for test_word_vect, test_cls in zip(test_word_vects, test_classes): 92 | test_data = get_doc_vector(test_word_vect, vocabulary) 93 | pred_cls = clf.classify(test_data, feat_names=vocabulary) 94 | if test_cls != pred_cls: 95 | print('Predict: {} -- Actual: {}'.format(pred_cls, test_cls)) 96 | error += 1 97 | 98 | print('Error Rate: {}'.format(error/len(test_classes))) 99 | 100 | -------------------------------------------------------------------------------- /logistic_regression/sms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 通过垃圾短信数据训练Logistic回归模型,并进行留存交叉验证 5 | ''' 6 | 7 | import re 8 | import random 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from logreg_stoch_grad_ascent import LogisticRegressionClassifier 14 | 15 | ENCODING = 'ISO-8859-1' 16 | TRAIN_PERCENTAGE = 0.9 17 | 18 | def get_doc_vector(words, vocabulary): 19 | ''' 根据词汇表将文档中的词条转换成文档向量 20 | 21 | :param words: 文档中的词条列表 22 | :type words: list of str 23 | 24 | :param vocabulary: 总的词汇列表 25 | :type vocabulary: list of str 26 | 27 | :return doc_vect: 用于贝叶斯分析的文档向量 28 | :type doc_vect: list of int 29 | ''' 30 | doc_vect = [0]*len(vocabulary) 31 | 32 | for word in words: 33 | if word in vocabulary: 34 | idx = vocabulary.index(word) 35 | doc_vect[idx] += 1 36 | 37 | return doc_vect 38 | 39 | def parse_line(line): 40 | ''' 解析数据集中的每一行返回词条向量和短信类型. 41 | ''' 42 | cls = line.split(',')[-1].strip() 43 | content = ','.join(line.split(',')[: -1]) 44 | word_vect = [word.lower() for word in re.split(r'\W+', content) if word] 45 | return word_vect, cls 46 | 47 | def parse_file(filename): 48 | ''' 解析文件中的数据 49 | ''' 50 | vocabulary, word_vects, classes = [], [], [] 51 | with open(filename, 'r', encoding=ENCODING) as f: 52 | for line in f: 53 | if line: 54 | word_vect, cls = parse_line(line) 55 | vocabulary.extend(word_vect) 56 | word_vects.append(word_vect) 57 | classes.append(cls) 58 | vocabulary = list(set(vocabulary)) 59 | 60 | return vocabulary, word_vects, classes 61 | 62 | if '__main__' == __name__: 63 | clf = LogisticRegressionClassifier() 64 | vocabulary, word_vects, classes = parse_file('english_big.txt') 65 | 66 | # 训练数据 & 测试数据 67 | ntest = int(len(classes)*(1-TRAIN_PERCENTAGE)) 68 | 69 | test_word_vects = [] 70 | test_classes = [] 71 | for i in range(ntest): 72 | idx = random.randint(0, len(word_vects)-1) 73 | test_word_vects.append(word_vects.pop(idx)) 74 | test_classes.append(classes.pop(idx)) 75 | 76 | train_word_vects = word_vects 77 | train_classes = classes 78 | 79 | # 将类型标签改为0\1 80 | f = lambda x: 1 if x == 'spam' else 0 81 | train_classes = list(map(f, train_classes)) 82 | test_classes = list(map(f, test_classes)) 83 | 84 | train_dataset = [get_doc_vector(words, vocabulary) for words in train_word_vects] 85 | 86 | # 训练LR模型 87 | clf.stoch_gradient_ascent(train_dataset, train_classes) 88 | 89 | # 测试模型 90 | error = 0 91 | for test_word_vect, test_cls in zip(test_word_vects, test_classes): 92 | test_data = get_doc_vector(test_word_vect, vocabulary) 93 | pred_cls = clf.classify(test_data) 94 | if test_cls != pred_cls: 95 | print('Predict: {} -- Actual: {}'.format(pred_cls, test_cls)) 96 | error += 1 97 | 98 | print('Error Rate: {}'.format(error/len(test_classes))) 99 | 100 | -------------------------------------------------------------------------------- /naive_bayes/sms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 通过垃圾短信数据训练朴素贝叶斯模型,并进行留存交叉验证 5 | ''' 6 | 7 | import re 8 | import random 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from bayes import NaiveBayesClassifier 14 | 15 | ENCODING = 'ISO-8859-1' 16 | TRAIN_PERCENTAGE = 0.9 17 | 18 | def get_doc_vector(words, vocabulary): 19 | ''' 根据词汇表将文档中的词条转换成文档向量 20 | 21 | :param words: 文档中的词条列表 22 | :type words: list of str 23 | 24 | :param vocabulary: 总的词汇列表 25 | :type vocabulary: list of str 26 | 27 | :return doc_vect: 用于贝叶斯分析的文档向量 28 | :type doc_vect: list of int 29 | ''' 30 | doc_vect = [0]*len(vocabulary) 31 | 32 | for word in words: 33 | if word in vocabulary: 34 | idx = vocabulary.index(word) 35 | doc_vect[idx] = 1 36 | 37 | return doc_vect 38 | 39 | def parse_line(line): 40 | ''' 解析数据集中的每一行返回词条向量和短信类型. 41 | ''' 42 | cls = line.split(',')[-1].strip() 43 | content = ','.join(line.split(',')[: -1]) 44 | word_vect = [word.lower() for word in re.split(r'\W+', content) if word] 45 | return word_vect, cls 46 | 47 | def parse_file(filename): 48 | ''' 解析文件中的数据 49 | ''' 50 | vocabulary, word_vects, classes = [], [], [] 51 | with open(filename, 'r', encoding=ENCODING) as f: 52 | for line in f: 53 | if line: 54 | word_vect, cls = parse_line(line) 55 | vocabulary.extend(word_vect) 56 | word_vects.append(word_vect) 57 | classes.append(cls) 58 | vocabulary = list(set(vocabulary)) 59 | 60 | return vocabulary, word_vects, classes 61 | 62 | if '__main__' == __name__: 63 | clf = NaiveBayesClassifier() 64 | vocabulary, word_vects, classes = parse_file('english_big.txt') 65 | 66 | # 训练数据 & 测试数据 67 | ntest = int(len(classes)*(1-TRAIN_PERCENTAGE)) 68 | 69 | test_word_vects = [] 70 | test_classes = [] 71 | for i in range(ntest): 72 | idx = random.randint(0, len(word_vects)-1) 73 | test_word_vects.append(word_vects.pop(idx)) 74 | test_classes.append(classes.pop(idx)) 75 | 76 | train_word_vects = word_vects 77 | train_classes = classes 78 | 79 | train_dataset = [get_doc_vector(words, vocabulary) for words in train_word_vects] 80 | 81 | # 训练贝叶斯模型 82 | cond_probs, cls_probs = clf.train(train_dataset, train_classes) 83 | 84 | # 测试模型 85 | error = 0 86 | for test_word_vect, test_cls in zip(test_word_vects, test_classes): 87 | test_data = get_doc_vector(test_word_vect, vocabulary) 88 | pred_cls = clf.classify(test_data, cond_probs, cls_probs) 89 | if test_cls != pred_cls: 90 | print('Predict: {} -- Actual: {}'.format(pred_cls, test_cls)) 91 | error += 1 92 | 93 | print('Error Rate: {}'.format(error/len(test_classes))) 94 | 95 | # 绘制不同类型的概率分布曲线 96 | fig = plt.figure() 97 | ax = fig.add_subplot(111) 98 | for cls, probs in cond_probs.items(): 99 | ax.scatter(np.arange(0, len(probs)), 100 | probs*cls_probs[cls], 101 | label=cls, 102 | alpha=0.3) 103 | ax.legend() 104 | 105 | plt.show() 106 | 107 | -------------------------------------------------------------------------------- /Reinforcement Learning/Defining Initial Distribution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Let us now define the initial distribution which represents the state of the system at k=0.\n", 8 | "Our system is composed of two states and we can model the initial distribution as a vector with two elements, the first element of the vector represents the probability of staying in the state s0 and the second element the probability of staying in state s1. Let’s suppose that we start from s0, the vector v representing the initial distribution will have this form:\n", 9 | "v=(1,0)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "Calculating the probability of being in a specific state after k iterations:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "v: [[1. 0.]]\n", 29 | "v_1: [[0.9 0.1]]\n", 30 | "v_5: [[0.83504 0.16496]]\n", 31 | "v_25: [[0.83333333 0.16666667]]\n", 32 | "v_50: [[0.83333333 0.16666667]]\n", 33 | "v_100: [[0.83333333 0.16666667]]\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "import numpy as np\n", 39 | "\n", 40 | "#Declare the initial distribution\n", 41 | "v = np.array([[1.0, 0.0]])\n", 42 | "\n", 43 | "#Declare the Transition Matrix T(this is the same matrix used as in the file'Calculating Transition Probabilities')\n", 44 | "T = np.array([[0.90, 0.10],\n", 45 | " [0.50, 0.50]])\n", 46 | "\n", 47 | "#Obtain T after 5 steps\n", 48 | "T_5 = np.linalg.matrix_power(T, 5)\n", 49 | "\n", 50 | "#Obtain T after 25 steps\n", 51 | "T_25 = np.linalg.matrix_power(T, 25)\n", 52 | "\n", 53 | "#Obtain T after 50 steps\n", 54 | "T_50 = np.linalg.matrix_power(T, 50)\n", 55 | "\n", 56 | "#Obtain T after 100 steps\n", 57 | "T_100 = np.linalg.matrix_power(T, 100)\n", 58 | "\n", 59 | "#Printing the initial distribution\n", 60 | "print(\"v: \" + str(v))\n", 61 | "print(\"v_1: \" + str(np.dot(v,T)))\n", 62 | "print(\"v_5: \" + str(np.dot(v,T_5)))\n", 63 | "print(\"v_25: \" + str(np.dot(v,T_25)))\n", 64 | "print(\"v_50: \" + str(np.dot(v,T_50)))\n", 65 | "print(\"v_100: \" + str(np.dot(v,T_100)))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "The result after 50 and 100 iterations are the same and v_50 is equal to v_100 no matter which starting distribution we have. The chain converged to equilibrium meaning that as the time progresses it forgets about the starting distribution." 73 | ] 74 | } 75 | ], 76 | "metadata": { 77 | "kernelspec": { 78 | "display_name": "Python 3", 79 | "language": "python", 80 | "name": "python3" 81 | }, 82 | "language_info": { 83 | "codemirror_mode": { 84 | "name": "ipython", 85 | "version": 3 86 | }, 87 | "file_extension": ".py", 88 | "mimetype": "text/x-python", 89 | "name": "python", 90 | "nbconvert_exporter": "python", 91 | "pygments_lexer": "ipython3", 92 | "version": "3.5.4" 93 | } 94 | }, 95 | "nbformat": 4, 96 | "nbformat_minor": 2 97 | } 98 | -------------------------------------------------------------------------------- /classification_and_regression_trees/model_tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from regression_tree import * 5 | 6 | def linear_regression(dataset): 7 | ''' 获取标准线性回归系数 8 | ''' 9 | dataset = np.matrix(dataset) 10 | # 分割数据并添加常数列 11 | X_ori, y = dataset[:, :-1], dataset[:, -1] 12 | X_ori, y = np.matrix(X_ori), np.matrix(y) 13 | m, n = X_ori.shape 14 | X = np.matrix(np.ones((m, n+1))) 15 | X[:, 1:] = X_ori 16 | 17 | # 回归系数 18 | w = (X.T*X).I*X.T*y 19 | return w, X, y 20 | 21 | def fleaf(dataset): 22 | ''' 计算给定数据集的线性回归系数 23 | ''' 24 | w, _, _ = linear_regression(dataset) 25 | return w 26 | 27 | def ferr(dataset): 28 | ''' 对给定数据集进行回归并计算误差 29 | ''' 30 | w, X, y = linear_regression(dataset) 31 | y_prime = X*w 32 | return np.var(y_prime - y) 33 | 34 | def get_nodes_edges(tree, root_node=None): 35 | ''' 返回树中所有节点和边 36 | ''' 37 | Node = namedtuple('Node', ['id', 'label']) 38 | Edge = namedtuple('Edge', ['start', 'end']) 39 | 40 | nodes, edges = [], [] 41 | 42 | if type(tree) is not dict: 43 | return nodes, edges 44 | 45 | if root_node is None: 46 | label = '{}: {}'.format(tree['feat_idx'], tree['feat_val']) 47 | root_node = Node._make([uuid.uuid4(), label]) 48 | nodes.append(root_node) 49 | 50 | for sub_tree in (tree['left'], tree['right']): 51 | if type(sub_tree) is dict: 52 | node_label = '{}: {}'.format(sub_tree['feat_idx'], sub_tree['feat_val']) 53 | else: 54 | node_label = '{}'.format(np.array(sub_tree.T).tolist()[0]) 55 | sub_node = Node._make([uuid.uuid4(), node_label]) 56 | nodes.append(sub_node) 57 | 58 | edge = Edge._make([root_node, sub_node]) 59 | edges.append(edge) 60 | 61 | sub_nodes, sub_edges = get_nodes_edges(sub_tree, root_node=sub_node) 62 | nodes.extend(sub_nodes) 63 | edges.extend(sub_edges) 64 | 65 | return nodes, edges 66 | 67 | def dotify(tree): 68 | ''' 获取树的Graphviz Dot文件的内容 69 | ''' 70 | content = 'digraph decision_tree {\n' 71 | nodes, edges = get_nodes_edges(tree) 72 | 73 | for node in nodes: 74 | content += ' "{}" [label="{}"];\n'.format(node.id, node.label) 75 | 76 | for edge in edges: 77 | start, end = edge.start, edge.end 78 | content += ' "{}" -> "{}";\n'.format(start.id, end.id) 79 | content += '}' 80 | 81 | return content 82 | 83 | def tree_predict(data, tree): 84 | if type(tree) is not dict: 85 | w = tree 86 | y = np.matrix(data)*w 87 | return y[0, 0] 88 | 89 | feat_idx, feat_val = tree['feat_idx'], tree['feat_val'] 90 | if data[feat_idx+1] < feat_val: 91 | return tree_predict(data, tree['left']) 92 | else: 93 | return tree_predict(data, tree['right']) 94 | 95 | if '__main__' == __name__: 96 | dataset = load_data('exp2.txt') 97 | tree = create_tree(dataset, fleaf, ferr, opt={'err_tolerance': 0.1, 'n_tolerance': 4}) 98 | 99 | # 生成模型树dot文件 100 | with open('exp2.dot', 'w') as f: 101 | f.write(dotify(tree)) 102 | 103 | dataset = np.array(dataset) 104 | # 绘制散点图 105 | plt.scatter(dataset[:, 0], dataset[:, 1]) 106 | 107 | # 绘制回归曲线 108 | x = np.sort(dataset[:, 0]) 109 | y = [tree_predict([1.0] + [i], tree) for i in x] 110 | plt.plot(x, y, c='r') 111 | plt.show() 112 | 113 | -------------------------------------------------------------------------------- /logistic_regression/logreg_grad_ascent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from math import exp 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | class LogisticRegressionClassifier(object): 11 | ''' 使用梯度上升算法Logistic回归分类器 12 | ''' 13 | 14 | @staticmethod 15 | def sigmoid(x): 16 | ''' Sigmoid 阶跃函数 17 | ''' 18 | return 1.0/(1 + np.exp(-x)) 19 | 20 | def gradient_ascent(self, dataset, labels, max_iter=10000): 21 | ''' 使用梯度上升优化Logistic回归模型参数 22 | 23 | :param dataset: 数据特征矩阵 24 | :type dataset: MxN numpy matrix 25 | 26 | :param labels: 数据集对应的类型向量 27 | :type labels: Nx1 numpy matrix 28 | ''' 29 | dataset = np.matrix(dataset) 30 | vlabels = np.matrix(labels).reshape(-1, 1) 31 | m, n = dataset.shape 32 | w = np.ones((n, 1)) 33 | alpha = 0.001 34 | ws = [] 35 | for i in range(max_iter): 36 | error = vlabels - self.sigmoid(dataset*w) 37 | w += alpha*dataset.T*error 38 | ws.append(w.reshape(1, -1).tolist()[0]) 39 | 40 | self.w = w 41 | 42 | return w, np.array(ws) 43 | 44 | def classify(self, data, w=None): 45 | ''' 对未知数据进行预测 46 | ''' 47 | if w is None: 48 | w = self.w 49 | 50 | data = np.matrix(data) 51 | prob = self.sigmoid((data*w).tolist()[0][0]) 52 | return round(prob) 53 | 54 | def load_data(filename): 55 | dataset, labels = [], [] 56 | with open(filename, 'r') as f: 57 | for line in f: 58 | splited_line = [float(i) for i in line.strip().split('\t')] 59 | data, label = [1.0] + splited_line[: -1], splited_line[-1] 60 | dataset.append(data) 61 | labels.append(label) 62 | dataset = np.array(dataset) 63 | labels = np.array(labels) 64 | 65 | return dataset, labels 66 | 67 | def snapshot(w, dataset, labels, pic_name): 68 | ''' 绘制类型分割线图 69 | ''' 70 | if not os.path.exists('./snapshots'): 71 | os.mkdir('./snapshots') 72 | 73 | fig = plt.figure() 74 | ax = fig.add_subplot(111) 75 | 76 | pts = {} 77 | for data, label in zip(dataset.tolist(), labels.tolist()): 78 | pts.setdefault(label, [data]).append(data) 79 | 80 | for label, data in pts.items(): 81 | data = np.array(data) 82 | plt.scatter(data[:, 1], data[:, 2], label=label, alpha=0.5) 83 | 84 | # 分割线绘制 85 | def get_y(x, w): 86 | w0, w1, w2 = w 87 | return (-w0 - w1*x)/w2 88 | 89 | x = [-4.0, 3.0] 90 | y = [get_y(i, w) for i in x] 91 | 92 | plt.plot(x, y, linewidth=2, color='#FB4A42') 93 | 94 | pic_name = './snapshots/{}'.format(pic_name) 95 | fig.savefig(pic_name) 96 | plt.close(fig) 97 | 98 | if '__main__' == __name__: 99 | clf = LogisticRegressionClassifier() 100 | dataset, labels = load_data('testSet.txt') 101 | w, ws = clf.gradient_ascent(dataset, labels, max_iter=50000) 102 | m, n = ws.shape 103 | 104 | # 绘制分割线 105 | for i in range(300): 106 | if i % (30) == 0: 107 | print('{}.png saved'.format(i)) 108 | snapshot(ws[i].tolist(), dataset, labels, '{}.png'.format(i)) 109 | 110 | fig = plt.figure() 111 | for i in range(n): 112 | label = 'w{}'.format(i) 113 | ax = fig.add_subplot(n, 1, i+1) 114 | ax.plot(ws[:, i], label=label) 115 | ax.legend() 116 | 117 | fig.savefig('w_traj.png') 118 | 119 | -------------------------------------------------------------------------------- /Reinforcement Learning/Calculating State Utilities.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "A MDP is a reinterpretation of Markov chains which includes an agent and a decision making process. A MDP is defined by these components:\n", 8 | "1. Set of possible States: S={s0,s1,...,sm}\n", 9 | "2. Initial State:s0\n", 10 | "3. Set of possible Actions:A={a0,a1,...,an}\n", 11 | "4. Transition Model:T(s,a,s′)\n", 12 | "5. Reward Function: R(s)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "We are going to implement MDP in a grid world of 3 x 4 space where our agent/robot is situated at (1,1) in the beginning and needs to reach (3,4) state which is its desired goal state. There is also a fault state at (2,4) which the robot needs to avoid at all costs. The movement of the robot from one state to another earns it a reward. Naturally, the reward for the goal state is the highest and the least for the fault state. The objective of the robot is to maximize its reward and thus plan its movements/actions accordingly. It can move in any direction and this is a stochastic process." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "To compare the states, we calculate the utility of these states and this is shown below:" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import numpy as np\n", 36 | "\n", 37 | "def state_utility(v, T, u, reward, gamma):\n", 38 | " \n", 39 | " #v is the state vector\n", 40 | " #T is the transition matrix\n", 41 | " #u is the utility vector\n", 42 | " #reward consists of the rewards earned for moving to a particular state\n", 43 | " #gamma is the discount factor by which rewards are discounted over the time\n", 44 | "\n", 45 | " action_array = np.zeros(4)\n", 46 | " for action in range(0, 4):\n", 47 | " action_array[action] = np.sum(np.multiply(u, np.dot(v, T[:,:,action])))\n", 48 | " return reward + gamma * np.max(action_array)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Utility of state (1,1): 0.7056\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "def main():\n", 66 | " \n", 67 | " #The agent starts from (1, 1)\n", 68 | " v = np.array([[0.0, 0.0, 0.0, 0.0, \n", 69 | " 0.0, 0.0, 0.0, 0.0, \n", 70 | " 1.0, 0.0, 0.0, 0.0]])\n", 71 | " \n", 72 | " #file loaded from the folder\n", 73 | " T = np.load(\"T.npy\")\n", 74 | "\n", 75 | " #Utility vector\n", 76 | " u = np.array([[0.812, 0.868, 0.918, 1.0,\n", 77 | " 0.762, 0.0, 0.660, -1.0,\n", 78 | " 0.705, 0.655, 0.611, 0.388]])\n", 79 | "\n", 80 | " #Define the reward for state (1,1)\n", 81 | " reward = -0.04\n", 82 | " #Assume that the discount factor is equal to 1.0\n", 83 | " gamma = 1.0\n", 84 | "\n", 85 | " #Use the Bellman equation to find the utility of state (1,1)\n", 86 | " utility_11 = state_utility(v, T, u, reward, gamma)\n", 87 | " print(\"Utility of state (1,1): \" + str(utility_11))\n", 88 | "\n", 89 | "if __name__ == \"__main__\":\n", 90 | " main()" 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.5.4" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /classification_and_regression_trees/ex0.txt: -------------------------------------------------------------------------------- 1 | 0.409175 1.883180 2 | 0.182603 0.063908 3 | 0.663687 3.042257 4 | 0.517395 2.305004 5 | 0.013643 -0.067698 6 | 0.469643 1.662809 7 | 0.725426 3.275749 8 | 0.394350 1.118077 9 | 0.507760 2.095059 10 | 0.237395 1.181912 11 | 0.057534 0.221663 12 | 0.369820 0.938453 13 | 0.976819 4.149409 14 | 0.616051 3.105444 15 | 0.413700 1.896278 16 | 0.105279 -0.121345 17 | 0.670273 3.161652 18 | 0.952758 4.135358 19 | 0.272316 0.859063 20 | 0.303697 1.170272 21 | 0.486698 1.687960 22 | 0.511810 1.979745 23 | 0.195865 0.068690 24 | 0.986769 4.052137 25 | 0.785623 3.156316 26 | 0.797583 2.950630 27 | 0.081306 0.068935 28 | 0.659753 2.854020 29 | 0.375270 0.999743 30 | 0.819136 4.048082 31 | 0.142432 0.230923 32 | 0.215112 0.816693 33 | 0.041270 0.130713 34 | 0.044136 -0.537706 35 | 0.131337 -0.339109 36 | 0.463444 2.124538 37 | 0.671905 2.708292 38 | 0.946559 4.017390 39 | 0.904176 4.004021 40 | 0.306674 1.022555 41 | 0.819006 3.657442 42 | 0.845472 4.073619 43 | 0.156258 0.011994 44 | 0.857185 3.640429 45 | 0.400158 1.808497 46 | 0.375395 1.431404 47 | 0.885807 3.935544 48 | 0.239960 1.162152 49 | 0.148640 -0.227330 50 | 0.143143 -0.068728 51 | 0.321582 0.825051 52 | 0.509393 2.008645 53 | 0.355891 0.664566 54 | 0.938633 4.180202 55 | 0.348057 0.864845 56 | 0.438898 1.851174 57 | 0.781419 2.761993 58 | 0.911333 4.075914 59 | 0.032469 0.110229 60 | 0.499985 2.181987 61 | 0.771663 3.152528 62 | 0.670361 3.046564 63 | 0.176202 0.128954 64 | 0.392170 1.062726 65 | 0.911188 3.651742 66 | 0.872288 4.401950 67 | 0.733107 3.022888 68 | 0.610239 2.874917 69 | 0.732739 2.946801 70 | 0.714825 2.893644 71 | 0.076386 0.072131 72 | 0.559009 1.748275 73 | 0.427258 1.912047 74 | 0.841875 3.710686 75 | 0.558918 1.719148 76 | 0.533241 2.174090 77 | 0.956665 3.656357 78 | 0.620393 3.522504 79 | 0.566120 2.234126 80 | 0.523258 1.859772 81 | 0.476884 2.097017 82 | 0.176408 0.001794 83 | 0.303094 1.231928 84 | 0.609731 2.953862 85 | 0.017774 -0.116803 86 | 0.622616 2.638864 87 | 0.886539 3.943428 88 | 0.148654 -0.328513 89 | 0.104350 -0.099866 90 | 0.116868 -0.030836 91 | 0.516514 2.359786 92 | 0.664896 3.212581 93 | 0.004327 0.188975 94 | 0.425559 1.904109 95 | 0.743671 3.007114 96 | 0.935185 3.845834 97 | 0.697300 3.079411 98 | 0.444551 1.939739 99 | 0.683753 2.880078 100 | 0.755993 3.063577 101 | 0.902690 4.116296 102 | 0.094491 -0.240963 103 | 0.873831 4.066299 104 | 0.991810 4.011834 105 | 0.185611 0.077710 106 | 0.694551 3.103069 107 | 0.657275 2.811897 108 | 0.118746 -0.104630 109 | 0.084302 0.025216 110 | 0.945341 4.330063 111 | 0.785827 3.087091 112 | 0.530933 2.269988 113 | 0.879594 4.010701 114 | 0.652770 3.119542 115 | 0.879338 3.723411 116 | 0.764739 2.792078 117 | 0.504884 2.192787 118 | 0.554203 2.081305 119 | 0.493209 1.714463 120 | 0.363783 0.885854 121 | 0.316465 1.028187 122 | 0.580283 1.951497 123 | 0.542898 1.709427 124 | 0.112661 0.144068 125 | 0.816742 3.880240 126 | 0.234175 0.921876 127 | 0.402804 1.979316 128 | 0.709423 3.085768 129 | 0.867298 3.476122 130 | 0.993392 3.993679 131 | 0.711580 3.077880 132 | 0.133643 -0.105365 133 | 0.052031 -0.164703 134 | 0.366806 1.096814 135 | 0.697521 3.092879 136 | 0.787262 2.987926 137 | 0.476710 2.061264 138 | 0.721417 2.746854 139 | 0.230376 0.716710 140 | 0.104397 0.103831 141 | 0.197834 0.023776 142 | 0.129291 -0.033299 143 | 0.528528 1.942286 144 | 0.009493 -0.006338 145 | 0.998533 3.808753 146 | 0.363522 0.652799 147 | 0.901386 4.053747 148 | 0.832693 4.569290 149 | 0.119002 -0.032773 150 | 0.487638 2.066236 151 | 0.153667 0.222785 152 | 0.238619 1.089268 153 | 0.208197 1.487788 154 | 0.750921 2.852033 155 | 0.183403 0.024486 156 | 0.995608 3.737750 157 | 0.151311 0.045017 158 | 0.126804 0.001238 159 | 0.983153 3.892763 160 | 0.772495 2.819376 161 | 0.784133 2.830665 162 | 0.056934 0.234633 163 | 0.425584 1.810782 164 | 0.998709 4.237235 165 | 0.707815 3.034768 166 | 0.413816 1.742106 167 | 0.217152 1.169250 168 | 0.360503 0.831165 169 | 0.977989 3.729376 170 | 0.507953 1.823205 171 | 0.920771 4.021970 172 | 0.210542 1.262939 173 | 0.928611 4.159518 174 | 0.580373 2.039114 175 | 0.841390 4.101837 176 | 0.681530 2.778672 177 | 0.292795 1.228284 178 | 0.456918 1.736620 179 | 0.134128 -0.195046 180 | 0.016241 -0.063215 181 | 0.691214 3.305268 182 | 0.582002 2.063627 183 | 0.303102 0.898840 184 | 0.622598 2.701692 185 | 0.525024 1.992909 186 | 0.996775 3.811393 187 | 0.881025 4.353857 188 | 0.723457 2.635641 189 | 0.676346 2.856311 190 | 0.254625 1.352682 191 | 0.488632 2.336459 192 | 0.519875 2.111651 193 | 0.160176 0.121726 194 | 0.609483 3.264605 195 | 0.531881 2.103446 196 | 0.321632 0.896855 197 | 0.845148 4.220850 198 | 0.012003 -0.217283 199 | 0.018883 -0.300577 200 | 0.071476 0.006014 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/exp2.txt: -------------------------------------------------------------------------------- 1 | 0.070670 3.470829 2 | 0.534076 6.377132 3 | 0.747221 8.949407 4 | 0.668970 8.034081 5 | 0.586082 6.997721 6 | 0.764962 9.318110 7 | 0.658125 7.880333 8 | 0.346734 4.213359 9 | 0.313967 3.762496 10 | 0.601418 7.188805 11 | 0.404396 4.893403 12 | 0.154345 3.683175 13 | 0.984061 11.712928 14 | 0.597514 7.146694 15 | 0.005144 3.333150 16 | 0.142295 3.743681 17 | 0.280007 3.737376 18 | 0.542008 6.494275 19 | 0.466781 5.532255 20 | 0.706970 8.476718 21 | 0.191038 3.673921 22 | 0.756591 9.176722 23 | 0.912879 10.850358 24 | 0.524701 6.067444 25 | 0.306090 3.681148 26 | 0.429009 5.032168 27 | 0.695091 8.209058 28 | 0.984495 11.909595 29 | 0.702748 8.298454 30 | 0.551771 6.715210 31 | 0.272894 3.983313 32 | 0.014611 3.559081 33 | 0.699852 8.417306 34 | 0.309710 3.739053 35 | 0.444877 5.219649 36 | 0.717509 8.483072 37 | 0.576550 6.894860 38 | 0.284200 3.792626 39 | 0.675922 8.067282 40 | 0.304401 3.671373 41 | 0.233675 3.795962 42 | 0.453779 5.477533 43 | 0.900938 10.701447 44 | 0.502418 6.046703 45 | 0.781843 9.254690 46 | 0.226271 3.546938 47 | 0.619535 7.703312 48 | 0.519998 6.202835 49 | 0.399447 4.934647 50 | 0.785298 9.497564 51 | 0.010767 3.565835 52 | 0.696399 8.307487 53 | 0.524366 6.266060 54 | 0.396583 4.611390 55 | 0.059988 3.484805 56 | 0.946702 11.263118 57 | 0.417559 4.895128 58 | 0.609194 7.239316 59 | 0.730687 8.858371 60 | 0.586694 7.061601 61 | 0.829567 9.937968 62 | 0.964229 11.521595 63 | 0.276813 3.756406 64 | 0.987041 11.947913 65 | 0.876107 10.440538 66 | 0.747582 8.942278 67 | 0.117348 3.567821 68 | 0.188617 3.976420 69 | 0.416655 4.928907 70 | 0.192995 3.978365 71 | 0.244888 3.777018 72 | 0.806349 9.685831 73 | 0.417555 4.990148 74 | 0.233805 3.740022 75 | 0.357325 4.325355 76 | 0.190201 3.638493 77 | 0.705127 8.432886 78 | 0.336599 3.868493 79 | 0.473786 5.871813 80 | 0.384794 4.830712 81 | 0.502217 6.117244 82 | 0.788220 9.454959 83 | 0.478773 5.681631 84 | 0.064296 3.642040 85 | 0.332143 3.886628 86 | 0.618869 7.312725 87 | 0.854981 10.306697 88 | 0.570000 6.764615 89 | 0.512739 6.166836 90 | 0.112285 3.545863 91 | 0.723700 8.526944 92 | 0.192256 3.661033 93 | 0.181268 3.678579 94 | 0.196731 3.916622 95 | 0.510342 6.026652 96 | 0.263713 3.723018 97 | 0.141105 3.529595 98 | 0.150262 3.552314 99 | 0.824724 9.973690 100 | 0.588088 6.893128 101 | 0.411291 4.856380 102 | 0.763717 9.199101 103 | 0.212118 3.740024 104 | 0.264587 3.742917 105 | 0.973524 11.683243 106 | 0.250670 3.679117 107 | 0.823460 9.743861 108 | 0.253752 3.781488 109 | 0.838332 10.172180 110 | 0.501156 6.113263 111 | 0.097275 3.472367 112 | 0.667199 7.948868 113 | 0.487320 6.022060 114 | 0.654640 7.809457 115 | 0.906907 10.775188 116 | 0.821941 9.936140 117 | 0.859396 10.428255 118 | 0.078696 3.490510 119 | 0.938092 11.252471 120 | 0.998868 11.863062 121 | 0.025501 3.515624 122 | 0.451806 5.441171 123 | 0.883872 10.498912 124 | 0.583567 6.912334 125 | 0.823688 10.003723 126 | 0.891032 10.818109 127 | 0.879259 10.639263 128 | 0.163007 3.662715 129 | 0.344263 4.169705 130 | 0.796083 9.422591 131 | 0.903683 10.978834 132 | 0.050129 3.575105 133 | 0.605553 7.306014 134 | 0.628951 7.556742 135 | 0.877052 10.444055 136 | 0.829402 9.856432 137 | 0.121422 3.638276 138 | 0.721517 8.663569 139 | 0.066532 3.673471 140 | 0.996587 11.782002 141 | 0.653384 7.804568 142 | 0.739494 8.817809 143 | 0.640341 7.636812 144 | 0.337828 3.971613 145 | 0.220512 3.713645 146 | 0.368815 4.381696 147 | 0.782509 9.349428 148 | 0.645825 7.790882 149 | 0.277391 3.834258 150 | 0.092569 3.643274 151 | 0.284320 3.609353 152 | 0.344465 4.023259 153 | 0.182523 3.749195 154 | 0.385001 4.426970 155 | 0.747609 8.966676 156 | 0.188907 3.711018 157 | 0.806244 9.610438 158 | 0.014211 3.517818 159 | 0.574813 7.040672 160 | 0.714500 8.525624 161 | 0.538982 6.393940 162 | 0.384638 4.649362 163 | 0.915586 10.936577 164 | 0.883513 10.441493 165 | 0.804148 9.742851 166 | 0.466011 5.833439 167 | 0.800574 9.638874 168 | 0.654980 8.028558 169 | 0.348564 4.064616 170 | 0.978595 11.720218 171 | 0.915906 10.833902 172 | 0.285477 3.818961 173 | 0.988631 11.684010 174 | 0.531069 6.305005 175 | 0.181658 3.806995 176 | 0.039657 3.356861 177 | 0.893344 10.776799 178 | 0.355214 4.263666 179 | 0.783508 9.475445 180 | 0.039768 3.429691 181 | 0.546308 6.472749 182 | 0.786882 9.398951 183 | 0.168282 3.564189 184 | 0.374900 4.399040 185 | 0.737767 8.888536 186 | 0.059849 3.431537 187 | 0.861891 10.246888 188 | 0.597578 7.112627 189 | 0.126050 3.611641 190 | 0.074795 3.609222 191 | 0.634401 7.627416 192 | 0.831633 9.926548 193 | 0.019095 3.470285 194 | 0.396533 4.773104 195 | 0.794973 9.492009 196 | 0.889088 10.420003 197 | 0.003174 3.587139 198 | 0.176767 3.554071 199 | 0.943730 11.227731 200 | 0.758564 8.885337 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/ex00.txt: -------------------------------------------------------------------------------- 1 | 0.036098 0.155096 2 | 0.993349 1.077553 3 | 0.530897 0.893462 4 | 0.712386 0.564858 5 | 0.343554 -0.371700 6 | 0.098016 -0.332760 7 | 0.691115 0.834391 8 | 0.091358 0.099935 9 | 0.727098 1.000567 10 | 0.951949 0.945255 11 | 0.768596 0.760219 12 | 0.541314 0.893748 13 | 0.146366 0.034283 14 | 0.673195 0.915077 15 | 0.183510 0.184843 16 | 0.339563 0.206783 17 | 0.517921 1.493586 18 | 0.703755 1.101678 19 | 0.008307 0.069976 20 | 0.243909 -0.029467 21 | 0.306964 -0.177321 22 | 0.036492 0.408155 23 | 0.295511 0.002882 24 | 0.837522 1.229373 25 | 0.202054 -0.087744 26 | 0.919384 1.029889 27 | 0.377201 -0.243550 28 | 0.814825 1.095206 29 | 0.611270 0.982036 30 | 0.072243 -0.420983 31 | 0.410230 0.331722 32 | 0.869077 1.114825 33 | 0.620599 1.334421 34 | 0.101149 0.068834 35 | 0.820802 1.325907 36 | 0.520044 0.961983 37 | 0.488130 -0.097791 38 | 0.819823 0.835264 39 | 0.975022 0.673579 40 | 0.953112 1.064690 41 | 0.475976 -0.163707 42 | 0.273147 -0.455219 43 | 0.804586 0.924033 44 | 0.074795 -0.349692 45 | 0.625336 0.623696 46 | 0.656218 0.958506 47 | 0.834078 1.010580 48 | 0.781930 1.074488 49 | 0.009849 0.056594 50 | 0.302217 -0.148650 51 | 0.678287 0.907727 52 | 0.180506 0.103676 53 | 0.193641 -0.327589 54 | 0.343479 0.175264 55 | 0.145809 0.136979 56 | 0.996757 1.035533 57 | 0.590210 1.336661 58 | 0.238070 -0.358459 59 | 0.561362 1.070529 60 | 0.377597 0.088505 61 | 0.099142 0.025280 62 | 0.539558 1.053846 63 | 0.790240 0.533214 64 | 0.242204 0.209359 65 | 0.152324 0.132858 66 | 0.252649 -0.055613 67 | 0.895930 1.077275 68 | 0.133300 -0.223143 69 | 0.559763 1.253151 70 | 0.643665 1.024241 71 | 0.877241 0.797005 72 | 0.613765 1.621091 73 | 0.645762 1.026886 74 | 0.651376 1.315384 75 | 0.697718 1.212434 76 | 0.742527 1.087056 77 | 0.901056 1.055900 78 | 0.362314 -0.556464 79 | 0.948268 0.631862 80 | 0.000234 0.060903 81 | 0.750078 0.906291 82 | 0.325412 -0.219245 83 | 0.726828 1.017112 84 | 0.348013 0.048939 85 | 0.458121 -0.061456 86 | 0.280738 -0.228880 87 | 0.567704 0.969058 88 | 0.750918 0.748104 89 | 0.575805 0.899090 90 | 0.507940 1.107265 91 | 0.071769 -0.110946 92 | 0.553520 1.391273 93 | 0.401152 -0.121640 94 | 0.406649 -0.366317 95 | 0.652121 1.004346 96 | 0.347837 -0.153405 97 | 0.081931 -0.269756 98 | 0.821648 1.280895 99 | 0.048014 0.064496 100 | 0.130962 0.184241 101 | 0.773422 1.125943 102 | 0.789625 0.552614 103 | 0.096994 0.227167 104 | 0.625791 1.244731 105 | 0.589575 1.185812 106 | 0.323181 0.180811 107 | 0.822443 1.086648 108 | 0.360323 -0.204830 109 | 0.950153 1.022906 110 | 0.527505 0.879560 111 | 0.860049 0.717490 112 | 0.007044 0.094150 113 | 0.438367 0.034014 114 | 0.574573 1.066130 115 | 0.536689 0.867284 116 | 0.782167 0.886049 117 | 0.989888 0.744207 118 | 0.761474 1.058262 119 | 0.985425 1.227946 120 | 0.132543 -0.329372 121 | 0.346986 -0.150389 122 | 0.768784 0.899705 123 | 0.848921 1.170959 124 | 0.449280 0.069098 125 | 0.066172 0.052439 126 | 0.813719 0.706601 127 | 0.661923 0.767040 128 | 0.529491 1.022206 129 | 0.846455 0.720030 130 | 0.448656 0.026974 131 | 0.795072 0.965721 132 | 0.118156 -0.077409 133 | 0.084248 -0.019547 134 | 0.845815 0.952617 135 | 0.576946 1.234129 136 | 0.772083 1.299018 137 | 0.696648 0.845423 138 | 0.595012 1.213435 139 | 0.648675 1.287407 140 | 0.897094 1.240209 141 | 0.552990 1.036158 142 | 0.332982 0.210084 143 | 0.065615 -0.306970 144 | 0.278661 0.253628 145 | 0.773168 1.140917 146 | 0.203693 -0.064036 147 | 0.355688 -0.119399 148 | 0.988852 1.069062 149 | 0.518735 1.037179 150 | 0.514563 1.156648 151 | 0.976414 0.862911 152 | 0.919074 1.123413 153 | 0.697777 0.827805 154 | 0.928097 0.883225 155 | 0.900272 0.996871 156 | 0.344102 -0.061539 157 | 0.148049 0.204298 158 | 0.130052 -0.026167 159 | 0.302001 0.317135 160 | 0.337100 0.026332 161 | 0.314924 -0.001952 162 | 0.269681 -0.165971 163 | 0.196005 -0.048847 164 | 0.129061 0.305107 165 | 0.936783 1.026258 166 | 0.305540 -0.115991 167 | 0.683921 1.414382 168 | 0.622398 0.766330 169 | 0.902532 0.861601 170 | 0.712503 0.933490 171 | 0.590062 0.705531 172 | 0.723120 1.307248 173 | 0.188218 0.113685 174 | 0.643601 0.782552 175 | 0.520207 1.209557 176 | 0.233115 -0.348147 177 | 0.465625 -0.152940 178 | 0.884512 1.117833 179 | 0.663200 0.701634 180 | 0.268857 0.073447 181 | 0.729234 0.931956 182 | 0.429664 -0.188659 183 | 0.737189 1.200781 184 | 0.378595 -0.296094 185 | 0.930173 1.035645 186 | 0.774301 0.836763 187 | 0.273940 -0.085713 188 | 0.824442 1.082153 189 | 0.626011 0.840544 190 | 0.679390 1.307217 191 | 0.578252 0.921885 192 | 0.785541 1.165296 193 | 0.597409 0.974770 194 | 0.014083 -0.132525 195 | 0.663870 1.187129 196 | 0.552381 1.369630 197 | 0.683886 0.999985 198 | 0.210334 -0.006899 199 | 0.604529 1.212685 200 | 0.250744 0.046297 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/exp.txt: -------------------------------------------------------------------------------- 1 | 0.529582 100.737303 2 | 0.985730 103.106872 3 | 0.797869 99.666151 4 | 0.393473 -1.773056 5 | 0.272568 -1.170222 6 | 0.758825 96.752440 7 | 0.218359 2.337347 8 | 0.926357 98.343231 9 | 0.726881 99.633009 10 | 0.805311 102.253834 11 | 0.208632 0.493174 12 | 0.184921 -2.231071 13 | 0.660135 100.139355 14 | 0.871875 96.637420 15 | 0.657182 100.345442 16 | 0.942481 97.751546 17 | 0.427843 -1.380170 18 | 0.845958 98.195303 19 | 0.878696 99.380485 20 | 0.582034 100.971036 21 | 0.118114 2.397033 22 | 0.144718 1.304535 23 | 0.576046 101.624714 24 | 0.750305 97.601324 25 | 0.518281 100.093634 26 | 0.260793 -1.361888 27 | 0.390245 -2.973759 28 | 0.963020 98.877859 29 | 0.880661 97.631997 30 | 0.291780 -1.638124 31 | 0.192903 -2.221257 32 | 0.461442 -1.074725 33 | 0.821171 99.372052 34 | 0.144557 2.589464 35 | 0.379346 0.991090 36 | 0.383822 1.832389 37 | 0.055406 -1.870700 38 | 0.084308 -0.611701 39 | 0.719578 100.087948 40 | 0.417471 -0.510292 41 | 0.477894 -3.426525 42 | 0.871228 100.307522 43 | 0.113074 -1.011079 44 | 0.409434 -0.616173 45 | 0.967141 96.551856 46 | 0.938254 97.052196 47 | 0.079989 2.083496 48 | 0.150207 1.285491 49 | 0.417339 -0.462985 50 | 0.038787 -2.237234 51 | 0.954657 102.111432 52 | 0.844894 98.350138 53 | 0.106770 -0.998182 54 | 0.247831 2.483594 55 | 0.108687 -0.920229 56 | 0.758165 98.079399 57 | 0.199978 -3.490410 58 | 0.600602 99.850119 59 | 0.026466 1.342825 60 | 0.141239 -0.949858 61 | 0.181437 -2.223725 62 | 0.352656 2.251362 63 | 0.803371 99.647157 64 | 0.677303 100.414859 65 | 0.561674 99.133372 66 | 0.497533 -3.764935 67 | 0.523327 98.452850 68 | 0.507075 103.807755 69 | 0.791978 99.414598 70 | 0.956890 95.977239 71 | 0.487927 1.199149 72 | 0.788795 100.012047 73 | 0.554283 98.522458 74 | 0.814361 97.642150 75 | 0.788940 97.399942 76 | 0.515845 102.240479 77 | 0.758538 97.461917 78 | 0.041824 -3.294141 79 | 0.341352 1.246559 80 | 0.194801 -2.285278 81 | 0.805528 99.023113 82 | 0.435762 0.361749 83 | 0.941615 100.746547 84 | 0.478234 0.791146 85 | 0.057445 -4.266792 86 | 0.510079 98.845273 87 | 0.209900 -0.861890 88 | 0.902668 101.429190 89 | 0.456602 -2.856392 90 | 0.997595 99.828241 91 | 0.048240 -0.268920 92 | 0.319531 0.896696 93 | 0.264929 -1.000487 94 | 0.432727 -4.630489 95 | 0.419828 1.260534 96 | 0.667056 99.456518 97 | 0.488173 1.574322 98 | 0.746300 100.563503 99 | 0.528660 100.736739 100 | 0.624185 99.562872 101 | 0.169411 1.809929 102 | 0.011025 4.132846 103 | 0.974164 98.706049 104 | 0.267957 0.297803 105 | 0.726093 99.381040 106 | 0.465163 -2.344545 107 | 0.993698 101.507792 108 | 0.816513 99.903496 109 | 0.398756 0.378060 110 | 0.054974 -0.588770 111 | 0.857067 100.322945 112 | 0.362328 2.551786 113 | 0.316961 -0.528283 114 | 0.167881 -0.376517 115 | 0.393776 3.658204 116 | 0.739991 100.426554 117 | 0.457949 0.857428 118 | 0.060635 2.484776 119 | 0.942634 101.254420 120 | 0.553691 102.467820 121 | 0.394694 -0.248353 122 | 0.714625 99.650556 123 | 0.273503 1.111820 124 | 0.471886 -5.665559 125 | 0.746476 98.720163 126 | 0.140209 0.471820 127 | 0.024197 -2.854251 128 | 0.521287 99.703915 129 | 0.672280 100.463227 130 | 0.380342 -0.785713 131 | 0.956380 99.482209 132 | 0.455254 1.613841 133 | 0.647551 101.591193 134 | 0.682498 98.267734 135 | 0.054839 -2.286019 136 | 0.716849 100.614510 137 | 0.217732 -2.161633 138 | 0.918885 100.260067 139 | 0.576026 101.719788 140 | 0.868511 100.669152 141 | 0.661135 97.637969 142 | 0.166334 1.374014 143 | 0.106850 -3.658050 144 | 0.768242 104.193841 145 | 0.240916 -0.368100 146 | 0.124957 2.821672 147 | 0.984335 98.571444 148 | 0.908524 101.777344 149 | 0.861217 98.656403 150 | 0.944295 100.154508 151 | 0.527278 101.052710 152 | 0.717072 100.788373 153 | 0.130227 0.115694 154 | 0.494734 -1.220681 155 | 0.498733 0.961514 156 | 0.519411 101.331622 157 | 0.712409 104.891067 158 | 0.933858 98.180299 159 | 0.266051 0.398961 160 | 0.153690 -0.657128 161 | 0.209181 1.486816 162 | 0.942699 102.187578 163 | 0.766799 100.213348 164 | 0.862578 101.816969 165 | 0.223266 2.854445 166 | 0.611394 103.428497 167 | 0.996212 98.494158 168 | 0.724945 99.098450 169 | 0.399346 0.879259 170 | 0.750510 98.729864 171 | 0.446060 0.639843 172 | 0.999913 101.502887 173 | 0.111561 3.256383 174 | 0.094755 0.170475 175 | 0.366547 0.488994 176 | 0.179924 -0.871567 177 | 0.969023 99.982789 178 | 0.941420 100.416754 179 | 0.656851 98.520940 180 | 0.983166 99.546591 181 | 0.167843 0.033922 182 | 0.316245 2.171137 183 | 0.817118 102.849575 184 | 0.173642 1.209173 185 | 0.411030 2.022640 186 | 0.265041 2.216470 187 | 0.779660 98.475428 188 | 0.059354 -0.929568 189 | 0.722092 97.974003 190 | 0.511958 101.924447 191 | 0.371938 -0.640602 192 | 0.851009 97.873330 193 | 0.375918 -5.308115 194 | 0.797332 99.763778 195 | 0.107749 -3.770092 196 | 0.156937 -0.876724 197 | 0.960447 99.597097 198 | 0.413434 2.408090 199 | 0.644257 100.453125 200 | 0.119332 -0.495588 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/ex2test.txt: -------------------------------------------------------------------------------- 1 | 0.421862 10.830241 2 | 0.105349 -2.241611 3 | 0.155196 21.872976 4 | 0.161152 2.015418 5 | 0.382632 -38.778979 6 | 0.017710 20.109113 7 | 0.129656 15.266887 8 | 0.613926 111.900063 9 | 0.409277 1.874731 10 | 0.807556 111.223754 11 | 0.593722 133.835486 12 | 0.953239 110.465070 13 | 0.257402 15.332899 14 | 0.645385 93.983054 15 | 0.563460 93.645277 16 | 0.408338 -30.719878 17 | 0.874394 91.873505 18 | 0.263805 -0.192752 19 | 0.411198 10.751118 20 | 0.449884 9.211901 21 | 0.646315 113.533660 22 | 0.673718 125.135638 23 | 0.805148 113.300462 24 | 0.759327 72.668572 25 | 0.519172 82.131698 26 | 0.741031 106.777146 27 | 0.030937 9.859127 28 | 0.268848 -34.137955 29 | 0.474901 -11.201301 30 | 0.588266 120.501998 31 | 0.893936 142.826476 32 | 0.870990 105.751746 33 | 0.430763 39.146258 34 | 0.057665 15.371897 35 | 0.100076 9.131761 36 | 0.980716 116.145896 37 | 0.235289 -13.691224 38 | 0.228098 16.089151 39 | 0.622248 99.345551 40 | 0.401467 -1.694383 41 | 0.960334 110.795415 42 | 0.031214 -5.330042 43 | 0.504228 96.003525 44 | 0.779660 75.921582 45 | 0.504496 101.341462 46 | 0.850974 96.293064 47 | 0.701119 102.333839 48 | 0.191551 5.072326 49 | 0.667116 92.310019 50 | 0.555584 80.367129 51 | 0.680006 132.965442 52 | 0.393899 38.605283 53 | 0.048940 -9.861871 54 | 0.963282 115.407485 55 | 0.655496 104.269918 56 | 0.576463 141.127267 57 | 0.675708 96.227996 58 | 0.853457 114.252288 59 | 0.003933 -12.182861 60 | 0.549512 97.927224 61 | 0.218967 -4.712462 62 | 0.659972 120.950439 63 | 0.008256 8.026816 64 | 0.099500 -14.318434 65 | 0.352215 -3.747546 66 | 0.874926 89.247356 67 | 0.635084 99.496059 68 | 0.039641 14.147109 69 | 0.665111 103.298719 70 | 0.156583 -2.540703 71 | 0.648843 119.333019 72 | 0.893237 95.209585 73 | 0.128807 5.558479 74 | 0.137438 5.567685 75 | 0.630538 98.462792 76 | 0.296084 -41.799438 77 | 0.632099 84.895098 78 | 0.987681 106.726447 79 | 0.744909 111.279705 80 | 0.862030 104.581156 81 | 0.080649 -7.679985 82 | 0.831277 59.053356 83 | 0.198716 26.878801 84 | 0.860932 90.632930 85 | 0.883250 92.759595 86 | 0.818003 110.272219 87 | 0.949216 115.200237 88 | 0.460078 -35.957981 89 | 0.561077 93.545761 90 | 0.863767 114.125786 91 | 0.476891 -29.774060 92 | 0.537826 81.587922 93 | 0.686224 110.911198 94 | 0.982327 119.114523 95 | 0.944453 92.033481 96 | 0.078227 30.216873 97 | 0.782937 92.588646 98 | 0.465886 2.222139 99 | 0.885024 90.247890 100 | 0.186077 7.144415 101 | 0.915828 84.010074 102 | 0.796649 115.572156 103 | 0.127821 28.933688 104 | 0.433429 6.782575 105 | 0.946796 108.574116 106 | 0.386915 -17.404601 107 | 0.561192 92.142700 108 | 0.182490 10.764616 109 | 0.878792 95.289476 110 | 0.381342 -6.177464 111 | 0.358474 -11.731754 112 | 0.270647 13.793201 113 | 0.488904 -17.641832 114 | 0.106773 5.684757 115 | 0.270112 4.335675 116 | 0.754985 75.860433 117 | 0.585174 111.640154 118 | 0.458821 12.029692 119 | 0.218017 -26.234872 120 | 0.583887 99.413850 121 | 0.923626 107.802298 122 | 0.833620 104.179678 123 | 0.870691 93.132591 124 | 0.249896 -8.618404 125 | 0.748230 109.160652 126 | 0.019365 34.048884 127 | 0.837588 101.239275 128 | 0.529251 115.514729 129 | 0.742898 67.038771 130 | 0.522034 64.160799 131 | 0.498982 3.983061 132 | 0.479439 24.355908 133 | 0.314834 -14.256200 134 | 0.753251 85.017092 135 | 0.479362 -17.480446 136 | 0.950593 99.072784 137 | 0.718623 58.080256 138 | 0.218720 -19.605593 139 | 0.664113 94.437159 140 | 0.942900 131.725134 141 | 0.314226 18.904871 142 | 0.284509 11.779346 143 | 0.004962 -14.624176 144 | 0.224087 -50.547649 145 | 0.974331 112.822725 146 | 0.894610 112.863995 147 | 0.167350 0.073380 148 | 0.753644 105.024456 149 | 0.632241 108.625812 150 | 0.314189 -6.090797 151 | 0.965527 87.418343 152 | 0.820919 94.610538 153 | 0.144107 -4.748387 154 | 0.072556 -5.682008 155 | 0.002447 29.685714 156 | 0.851007 79.632376 157 | 0.458024 -12.326026 158 | 0.627503 139.458881 159 | 0.422259 -29.827405 160 | 0.714659 63.480271 161 | 0.672320 93.608554 162 | 0.498592 37.112975 163 | 0.698906 96.282845 164 | 0.861441 99.699230 165 | 0.112425 -12.419909 166 | 0.164784 5.244704 167 | 0.481531 -18.070497 168 | 0.375482 1.779411 169 | 0.089325 -14.216755 170 | 0.036609 -6.264372 171 | 0.945004 54.723563 172 | 0.136608 14.970936 173 | 0.292285 -41.723711 174 | 0.029195 -0.660279 175 | 0.998307 100.124230 176 | 0.303928 -5.492264 177 | 0.957863 117.824392 178 | 0.815089 113.377704 179 | 0.466399 -10.249874 180 | 0.876693 115.617275 181 | 0.536121 102.997087 182 | 0.373984 -37.359936 183 | 0.565162 74.967476 184 | 0.085412 -21.449563 185 | 0.686411 64.859620 186 | 0.908752 107.983366 187 | 0.982829 98.005424 188 | 0.052766 -42.139502 189 | 0.777552 91.899340 190 | 0.374316 -3.522501 191 | 0.060231 10.008227 192 | 0.526225 87.317722 193 | 0.583872 67.104433 194 | 0.238276 10.615159 195 | 0.678747 60.624273 196 | 0.067649 15.947398 197 | 0.530182 105.030933 198 | 0.869389 104.969996 199 | 0.698410 75.460417 200 | 0.549430 82.558068 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/ex2.txt: -------------------------------------------------------------------------------- 1 | 0.228628 -2.266273 2 | 0.965969 112.386764 3 | 0.342761 -31.584855 4 | 0.901444 87.300625 5 | 0.585413 125.295113 6 | 0.334900 18.976650 7 | 0.769043 64.041941 8 | 0.297107 -1.798377 9 | 0.901421 100.133819 10 | 0.176523 0.946348 11 | 0.710234 108.553919 12 | 0.981980 86.399637 13 | 0.085873 -10.137104 14 | 0.537834 90.995536 15 | 0.806158 62.877698 16 | 0.708890 135.416767 17 | 0.787755 118.642009 18 | 0.463241 17.171057 19 | 0.300318 -18.051318 20 | 0.815215 118.319942 21 | 0.139880 7.336784 22 | 0.068373 -15.160836 23 | 0.457563 -34.044555 24 | 0.665652 105.547997 25 | 0.084661 -24.132226 26 | 0.954711 100.935789 27 | 0.953902 130.926480 28 | 0.487381 27.729263 29 | 0.759504 81.106762 30 | 0.454312 -20.360067 31 | 0.295993 -14.988279 32 | 0.156067 7.557349 33 | 0.428582 15.224266 34 | 0.847219 76.240984 35 | 0.499171 11.924204 36 | 0.203993 -22.379119 37 | 0.548539 83.114502 38 | 0.790312 110.159730 39 | 0.937766 119.949824 40 | 0.218321 1.410768 41 | 0.223200 15.501642 42 | 0.896683 107.001620 43 | 0.582311 82.589328 44 | 0.698920 92.470636 45 | 0.823848 59.342323 46 | 0.385021 24.816941 47 | 0.061219 6.695567 48 | 0.841547 115.669032 49 | 0.763328 115.199195 50 | 0.934853 115.753994 51 | 0.222271 -9.255852 52 | 0.217214 -3.958752 53 | 0.706961 106.180427 54 | 0.888426 94.896354 55 | 0.549814 137.267576 56 | 0.107960 -1.293195 57 | 0.085111 37.820659 58 | 0.388789 21.578007 59 | 0.467383 -9.712925 60 | 0.623909 87.181863 61 | 0.373501 -8.228297 62 | 0.513332 101.075609 63 | 0.350725 -40.086564 64 | 0.716211 103.345308 65 | 0.731636 73.912028 66 | 0.273863 -9.457556 67 | 0.211633 -8.332207 68 | 0.944221 100.120253 69 | 0.053764 -13.731698 70 | 0.126833 22.891675 71 | 0.952833 100.649591 72 | 0.391609 3.001104 73 | 0.560301 82.903945 74 | 0.124723 -1.402796 75 | 0.465680 -23.777531 76 | 0.699873 115.586605 77 | 0.164134 -27.405211 78 | 0.455761 9.841938 79 | 0.508542 96.403373 80 | 0.138619 -29.087463 81 | 0.335182 2.768225 82 | 0.908629 118.513475 83 | 0.546601 96.319043 84 | 0.378965 13.583555 85 | 0.968621 98.648346 86 | 0.637999 91.656617 87 | 0.350065 -1.319852 88 | 0.632691 93.645293 89 | 0.936524 65.548418 90 | 0.310956 -49.939516 91 | 0.437652 19.745224 92 | 0.166765 -14.740059 93 | 0.571214 114.872056 94 | 0.952377 73.520802 95 | 0.665329 121.980607 96 | 0.258070 -20.425137 97 | 0.912161 85.005351 98 | 0.777582 100.838446 99 | 0.642707 82.500766 100 | 0.885676 108.045948 101 | 0.080061 2.229873 102 | 0.039914 11.220099 103 | 0.958512 135.837013 104 | 0.377383 5.241196 105 | 0.661073 115.687524 106 | 0.454375 3.043912 107 | 0.412516 -26.419289 108 | 0.854970 89.209930 109 | 0.698472 120.521925 110 | 0.465561 30.051931 111 | 0.328890 39.783113 112 | 0.309133 8.814725 113 | 0.418943 44.161493 114 | 0.553797 120.857321 115 | 0.799873 91.368473 116 | 0.811363 112.981216 117 | 0.785574 107.024467 118 | 0.949198 105.752508 119 | 0.666452 120.014736 120 | 0.652462 112.715799 121 | 0.290749 -14.391613 122 | 0.508548 93.292829 123 | 0.680486 110.367074 124 | 0.356790 -19.526539 125 | 0.199903 -3.372472 126 | 0.264926 5.280579 127 | 0.166431 -6.512506 128 | 0.370042 -32.124495 129 | 0.628061 117.628346 130 | 0.228473 19.425158 131 | 0.044737 3.855393 132 | 0.193282 18.208423 133 | 0.519150 116.176162 134 | 0.351478 -0.461116 135 | 0.872199 111.552716 136 | 0.115150 13.795828 137 | 0.324274 -13.189243 138 | 0.446196 -5.108172 139 | 0.613004 168.180746 140 | 0.533511 129.766743 141 | 0.740859 93.773929 142 | 0.667851 92.449664 143 | 0.900699 109.188248 144 | 0.599142 130.378529 145 | 0.232802 1.222318 146 | 0.838587 134.089674 147 | 0.284794 35.623746 148 | 0.130626 -39.524461 149 | 0.642373 140.613941 150 | 0.786865 100.598825 151 | 0.403228 -1.729244 152 | 0.883615 95.348184 153 | 0.910975 106.814667 154 | 0.819722 70.054508 155 | 0.798198 76.853728 156 | 0.606417 93.521396 157 | 0.108801 -16.106164 158 | 0.318309 -27.605424 159 | 0.856421 107.166848 160 | 0.842940 95.893131 161 | 0.618868 76.917665 162 | 0.531944 124.795495 163 | 0.028546 -8.377094 164 | 0.915263 96.717610 165 | 0.925782 92.074619 166 | 0.624827 105.970743 167 | 0.331364 -1.290825 168 | 0.341700 -23.547711 169 | 0.342155 -16.930416 170 | 0.729397 110.902830 171 | 0.640515 82.713621 172 | 0.228751 -30.812912 173 | 0.948822 69.318649 174 | 0.706390 105.062147 175 | 0.079632 29.420068 176 | 0.451087 -28.724685 177 | 0.833026 76.723835 178 | 0.589806 98.674874 179 | 0.426711 -21.594268 180 | 0.872883 95.887712 181 | 0.866451 94.402102 182 | 0.960398 123.559747 183 | 0.483803 5.224234 184 | 0.811602 99.841379 185 | 0.757527 63.549854 186 | 0.569327 108.435392 187 | 0.841625 60.552308 188 | 0.264639 2.557923 189 | 0.202161 -1.983889 190 | 0.055862 -3.131497 191 | 0.543843 98.362010 192 | 0.689099 112.378209 193 | 0.956951 82.016541 194 | 0.382037 -29.007783 195 | 0.131833 22.478291 196 | 0.156273 0.225886 197 | 0.000256 9.668106 198 | 0.892999 82.436686 199 | 0.206207 -12.619036 200 | 0.487537 5.149336 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/bikeSpeedVsIq_train.txt: -------------------------------------------------------------------------------- 1 | 3.000000 46.852122 2 | 23.000000 178.676107 3 | 0.000000 86.154024 4 | 6.000000 68.707614 5 | 15.000000 139.737693 6 | 17.000000 141.988903 7 | 12.000000 94.477135 8 | 8.000000 86.083788 9 | 9.000000 97.265824 10 | 7.000000 80.400027 11 | 8.000000 83.414554 12 | 1.000000 52.525471 13 | 16.000000 127.060008 14 | 9.000000 101.639269 15 | 14.000000 146.412680 16 | 15.000000 144.157101 17 | 17.000000 152.699910 18 | 19.000000 136.669023 19 | 21.000000 166.971736 20 | 21.000000 165.467251 21 | 3.000000 38.455193 22 | 6.000000 75.557721 23 | 4.000000 22.171763 24 | 5.000000 50.321915 25 | 0.000000 74.412428 26 | 5.000000 42.052392 27 | 1.000000 42.489057 28 | 14.000000 139.185416 29 | 21.000000 140.713725 30 | 5.000000 63.222944 31 | 5.000000 56.294626 32 | 9.000000 91.674826 33 | 22.000000 173.497655 34 | 17.000000 152.692482 35 | 9.000000 113.920633 36 | 1.000000 51.552411 37 | 9.000000 100.075315 38 | 16.000000 137.803868 39 | 18.000000 135.925777 40 | 3.000000 45.550762 41 | 16.000000 149.933224 42 | 2.000000 27.914173 43 | 6.000000 62.103546 44 | 20.000000 173.942381 45 | 12.000000 119.200505 46 | 6.000000 70.730214 47 | 16.000000 156.260832 48 | 15.000000 132.467643 49 | 19.000000 161.164086 50 | 17.000000 138.031844 51 | 23.000000 169.747881 52 | 11.000000 116.761920 53 | 4.000000 34.305905 54 | 6.000000 68.841160 55 | 10.000000 119.535227 56 | 20.000000 158.104763 57 | 18.000000 138.390511 58 | 5.000000 59.375794 59 | 7.000000 80.802300 60 | 11.000000 108.611485 61 | 10.000000 91.169028 62 | 15.000000 154.104819 63 | 5.000000 51.100287 64 | 3.000000 32.334330 65 | 15.000000 150.551655 66 | 10.000000 111.023073 67 | 0.000000 87.489950 68 | 2.000000 46.726299 69 | 7.000000 92.540440 70 | 15.000000 135.715438 71 | 19.000000 152.960552 72 | 19.000000 162.789223 73 | 21.000000 167.176240 74 | 22.000000 164.323358 75 | 12.000000 104.823071 76 | 1.000000 35.554328 77 | 11.000000 114.784640 78 | 1.000000 36.819570 79 | 12.000000 130.266826 80 | 12.000000 126.053312 81 | 18.000000 153.378289 82 | 7.000000 70.089159 83 | 15.000000 139.528624 84 | 19.000000 157.137999 85 | 23.000000 183.595248 86 | 7.000000 73.431043 87 | 11.000000 128.176167 88 | 22.000000 183.181247 89 | 13.000000 112.685801 90 | 18.000000 161.634783 91 | 6.000000 63.169478 92 | 7.000000 63.393975 93 | 19.000000 165.779578 94 | 14.000000 143.973398 95 | 22.000000 185.131852 96 | 3.000000 45.275591 97 | 6.000000 62.018003 98 | 0.000000 83.193398 99 | 7.000000 76.847802 100 | 19.000000 147.087386 101 | 7.000000 62.812086 102 | 1.000000 49.910068 103 | 11.000000 102.169335 104 | 11.000000 105.108121 105 | 6.000000 63.429817 106 | 12.000000 121.301542 107 | 17.000000 163.253962 108 | 13.000000 119.588698 109 | 0.000000 87.333807 110 | 20.000000 144.484066 111 | 21.000000 168.792482 112 | 23.000000 159.751246 113 | 20.000000 162.843592 114 | 14.000000 145.664069 115 | 19.000000 146.838515 116 | 12.000000 132.049377 117 | 18.000000 155.756119 118 | 22.000000 155.686345 119 | 7.000000 73.913958 120 | 1.000000 66.761881 121 | 7.000000 65.855450 122 | 6.000000 56.271026 123 | 19.000000 155.308523 124 | 12.000000 124.372873 125 | 17.000000 136.025960 126 | 14.000000 132.996861 127 | 21.000000 172.639791 128 | 17.000000 135.672594 129 | 8.000000 90.323742 130 | 5.000000 62.462698 131 | 16.000000 159.048794 132 | 14.000000 139.991227 133 | 3.000000 37.026678 134 | 9.000000 100.839901 135 | 9.000000 93.097395 136 | 15.000000 123.645221 137 | 15.000000 147.327185 138 | 1.000000 40.055830 139 | 0.000000 88.192829 140 | 17.000000 139.174517 141 | 22.000000 169.354493 142 | 17.000000 136.354272 143 | 9.000000 90.692829 144 | 7.000000 63.987997 145 | 14.000000 128.972231 146 | 10.000000 108.433394 147 | 2.000000 49.321034 148 | 19.000000 171.615671 149 | 9.000000 97.894855 150 | 0.000000 68.962453 151 | 9.000000 72.063371 152 | 22.000000 157.000070 153 | 12.000000 114.461754 154 | 6.000000 58.239465 155 | 9.000000 104.601048 156 | 8.000000 90.772359 157 | 22.000000 164.428791 158 | 5.000000 34.804083 159 | 5.000000 37.089459 160 | 22.000000 177.987605 161 | 10.000000 89.439608 162 | 6.000000 70.711362 163 | 23.000000 181.731482 164 | 20.000000 151.538932 165 | 7.000000 66.067228 166 | 6.000000 61.565125 167 | 20.000000 184.441687 168 | 9.000000 91.569158 169 | 9.000000 98.833425 170 | 17.000000 144.352866 171 | 9.000000 94.498314 172 | 15.000000 121.922732 173 | 18.000000 166.408274 174 | 10.000000 89.571299 175 | 8.000000 75.373772 176 | 22.000000 161.001478 177 | 8.000000 90.594227 178 | 5.000000 57.180933 179 | 20.000000 161.643007 180 | 8.000000 87.197370 181 | 8.000000 95.584308 182 | 15.000000 126.207221 183 | 7.000000 84.528209 184 | 18.000000 161.056986 185 | 10.000000 86.762615 186 | 1.000000 33.325906 187 | 9.000000 105.095502 188 | 2.000000 22.440421 189 | 9.000000 93.449284 190 | 14.000000 106.249595 191 | 21.000000 163.254385 192 | 22.000000 161.746628 193 | 20.000000 152.973085 194 | 17.000000 122.918987 195 | 7.000000 58.536412 196 | 1.000000 45.013277 197 | 13.000000 137.294148 198 | 10.000000 88.123737 199 | 2.000000 45.847376 200 | 20.000000 163.385797 201 | -------------------------------------------------------------------------------- /classification_and_regression_trees/bikeSpeedVsIq_test.txt: -------------------------------------------------------------------------------- 1 | 12.000000 121.010516 2 | 19.000000 157.337044 3 | 12.000000 116.031825 4 | 15.000000 132.124872 5 | 2.000000 52.719612 6 | 6.000000 39.058368 7 | 3.000000 50.757763 8 | 20.000000 166.740333 9 | 11.000000 115.808227 10 | 21.000000 165.582995 11 | 3.000000 41.956087 12 | 3.000000 34.432370 13 | 13.000000 116.954676 14 | 1.000000 32.112553 15 | 7.000000 50.380243 16 | 7.000000 94.107791 17 | 23.000000 188.943179 18 | 18.000000 152.637773 19 | 9.000000 104.122082 20 | 18.000000 127.805226 21 | 0.000000 83.083232 22 | 15.000000 148.180104 23 | 3.000000 38.480247 24 | 8.000000 77.597839 25 | 7.000000 75.625803 26 | 11.000000 124.620208 27 | 13.000000 125.186698 28 | 5.000000 51.165922 29 | 3.000000 31.179113 30 | 15.000000 132.505727 31 | 19.000000 137.978043 32 | 9.000000 106.481123 33 | 20.000000 172.149955 34 | 11.000000 104.116556 35 | 4.000000 22.457996 36 | 20.000000 175.735047 37 | 18.000000 165.350412 38 | 22.000000 177.461724 39 | 16.000000 138.672986 40 | 17.000000 156.791788 41 | 19.000000 150.327544 42 | 19.000000 156.992196 43 | 23.000000 163.624262 44 | 8.000000 92.537227 45 | 3.000000 32.341399 46 | 16.000000 144.445614 47 | 11.000000 119.985586 48 | 16.000000 145.149335 49 | 12.000000 113.284662 50 | 5.000000 47.742716 51 | 11.000000 115.852585 52 | 3.000000 31.579325 53 | 1.000000 43.758671 54 | 1.000000 61.049125 55 | 13.000000 132.751826 56 | 23.000000 163.233087 57 | 12.000000 115.134296 58 | 8.000000 91.370839 59 | 8.000000 86.137955 60 | 14.000000 120.857934 61 | 3.000000 33.777477 62 | 10.000000 110.831763 63 | 10.000000 104.174775 64 | 20.000000 155.920696 65 | 4.000000 30.619132 66 | 0.000000 71.880474 67 | 7.000000 86.399516 68 | 7.000000 72.632906 69 | 5.000000 58.632985 70 | 18.000000 143.584511 71 | 23.000000 187.059504 72 | 6.000000 65.067119 73 | 6.000000 69.110280 74 | 19.000000 142.388056 75 | 15.000000 137.174489 76 | 21.000000 159.719092 77 | 9.000000 102.179638 78 | 20.000000 176.416294 79 | 21.000000 146.516385 80 | 18.000000 147.808343 81 | 23.000000 154.790810 82 | 16.000000 137.385285 83 | 18.000000 166.885975 84 | 15.000000 136.989000 85 | 20.000000 144.668679 86 | 14.000000 137.060671 87 | 19.000000 140.468283 88 | 11.000000 98.344084 89 | 16.000000 132.497910 90 | 1.000000 59.143101 91 | 20.000000 152.299381 92 | 13.000000 134.487271 93 | 0.000000 77.805718 94 | 3.000000 28.543764 95 | 10.000000 97.751817 96 | 4.000000 41.223659 97 | 11.000000 110.017015 98 | 12.000000 119.391386 99 | 20.000000 158.872126 100 | 2.000000 38.776222 101 | 19.000000 150.496148 102 | 15.000000 131.505967 103 | 22.000000 179.856157 104 | 13.000000 143.090102 105 | 14.000000 142.611861 106 | 13.000000 120.757410 107 | 4.000000 27.929324 108 | 16.000000 151.530849 109 | 15.000000 148.149702 110 | 5.000000 44.188084 111 | 16.000000 141.135406 112 | 12.000000 119.817665 113 | 8.000000 80.991524 114 | 3.000000 29.308640 115 | 6.000000 48.203468 116 | 8.000000 92.179834 117 | 22.000000 162.720371 118 | 10.000000 91.971158 119 | 2.000000 33.481943 120 | 8.000000 88.528612 121 | 1.000000 54.042173 122 | 8.000000 92.002928 123 | 5.000000 45.614646 124 | 3.000000 34.319635 125 | 14.000000 129.140558 126 | 17.000000 146.807901 127 | 17.000000 157.694058 128 | 4.000000 37.080929 129 | 20.000000 169.942381 130 | 10.000000 114.675638 131 | 5.000000 34.913029 132 | 14.000000 137.889747 133 | 0.000000 79.043129 134 | 16.000000 139.084390 135 | 6.000000 53.340135 136 | 13.000000 142.772612 137 | 0.000000 73.103173 138 | 3.000000 37.717487 139 | 15.000000 134.116395 140 | 18.000000 138.748257 141 | 23.000000 180.779121 142 | 10.000000 93.721894 143 | 23.000000 166.958335 144 | 6.000000 74.473589 145 | 6.000000 73.006291 146 | 3.000000 34.178656 147 | 1.000000 33.395482 148 | 22.000000 149.933384 149 | 18.000000 154.858982 150 | 6.000000 66.121084 151 | 1.000000 60.816800 152 | 5.000000 55.681020 153 | 6.000000 61.251558 154 | 15.000000 125.452206 155 | 16.000000 134.310255 156 | 19.000000 167.999681 157 | 5.000000 40.074830 158 | 22.000000 162.658997 159 | 12.000000 109.473909 160 | 4.000000 44.743405 161 | 11.000000 122.419496 162 | 14.000000 139.852014 163 | 21.000000 160.045407 164 | 15.000000 131.999358 165 | 15.000000 135.577799 166 | 20.000000 173.494629 167 | 8.000000 82.497177 168 | 12.000000 123.122032 169 | 10.000000 97.592026 170 | 16.000000 141.345706 171 | 8.000000 79.588881 172 | 3.000000 54.308878 173 | 4.000000 36.112937 174 | 19.000000 165.005336 175 | 23.000000 172.198031 176 | 15.000000 127.699625 177 | 1.000000 47.305217 178 | 13.000000 115.489379 179 | 8.000000 103.956569 180 | 4.000000 53.669477 181 | 0.000000 76.220652 182 | 12.000000 114.153306 183 | 6.000000 74.608728 184 | 3.000000 41.339299 185 | 5.000000 21.944048 186 | 22.000000 181.455655 187 | 20.000000 171.691444 188 | 10.000000 104.299002 189 | 21.000000 168.307123 190 | 20.000000 169.556523 191 | 23.000000 175.960552 192 | 1.000000 42.554778 193 | 14.000000 137.286185 194 | 16.000000 136.126561 195 | 12.000000 119.269042 196 | 6.000000 63.426977 197 | 4.000000 27.728212 198 | 4.000000 32.687588 199 | 23.000000 151.153204 200 | 15.000000 129.767331 201 | -------------------------------------------------------------------------------- /Reinforcement Learning/Value Iteration Algorithm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Value Iteration algorithm uses the calculated utilities of all the states and compares them after an equilibrium is reached to calculate which is the best move to be taken. The algorithm reaches an equlibrium and this can be known using a stopping criteria. The stopping criteria taken is when no state's utility gets changed by much between two consecutive iterations." 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Implementing the Value Iteration algorithm:" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import numpy as np\n", 24 | "\n", 25 | "def state_utility(v, T, u, reward, gamma):\n", 26 | " \n", 27 | " #v is the state vector\n", 28 | " #T is the transition matrix\n", 29 | " #u is the utility vector\n", 30 | " #reward consists of the rewards earned for moving to a particular state\n", 31 | " #gamma is the discount factor by which rewards are discounted over the time\n", 32 | "\n", 33 | " action_array = np.zeros(4)\n", 34 | " for action in range(0, 4):\n", 35 | " action_array[action] = np.sum(np.multiply(u, np.dot(v, T[:,:,action])))\n", 36 | " return reward + gamma * np.max(action_array)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "=================== FINAL RESULT ==================\n", 49 | "Iterations: 26\n", 50 | "Delta: 9.511968687869743e-06\n", 51 | "Gamma: 0.999\n", 52 | "Epsilon: 0.01\n", 53 | "===================================================\n", 54 | "[0.80796341 0.86539911 0.91653199 1. ]\n", 55 | "[ 0.75696613 0. 0.65836281 -1. ]\n", 56 | "[0.69968168 0.64881721 0.60471137 0.3814863 ]\n", 57 | "===================================================\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "def main():\n", 63 | " \n", 64 | " tot_states = 12\n", 65 | " gamma = 0.999 \n", 66 | " iteration = 0 #Iteration counter\n", 67 | " epsilon = 0.01 #Stopping criteria given a small value\n", 68 | "\n", 69 | " #List containing the data for each iteation\n", 70 | " graph_list = list()\n", 71 | "\n", 72 | " #Transition matrix loaded from file\n", 73 | " T = np.load(\"T.npy\")\n", 74 | "\n", 75 | " #Reward vector\n", 76 | " r = np.array([-0.04, -0.04, -0.04, +1.0,\n", 77 | " -0.04, 0.0, -0.04, -1.0,\n", 78 | " -0.04, -0.04, -0.04, -0.04]) \n", 79 | "\n", 80 | " #Utility vectors\n", 81 | " u = np.array([0.0, 0.0, 0.0, 0.0,\n", 82 | " 0.0, 0.0, 0.0, 0.0,\n", 83 | " 0.0, 0.0, 0.0, 0.0])\n", 84 | " \n", 85 | " u1 = np.array([0.0, 0.0, 0.0, 0.0,\n", 86 | " 0.0, 0.0, 0.0, 0.0,\n", 87 | " 0.0, 0.0, 0.0, 0.0])\n", 88 | "\n", 89 | " while True:\n", 90 | " delta = 0\n", 91 | " u = u1.copy()\n", 92 | " iteration += 1\n", 93 | " graph_list.append(u)\n", 94 | " for s in range(tot_states):\n", 95 | " reward = r[s]\n", 96 | " v = np.zeros((1,tot_states))\n", 97 | " v[0,s] = 1.0\n", 98 | " u1[s] = state_utility(v, T, u, reward, gamma)\n", 99 | " delta = max(delta, np.abs(u1[s] - u[s])) #Stopping criteria checked \n", 100 | " \n", 101 | " if delta < epsilon * (1 - gamma) / gamma:\n", 102 | " print(\"=================== FINAL RESULT ==================\")\n", 103 | " print(\"Iterations: \" + str(iteration))\n", 104 | " print(\"Delta: \" + str(delta))\n", 105 | " print(\"Gamma: \" + str(gamma))\n", 106 | " print(\"Epsilon: \" + str(epsilon))\n", 107 | " print(\"===================================================\")\n", 108 | " print(u[0:4])\n", 109 | " print(u[4:8])\n", 110 | " print(u[8:12])\n", 111 | " print(\"===================================================\")\n", 112 | " break\n", 113 | "\n", 114 | "if __name__ == \"__main__\":\n", 115 | " main()" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.5.4" 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /support_vector_machine/svm_simple_smo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def load_data(filename): 11 | dataset, labels = [], [] 12 | with open(filename, 'r') as f: 13 | for line in f: 14 | x, y, label = [float(i) for i in line.strip().split()] 15 | dataset.append([x, y]) 16 | labels.append(label) 17 | return dataset, labels 18 | 19 | def clip(alpha, L, H): 20 | ''' 修建alpha的值到L和H之间. 21 | ''' 22 | if alpha < L: 23 | return L 24 | elif alpha > H: 25 | return H 26 | else: 27 | return alpha 28 | 29 | def select_j(i, m): 30 | ''' 在m中随机选择除了i之外剩余的数 31 | ''' 32 | l = list(range(m)) 33 | seq = l[: i] + l[i+1:] 34 | return random.choice(seq) 35 | 36 | def get_w(alphas, dataset, labels): 37 | ''' 通过已知数据点和拉格朗日乘子获得分割超平面参数w 38 | ''' 39 | alphas, dataset, labels = np.array(alphas), np.array(dataset), np.array(labels) 40 | yx = labels.reshape(1, -1).T*np.array([1, 1])*dataset 41 | w = np.dot(yx.T, alphas) 42 | 43 | return w.tolist() 44 | 45 | def simple_smo(dataset, labels, C, max_iter): 46 | ''' 简化版SMO算法实现,未使用启发式方法对alpha对进行选择. 47 | 48 | :param dataset: 所有特征数据向量 49 | :param labels: 所有的数据标签 50 | :param C: 软间隔常数, 0 <= alpha_i <= C 51 | :param max_iter: 外层循环最大迭代次数 52 | ''' 53 | dataset = np.array(dataset) 54 | m, n = dataset.shape 55 | labels = np.array(labels) 56 | 57 | # 初始化参数 58 | alphas = np.zeros(m) 59 | b = 0 60 | it = 0 61 | 62 | def f(x): 63 | "SVM分类器函数 y = w^Tx + b" 64 | # Kernel function vector. 65 | x = np.matrix(x).T 66 | data = np.matrix(dataset) 67 | ks = data*x 68 | 69 | # Predictive value. 70 | wx = np.matrix(alphas*labels)*ks 71 | fx = wx + b 72 | 73 | return fx[0, 0] 74 | 75 | all_alphas, all_bs = [], [] 76 | 77 | while it < max_iter: 78 | pair_changed = 0 79 | for i in range(m): 80 | a_i, x_i, y_i = alphas[i], dataset[i], labels[i] 81 | fx_i = f(x_i) 82 | E_i = fx_i - y_i 83 | 84 | j = select_j(i, m) 85 | a_j, x_j, y_j = alphas[j], dataset[j], labels[j] 86 | fx_j = f(x_j) 87 | E_j = fx_j - y_j 88 | 89 | K_ii, K_jj, K_ij = np.dot(x_i, x_i), np.dot(x_j, x_j), np.dot(x_i, x_j) 90 | eta = K_ii + K_jj - 2*K_ij 91 | if eta <= 0: 92 | print('WARNING eta <= 0') 93 | continue 94 | 95 | # 获取更新的alpha对 96 | a_i_old, a_j_old = a_i, a_j 97 | a_j_new = a_j_old + y_j*(E_i - E_j)/eta 98 | 99 | # 对alpha进行修剪 100 | if y_i != y_j: 101 | L = max(0, a_j_old - a_i_old) 102 | H = min(C, C + a_j_old - a_i_old) 103 | else: 104 | L = max(0, a_i_old + a_j_old - C) 105 | H = min(C, a_j_old + a_i_old) 106 | 107 | a_j_new = clip(a_j_new, L, H) 108 | a_i_new = a_i_old + y_i*y_j*(a_j_old - a_j_new) 109 | 110 | if abs(a_j_new - a_j_old) < 0.00001: 111 | print('WARNING alpha_j not moving enough') 112 | continue 113 | 114 | alphas[i], alphas[j] = a_i_new, a_j_new 115 | 116 | # 更新阈值b 117 | #import ipdb; ipdb.set_trace() 118 | b_i = -E_i - y_i*K_ii*(a_i_new - a_i_old) - y_j*K_ij*(a_j_new - a_j_old) + b 119 | b_j = -E_j - y_i*K_ij*(a_i_new - a_i_old) - y_j*K_jj*(a_j_new - a_j_old) + b 120 | 121 | if 0 < a_i_new < C: 122 | b = b_i 123 | elif 0 < a_j_new < C: 124 | b = b_j 125 | else: 126 | b = (b_i + b_j)/2 127 | 128 | all_alphas.append(alphas) 129 | all_bs.append(b) 130 | 131 | pair_changed += 1 132 | print('INFO iteration:{} i:{} pair_changed:{}'.format(it, i, pair_changed)) 133 | 134 | if pair_changed == 0: 135 | it += 1 136 | else: 137 | it = 0 138 | print('iteration number: {}'.format(it)) 139 | 140 | return alphas, b 141 | 142 | if '__main__' == __name__: 143 | # 加载训练数据 144 | dataset, labels = load_data('testSet.txt') 145 | # 使用简化版SMO算法优化SVM 146 | alphas, b = simple_smo(dataset, labels, 0.6, 40) 147 | 148 | # 分类数据点 149 | classified_pts = {'+1': [], '-1': []} 150 | for point, label in zip(dataset, labels): 151 | if label == 1.0: 152 | classified_pts['+1'].append(point) 153 | else: 154 | classified_pts['-1'].append(point) 155 | 156 | fig = plt.figure() 157 | ax = fig.add_subplot(111) 158 | 159 | # 绘制数据点 160 | for label, pts in classified_pts.items(): 161 | pts = np.array(pts) 162 | ax.scatter(pts[:, 0], pts[:, 1], label=label) 163 | 164 | # 绘制分割线 165 | w = get_w(alphas, dataset, labels) 166 | x1, _ = max(dataset, key=lambda x: x[0]) 167 | x2, _ = min(dataset, key=lambda x: x[0]) 168 | a1, a2 = w 169 | y1, y2 = (-b - a1*x1)/a2, (-b - a1*x2)/a2 170 | ax.plot([x1, x2], [y1, y2]) 171 | 172 | # 绘制支持向量 173 | for i, alpha in enumerate(alphas): 174 | if abs(alpha) > 1e-3: 175 | x, y = dataset[i] 176 | ax.scatter([x], [y], s=150, c='none', alpha=0.7, 177 | linewidth=1.5, edgecolor='#AB3319') 178 | 179 | plt.show() 180 | 181 | -------------------------------------------------------------------------------- /classification_and_regression_trees/regression_tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ''' 回归树实现 5 | ''' 6 | 7 | import uuid 8 | from functools import namedtuple 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def load_data(filename): 15 | ''' 加载文本文件中的数据. 16 | ''' 17 | dataset = [] 18 | with open(filename, 'r') as f: 19 | for line in f: 20 | line_data = [float(data) for data in line.split()] 21 | dataset.append(line_data) 22 | return dataset 23 | 24 | def split_dataset(dataset, feat_idx, value): 25 | ''' 根据给定的特征编号和特征值对数据集进行分割 26 | ''' 27 | ldata, rdata = [], [] 28 | for data in dataset: 29 | if data[feat_idx] < value: 30 | ldata.append(data) 31 | else: 32 | rdata.append(data) 33 | return ldata, rdata 34 | 35 | def create_tree(dataset, fleaf, ferr, opt=None): 36 | ''' 递归创建树结构 37 | 38 | dataset: 待划分的数据集 39 | fleaf: 创建叶子节点的函数 40 | ferr: 计算数据误差的函数 41 | opt: 回归树参数. 42 | err_tolerance: 最小误差下降值; 43 | n_tolerance: 数据切分最小样本数 44 | ''' 45 | if opt is None: 46 | opt = {'err_tolerance': 1, 'n_tolerance': 4} 47 | 48 | # 选择最优化分特征和特征值 49 | feat_idx, value = choose_best_feature(dataset, fleaf, ferr, opt) 50 | 51 | # 触底条件 52 | if feat_idx is None: 53 | return value 54 | 55 | # 创建回归树 56 | tree = {'feat_idx': feat_idx, 'feat_val': value} 57 | 58 | # 递归创建左子树和右子树 59 | ldata, rdata = split_dataset(dataset, feat_idx, value) 60 | ltree = create_tree(ldata, fleaf, ferr, opt) 61 | rtree = create_tree(rdata, fleaf, ferr, opt) 62 | tree['left'] = ltree 63 | tree['right'] = rtree 64 | 65 | return tree 66 | 67 | def fleaf(dataset): 68 | ''' 计算给定数据的叶节点数值, 这里为均值 69 | ''' 70 | dataset = np.array(dataset) 71 | return np.mean(dataset[:, -1]) 72 | 73 | def ferr(dataset): 74 | ''' 计算数据集的误差. 75 | ''' 76 | dataset = np.array(dataset) 77 | m, _ = dataset.shape 78 | return np.var(dataset[:, -1])*dataset.shape[0] 79 | 80 | def choose_best_feature(dataset, fleaf, ferr, opt): 81 | ''' 选取最佳分割特征和特征值 82 | 83 | dataset: 待划分的数据集 84 | fleaf: 创建叶子节点的函数 85 | ferr: 计算数据误差的函数 86 | opt: 回归树参数. 87 | err_tolerance: 最小误差下降值; 88 | n_tolerance: 数据切分最小样本数 89 | ''' 90 | dataset = np.array(dataset) 91 | m, n = dataset.shape 92 | err_tolerance, n_tolerance = opt['err_tolerance'], opt['n_tolerance'] 93 | 94 | err = ferr(dataset) 95 | best_feat_idx, best_feat_val, best_err = 0, 0, float('inf') 96 | 97 | # 遍历所有特征 98 | for feat_idx in range(n-1): 99 | values = dataset[:, feat_idx] 100 | # 遍历所有特征值 101 | for val in values: 102 | # 按照当前特征和特征值分割数据 103 | ldata, rdata = split_dataset(dataset.tolist(), feat_idx, val) 104 | if len(ldata) < n_tolerance or len(rdata) < n_tolerance: 105 | # 如果切分的样本量太小 106 | continue 107 | 108 | # 计算误差 109 | new_err = ferr(ldata) + ferr(rdata) 110 | if new_err < best_err: 111 | best_feat_idx = feat_idx 112 | best_feat_val = val 113 | best_err = new_err 114 | 115 | # 如果误差变化并不大归为一类 116 | if abs(err - best_err) < err_tolerance: 117 | return None, fleaf(dataset) 118 | 119 | # 检查分割样本量是不是太小 120 | ldata, rdata = split_dataset(dataset.tolist(), best_feat_idx, best_feat_val) 121 | if len(ldata) < n_tolerance or len(rdata) < n_tolerance: 122 | return None, fleaf(dataset) 123 | 124 | return best_feat_idx, best_feat_val 125 | 126 | def get_nodes_edges(tree, root_node=None): 127 | ''' 返回树中所有节点和边 128 | ''' 129 | Node = namedtuple('Node', ['id', 'label']) 130 | Edge = namedtuple('Edge', ['start', 'end']) 131 | 132 | nodes, edges = [], [] 133 | 134 | if type(tree) is not dict: 135 | return nodes, edges 136 | 137 | if root_node is None: 138 | label = '{}: {}'.format(tree['feat_idx'], tree['feat_val']) 139 | root_node = Node._make([uuid.uuid4(), label]) 140 | nodes.append(root_node) 141 | 142 | for sub_tree in (tree['left'], tree['right']): 143 | if type(sub_tree) is dict: 144 | node_label = '{}: {}'.format(sub_tree['feat_idx'], sub_tree['feat_val']) 145 | else: 146 | node_label = '{:.2f}'.format(sub_tree) 147 | sub_node = Node._make([uuid.uuid4(), node_label]) 148 | nodes.append(sub_node) 149 | 150 | edge = Edge._make([root_node, sub_node]) 151 | edges.append(edge) 152 | 153 | sub_nodes, sub_edges = get_nodes_edges(sub_tree, root_node=sub_node) 154 | nodes.extend(sub_nodes) 155 | edges.extend(sub_edges) 156 | 157 | return nodes, edges 158 | 159 | def dotify(tree): 160 | ''' 获取树的Graphviz Dot文件的内容 161 | ''' 162 | content = 'digraph decision_tree {\n' 163 | nodes, edges = get_nodes_edges(tree) 164 | 165 | for node in nodes: 166 | content += ' "{}" [label="{}"];\n'.format(node.id, node.label) 167 | 168 | for edge in edges: 169 | start, end = edge.start, edge.end 170 | content += ' "{}" -> "{}";\n'.format(start.id, end.id) 171 | content += '}' 172 | 173 | return content 174 | 175 | def tree_predict(data, tree): 176 | ''' 根据给定的回归树预测数据值 177 | ''' 178 | if type(tree) is not dict: 179 | return tree 180 | 181 | feat_idx, feat_val = tree['feat_idx'], tree['feat_val'] 182 | if data[feat_idx] < feat_val: 183 | sub_tree = tree['left'] 184 | else: 185 | sub_tree = tree['right'] 186 | 187 | return tree_predict(data, sub_tree) 188 | 189 | if '__main__' == __name__: 190 | datafile = 'ex0.txt' 191 | dataset = load_data(datafile) 192 | tree = create_tree(dataset, fleaf, ferr, opt={'n_tolerance': 4, 193 | 'err_tolerance': 1}) 194 | 195 | dotfile = '{}.dot'.format(datafile.split('.')[0]) 196 | with open(dotfile, 'w') as f: 197 | content = dotify(tree) 198 | f.write(content) 199 | 200 | dataset = np.array(dataset) 201 | # 绘制散点 202 | plt.scatter(dataset[:, 0], dataset[:, 1]) 203 | # 绘制回归曲线 204 | x = np.linspace(0, 1, 50) 205 | y = [tree_predict([i], tree) for i in x] 206 | plt.plot(x, y, c='r') 207 | plt.show() 208 | 209 | -------------------------------------------------------------------------------- /linear_regression/ex0.txt: -------------------------------------------------------------------------------- 1 | 1.000000 0.067732 3.176513 2 | 1.000000 0.427810 3.816464 3 | 1.000000 0.995731 4.550095 4 | 1.000000 0.738336 4.256571 5 | 1.000000 0.981083 4.560815 6 | 1.000000 0.526171 3.929515 7 | 1.000000 0.378887 3.526170 8 | 1.000000 0.033859 3.156393 9 | 1.000000 0.132791 3.110301 10 | 1.000000 0.138306 3.149813 11 | 1.000000 0.247809 3.476346 12 | 1.000000 0.648270 4.119688 13 | 1.000000 0.731209 4.282233 14 | 1.000000 0.236833 3.486582 15 | 1.000000 0.969788 4.655492 16 | 1.000000 0.607492 3.965162 17 | 1.000000 0.358622 3.514900 18 | 1.000000 0.147846 3.125947 19 | 1.000000 0.637820 4.094115 20 | 1.000000 0.230372 3.476039 21 | 1.000000 0.070237 3.210610 22 | 1.000000 0.067154 3.190612 23 | 1.000000 0.925577 4.631504 24 | 1.000000 0.717733 4.295890 25 | 1.000000 0.015371 3.085028 26 | 1.000000 0.335070 3.448080 27 | 1.000000 0.040486 3.167440 28 | 1.000000 0.212575 3.364266 29 | 1.000000 0.617218 3.993482 30 | 1.000000 0.541196 3.891471 31 | 1.000000 0.045353 3.143259 32 | 1.000000 0.126762 3.114204 33 | 1.000000 0.556486 3.851484 34 | 1.000000 0.901144 4.621899 35 | 1.000000 0.958476 4.580768 36 | 1.000000 0.274561 3.620992 37 | 1.000000 0.394396 3.580501 38 | 1.000000 0.872480 4.618706 39 | 1.000000 0.409932 3.676867 40 | 1.000000 0.908969 4.641845 41 | 1.000000 0.166819 3.175939 42 | 1.000000 0.665016 4.264980 43 | 1.000000 0.263727 3.558448 44 | 1.000000 0.231214 3.436632 45 | 1.000000 0.552928 3.831052 46 | 1.000000 0.047744 3.182853 47 | 1.000000 0.365746 3.498906 48 | 1.000000 0.495002 3.946833 49 | 1.000000 0.493466 3.900583 50 | 1.000000 0.792101 4.238522 51 | 1.000000 0.769660 4.233080 52 | 1.000000 0.251821 3.521557 53 | 1.000000 0.181951 3.203344 54 | 1.000000 0.808177 4.278105 55 | 1.000000 0.334116 3.555705 56 | 1.000000 0.338630 3.502661 57 | 1.000000 0.452584 3.859776 58 | 1.000000 0.694770 4.275956 59 | 1.000000 0.590902 3.916191 60 | 1.000000 0.307928 3.587961 61 | 1.000000 0.148364 3.183004 62 | 1.000000 0.702180 4.225236 63 | 1.000000 0.721544 4.231083 64 | 1.000000 0.666886 4.240544 65 | 1.000000 0.124931 3.222372 66 | 1.000000 0.618286 4.021445 67 | 1.000000 0.381086 3.567479 68 | 1.000000 0.385643 3.562580 69 | 1.000000 0.777175 4.262059 70 | 1.000000 0.116089 3.208813 71 | 1.000000 0.115487 3.169825 72 | 1.000000 0.663510 4.193949 73 | 1.000000 0.254884 3.491678 74 | 1.000000 0.993888 4.533306 75 | 1.000000 0.295434 3.550108 76 | 1.000000 0.952523 4.636427 77 | 1.000000 0.307047 3.557078 78 | 1.000000 0.277261 3.552874 79 | 1.000000 0.279101 3.494159 80 | 1.000000 0.175724 3.206828 81 | 1.000000 0.156383 3.195266 82 | 1.000000 0.733165 4.221292 83 | 1.000000 0.848142 4.413372 84 | 1.000000 0.771184 4.184347 85 | 1.000000 0.429492 3.742878 86 | 1.000000 0.162176 3.201878 87 | 1.000000 0.917064 4.648964 88 | 1.000000 0.315044 3.510117 89 | 1.000000 0.201473 3.274434 90 | 1.000000 0.297038 3.579622 91 | 1.000000 0.336647 3.489244 92 | 1.000000 0.666109 4.237386 93 | 1.000000 0.583888 3.913749 94 | 1.000000 0.085031 3.228990 95 | 1.000000 0.687006 4.286286 96 | 1.000000 0.949655 4.628614 97 | 1.000000 0.189912 3.239536 98 | 1.000000 0.844027 4.457997 99 | 1.000000 0.333288 3.513384 100 | 1.000000 0.427035 3.729674 101 | 1.000000 0.466369 3.834274 102 | 1.000000 0.550659 3.811155 103 | 1.000000 0.278213 3.598316 104 | 1.000000 0.918769 4.692514 105 | 1.000000 0.886555 4.604859 106 | 1.000000 0.569488 3.864912 107 | 1.000000 0.066379 3.184236 108 | 1.000000 0.335751 3.500796 109 | 1.000000 0.426863 3.743365 110 | 1.000000 0.395746 3.622905 111 | 1.000000 0.694221 4.310796 112 | 1.000000 0.272760 3.583357 113 | 1.000000 0.503495 3.901852 114 | 1.000000 0.067119 3.233521 115 | 1.000000 0.038326 3.105266 116 | 1.000000 0.599122 3.865544 117 | 1.000000 0.947054 4.628625 118 | 1.000000 0.671279 4.231213 119 | 1.000000 0.434811 3.791149 120 | 1.000000 0.509381 3.968271 121 | 1.000000 0.749442 4.253910 122 | 1.000000 0.058014 3.194710 123 | 1.000000 0.482978 3.996503 124 | 1.000000 0.466776 3.904358 125 | 1.000000 0.357767 3.503976 126 | 1.000000 0.949123 4.557545 127 | 1.000000 0.417320 3.699876 128 | 1.000000 0.920461 4.613614 129 | 1.000000 0.156433 3.140401 130 | 1.000000 0.656662 4.206717 131 | 1.000000 0.616418 3.969524 132 | 1.000000 0.853428 4.476096 133 | 1.000000 0.133295 3.136528 134 | 1.000000 0.693007 4.279071 135 | 1.000000 0.178449 3.200603 136 | 1.000000 0.199526 3.299012 137 | 1.000000 0.073224 3.209873 138 | 1.000000 0.286515 3.632942 139 | 1.000000 0.182026 3.248361 140 | 1.000000 0.621523 3.995783 141 | 1.000000 0.344584 3.563262 142 | 1.000000 0.398556 3.649712 143 | 1.000000 0.480369 3.951845 144 | 1.000000 0.153350 3.145031 145 | 1.000000 0.171846 3.181577 146 | 1.000000 0.867082 4.637087 147 | 1.000000 0.223855 3.404964 148 | 1.000000 0.528301 3.873188 149 | 1.000000 0.890192 4.633648 150 | 1.000000 0.106352 3.154768 151 | 1.000000 0.917886 4.623637 152 | 1.000000 0.014855 3.078132 153 | 1.000000 0.567682 3.913596 154 | 1.000000 0.068854 3.221817 155 | 1.000000 0.603535 3.938071 156 | 1.000000 0.532050 3.880822 157 | 1.000000 0.651362 4.176436 158 | 1.000000 0.901225 4.648161 159 | 1.000000 0.204337 3.332312 160 | 1.000000 0.696081 4.240614 161 | 1.000000 0.963924 4.532224 162 | 1.000000 0.981390 4.557105 163 | 1.000000 0.987911 4.610072 164 | 1.000000 0.990947 4.636569 165 | 1.000000 0.736021 4.229813 166 | 1.000000 0.253574 3.500860 167 | 1.000000 0.674722 4.245514 168 | 1.000000 0.939368 4.605182 169 | 1.000000 0.235419 3.454340 170 | 1.000000 0.110521 3.180775 171 | 1.000000 0.218023 3.380820 172 | 1.000000 0.869778 4.565020 173 | 1.000000 0.196830 3.279973 174 | 1.000000 0.958178 4.554241 175 | 1.000000 0.972673 4.633520 176 | 1.000000 0.745797 4.281037 177 | 1.000000 0.445674 3.844426 178 | 1.000000 0.470557 3.891601 179 | 1.000000 0.549236 3.849728 180 | 1.000000 0.335691 3.492215 181 | 1.000000 0.884739 4.592374 182 | 1.000000 0.918916 4.632025 183 | 1.000000 0.441815 3.756750 184 | 1.000000 0.116598 3.133555 185 | 1.000000 0.359274 3.567919 186 | 1.000000 0.814811 4.363382 187 | 1.000000 0.387125 3.560165 188 | 1.000000 0.982243 4.564305 189 | 1.000000 0.780880 4.215055 190 | 1.000000 0.652565 4.174999 191 | 1.000000 0.870030 4.586640 192 | 1.000000 0.604755 3.960008 193 | 1.000000 0.255212 3.529963 194 | 1.000000 0.730546 4.213412 195 | 1.000000 0.493829 3.908685 196 | 1.000000 0.257017 3.585821 197 | 1.000000 0.833735 4.374394 198 | 1.000000 0.070095 3.213817 199 | 1.000000 0.527070 3.952681 200 | 1.000000 0.116163 3.129283 201 | -------------------------------------------------------------------------------- /linear_regression/ex1.txt: -------------------------------------------------------------------------------- 1 | 1.000000 0.635975 4.093119 2 | 1.000000 0.552438 3.804358 3 | 1.000000 0.855922 4.456531 4 | 1.000000 0.083386 3.187049 5 | 1.000000 0.975802 4.506176 6 | 1.000000 0.181269 3.171914 7 | 1.000000 0.129156 3.053996 8 | 1.000000 0.605648 3.974659 9 | 1.000000 0.301625 3.542525 10 | 1.000000 0.698805 4.234199 11 | 1.000000 0.226419 3.405937 12 | 1.000000 0.519290 3.932469 13 | 1.000000 0.354424 3.514051 14 | 1.000000 0.118380 3.105317 15 | 1.000000 0.512811 3.843351 16 | 1.000000 0.236795 3.576074 17 | 1.000000 0.353509 3.544471 18 | 1.000000 0.481447 3.934625 19 | 1.000000 0.060509 3.228226 20 | 1.000000 0.174090 3.300232 21 | 1.000000 0.806818 4.331785 22 | 1.000000 0.531462 3.908166 23 | 1.000000 0.853167 4.386918 24 | 1.000000 0.304804 3.617260 25 | 1.000000 0.612021 4.082411 26 | 1.000000 0.620880 3.949470 27 | 1.000000 0.580245 3.984041 28 | 1.000000 0.742443 4.251907 29 | 1.000000 0.110770 3.115214 30 | 1.000000 0.742687 4.234319 31 | 1.000000 0.574390 3.947544 32 | 1.000000 0.986378 4.532519 33 | 1.000000 0.294867 3.510392 34 | 1.000000 0.472125 3.927832 35 | 1.000000 0.872321 4.631825 36 | 1.000000 0.843537 4.482263 37 | 1.000000 0.864577 4.487656 38 | 1.000000 0.341874 3.486371 39 | 1.000000 0.097980 3.137514 40 | 1.000000 0.757874 4.212660 41 | 1.000000 0.877656 4.506268 42 | 1.000000 0.457993 3.800973 43 | 1.000000 0.475341 3.975979 44 | 1.000000 0.848391 4.494447 45 | 1.000000 0.746059 4.244715 46 | 1.000000 0.153462 3.019251 47 | 1.000000 0.694256 4.277945 48 | 1.000000 0.498712 3.812414 49 | 1.000000 0.023580 3.116973 50 | 1.000000 0.976826 4.617363 51 | 1.000000 0.624004 4.005158 52 | 1.000000 0.472220 3.874188 53 | 1.000000 0.390551 3.630228 54 | 1.000000 0.021349 3.145849 55 | 1.000000 0.173488 3.192618 56 | 1.000000 0.971028 4.540226 57 | 1.000000 0.595302 3.835879 58 | 1.000000 0.097638 3.141948 59 | 1.000000 0.745972 4.323316 60 | 1.000000 0.676390 4.204829 61 | 1.000000 0.488949 3.946710 62 | 1.000000 0.982873 4.666332 63 | 1.000000 0.296060 3.482348 64 | 1.000000 0.228008 3.451286 65 | 1.000000 0.671059 4.186388 66 | 1.000000 0.379419 3.595223 67 | 1.000000 0.285170 3.534446 68 | 1.000000 0.236314 3.420891 69 | 1.000000 0.629803 4.115553 70 | 1.000000 0.770272 4.257463 71 | 1.000000 0.493052 3.934798 72 | 1.000000 0.631592 4.154963 73 | 1.000000 0.965676 4.587470 74 | 1.000000 0.598675 3.944766 75 | 1.000000 0.351997 3.480517 76 | 1.000000 0.342001 3.481382 77 | 1.000000 0.661424 4.253286 78 | 1.000000 0.140912 3.131670 79 | 1.000000 0.373574 3.527099 80 | 1.000000 0.223166 3.378051 81 | 1.000000 0.908785 4.578960 82 | 1.000000 0.915102 4.551773 83 | 1.000000 0.410940 3.634259 84 | 1.000000 0.754921 4.167016 85 | 1.000000 0.764453 4.217570 86 | 1.000000 0.101534 3.237201 87 | 1.000000 0.780368 4.353163 88 | 1.000000 0.819868 4.342184 89 | 1.000000 0.173990 3.236950 90 | 1.000000 0.330472 3.509404 91 | 1.000000 0.162656 3.242535 92 | 1.000000 0.476283 3.907937 93 | 1.000000 0.636391 4.108455 94 | 1.000000 0.758737 4.181959 95 | 1.000000 0.778372 4.251103 96 | 1.000000 0.936287 4.538462 97 | 1.000000 0.510904 3.848193 98 | 1.000000 0.515737 3.974757 99 | 1.000000 0.437823 3.708323 100 | 1.000000 0.828607 4.385210 101 | 1.000000 0.556100 3.927788 102 | 1.000000 0.038209 3.187881 103 | 1.000000 0.321993 3.444542 104 | 1.000000 0.067288 3.199263 105 | 1.000000 0.774989 4.285745 106 | 1.000000 0.566077 3.878557 107 | 1.000000 0.796314 4.155745 108 | 1.000000 0.746600 4.197772 109 | 1.000000 0.360778 3.524928 110 | 1.000000 0.397321 3.525692 111 | 1.000000 0.062142 3.211318 112 | 1.000000 0.379250 3.570495 113 | 1.000000 0.248238 3.462431 114 | 1.000000 0.682561 4.206177 115 | 1.000000 0.355393 3.562322 116 | 1.000000 0.889051 4.595215 117 | 1.000000 0.733806 4.182694 118 | 1.000000 0.153949 3.320695 119 | 1.000000 0.036104 3.122670 120 | 1.000000 0.388577 3.541312 121 | 1.000000 0.274481 3.502135 122 | 1.000000 0.319401 3.537559 123 | 1.000000 0.431653 3.712609 124 | 1.000000 0.960398 4.504875 125 | 1.000000 0.083660 3.262164 126 | 1.000000 0.122098 3.105583 127 | 1.000000 0.415299 3.742634 128 | 1.000000 0.854192 4.566589 129 | 1.000000 0.925574 4.630884 130 | 1.000000 0.109306 3.190539 131 | 1.000000 0.805161 4.289105 132 | 1.000000 0.344474 3.406602 133 | 1.000000 0.769116 4.251899 134 | 1.000000 0.182003 3.183214 135 | 1.000000 0.225972 3.342508 136 | 1.000000 0.413088 3.747926 137 | 1.000000 0.964444 4.499998 138 | 1.000000 0.203334 3.350089 139 | 1.000000 0.285574 3.539554 140 | 1.000000 0.850209 4.443465 141 | 1.000000 0.061561 3.290370 142 | 1.000000 0.426935 3.733302 143 | 1.000000 0.389376 3.614803 144 | 1.000000 0.096918 3.175132 145 | 1.000000 0.148938 3.164284 146 | 1.000000 0.893738 4.619629 147 | 1.000000 0.195527 3.426648 148 | 1.000000 0.407248 3.670722 149 | 1.000000 0.224357 3.412571 150 | 1.000000 0.045963 3.110330 151 | 1.000000 0.944647 4.647928 152 | 1.000000 0.756552 4.164515 153 | 1.000000 0.432098 3.730603 154 | 1.000000 0.990511 4.609868 155 | 1.000000 0.649699 4.094111 156 | 1.000000 0.584879 3.907636 157 | 1.000000 0.785934 4.240814 158 | 1.000000 0.029945 3.106915 159 | 1.000000 0.075747 3.201181 160 | 1.000000 0.408408 3.872302 161 | 1.000000 0.583851 3.860890 162 | 1.000000 0.497759 3.884108 163 | 1.000000 0.421301 3.696816 164 | 1.000000 0.140320 3.114540 165 | 1.000000 0.546465 3.791233 166 | 1.000000 0.843181 4.443487 167 | 1.000000 0.295390 3.535337 168 | 1.000000 0.825059 4.417975 169 | 1.000000 0.946343 4.742471 170 | 1.000000 0.350404 3.470964 171 | 1.000000 0.042787 3.113381 172 | 1.000000 0.352487 3.594600 173 | 1.000000 0.590736 3.914875 174 | 1.000000 0.120748 3.108492 175 | 1.000000 0.143140 3.152725 176 | 1.000000 0.511926 3.994118 177 | 1.000000 0.496358 3.933417 178 | 1.000000 0.382802 3.510829 179 | 1.000000 0.252464 3.498402 180 | 1.000000 0.845894 4.460441 181 | 1.000000 0.132023 3.245277 182 | 1.000000 0.442301 3.771067 183 | 1.000000 0.266889 3.434771 184 | 1.000000 0.008575 2.999612 185 | 1.000000 0.897632 4.454221 186 | 1.000000 0.533171 3.985348 187 | 1.000000 0.285243 3.557982 188 | 1.000000 0.377258 3.625972 189 | 1.000000 0.486995 3.922226 190 | 1.000000 0.305993 3.547421 191 | 1.000000 0.277528 3.580944 192 | 1.000000 0.750899 4.268081 193 | 1.000000 0.694756 4.278096 194 | 1.000000 0.870158 4.517640 195 | 1.000000 0.276457 3.555461 196 | 1.000000 0.017761 3.055026 197 | 1.000000 0.802046 4.354819 198 | 1.000000 0.559275 3.894387 199 | 1.000000 0.941305 4.597773 200 | 1.000000 0.856877 4.523616 201 | -------------------------------------------------------------------------------- /decision_tree/trees.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author: PytLab 4 | # Date: 2017-07-07 5 | 6 | import copy 7 | import uuid 8 | import pickle 9 | from collections import defaultdict, namedtuple 10 | from math import log2 11 | 12 | 13 | class DecisionTreeClassifier(object): 14 | ''' 使用ID3算法划分数据集的决策树分类器 15 | ''' 16 | 17 | @staticmethod 18 | def split_dataset(dataset, classes, feat_idx): 19 | ''' 根据某个特征以及特征值划分数据集 20 | 21 | :param dataset: 待划分的数据集, 有数据向量组成的列表. 22 | :param classes: 数据集对应的类型, 与数据集有相同的长度 23 | :param feat_idx: 特征在特征向量中的索引 24 | 25 | :param splited_dict: 保存分割后数据的字典 特征值: [子数据集, 子类型列表] 26 | ''' 27 | splited_dict = {} 28 | for data_vect, cls in zip(dataset, classes): 29 | feat_val = data_vect[feat_idx] 30 | sub_dataset, sub_classes = splited_dict.setdefault(feat_val, [[], []]) 31 | sub_dataset.append(data_vect[: feat_idx] + data_vect[feat_idx+1: ]) 32 | sub_classes.append(cls) 33 | 34 | return splited_dict 35 | 36 | def get_shanno_entropy(self, values): 37 | ''' 根据给定列表中的值计算其Shanno Entropy 38 | ''' 39 | uniq_vals = set(values) 40 | val_nums = {key: values.count(key) for key in uniq_vals} 41 | probs = [v/len(values) for k, v in val_nums.items()] 42 | entropy = sum([-prob*log2(prob) for prob in probs]) 43 | return entropy 44 | 45 | def choose_best_split_feature(self, dataset, classes): 46 | ''' 根据信息增益确定最好的划分数据的特征 47 | 48 | :param dataset: 待划分的数据集 49 | :param classes: 数据集对应的类型 50 | 51 | :return: 划分数据的增益最大的属性索引 52 | ''' 53 | base_entropy = self.get_shanno_entropy(classes) 54 | 55 | feat_num = len(dataset[0]) 56 | entropy_gains = [] 57 | for i in range(feat_num): 58 | splited_dict = self.split_dataset(dataset, classes, i) 59 | new_entropy = sum([ 60 | len(sub_classes)/len(classes)*self.get_shanno_entropy(sub_classes) 61 | for _, (_, sub_classes) in splited_dict.items() 62 | ]) 63 | entropy_gains.append(base_entropy - new_entropy) 64 | 65 | return entropy_gains.index(max(entropy_gains)) 66 | 67 | def get_majority(classes): 68 | ''' 返回类型中占据大多数的类型 69 | ''' 70 | cls_num = defaultdict(lambda: 0) 71 | for cls in classes: 72 | cls_num[cls] += 1 73 | 74 | return max(cls_num, key=cls_num.get) 75 | 76 | def create_tree(self, dataset, classes, feat_names): 77 | ''' 根据当前数据集递归创建决策树 78 | 79 | :param dataset: 数据集 80 | :param feat_names: 数据集中数据相应的特征名称 81 | :param classes: 数据集中数据相应的类型 82 | 83 | :param tree: 以字典形式返回决策树 84 | ''' 85 | # 如果数据集中只有一种类型停止树分裂 86 | if len(set(classes)) == 1: 87 | return classes[0] 88 | 89 | # 如果遍历完所有特征,返回比例最多的类型 90 | if len(feat_names) == 0: 91 | return get_majority(classes) 92 | 93 | # 分裂创建新的子树 94 | tree = {} 95 | best_feat_idx = self.choose_best_split_feature(dataset, classes) 96 | feature = feat_names[best_feat_idx] 97 | tree[feature] = {} 98 | 99 | # 创建用于递归创建子树的子数据集 100 | sub_feat_names = feat_names[:] 101 | sub_feat_names.pop(best_feat_idx) 102 | 103 | splited_dict = self.split_dataset(dataset, classes, best_feat_idx) 104 | for feat_val, (sub_dataset, sub_classes) in splited_dict.items(): 105 | tree[feature][feat_val] = self.create_tree(sub_dataset, 106 | sub_classes, 107 | sub_feat_names) 108 | self.tree = tree 109 | self.feat_names = feat_names 110 | 111 | return tree 112 | 113 | def get_nodes_edges(self, tree=None, root_node=None): 114 | ''' 返回树中所有节点和边 115 | ''' 116 | Node = namedtuple('Node', ['id', 'label']) 117 | Edge = namedtuple('Edge', ['start', 'end', 'label']) 118 | 119 | if tree is None: 120 | tree = self.tree 121 | 122 | if type(tree) is not dict: 123 | return [], [] 124 | 125 | nodes, edges = [], [] 126 | 127 | if root_node is None: 128 | label = list(tree.keys())[0] 129 | root_node = Node._make([uuid.uuid4(), label]) 130 | nodes.append(root_node) 131 | 132 | for edge_label, sub_tree in tree[root_node.label].items(): 133 | node_label = list(sub_tree.keys())[0] if type(sub_tree) is dict else sub_tree 134 | sub_node = Node._make([uuid.uuid4(), node_label]) 135 | nodes.append(sub_node) 136 | 137 | edge = Edge._make([root_node, sub_node, edge_label]) 138 | edges.append(edge) 139 | 140 | sub_nodes, sub_edges = self.get_nodes_edges(sub_tree, root_node=sub_node) 141 | nodes.extend(sub_nodes) 142 | edges.extend(sub_edges) 143 | 144 | return nodes, edges 145 | 146 | def dotify(self, tree=None): 147 | ''' 获取树的Graphviz Dot文件的内容 148 | ''' 149 | if tree is None: 150 | tree = self.tree 151 | 152 | content = 'digraph decision_tree {\n' 153 | nodes, edges = self.get_nodes_edges(tree) 154 | 155 | for node in nodes: 156 | content += ' "{}" [label="{}"];\n'.format(node.id, node.label) 157 | 158 | for edge in edges: 159 | start, label, end = edge.start, edge.label, edge.end 160 | content += ' "{}" -> "{}" [label="{}"];\n'.format(start.id, end.id, label) 161 | content += '}' 162 | 163 | return content 164 | 165 | def classify(self, data_vect, feat_names=None, tree=None): 166 | ''' 根据构建的决策树对数据进行分类 167 | ''' 168 | if tree is None: 169 | tree = self.tree 170 | 171 | if feat_names is None: 172 | feat_names = self.feat_names 173 | 174 | # Recursive base case. 175 | if type(tree) is not dict: 176 | return tree 177 | 178 | feature = list(tree.keys())[0] 179 | value = data_vect[feat_names.index(feature)] 180 | sub_tree = tree[feature][value] 181 | 182 | return self.classify(data_vect, feat_names, sub_tree) 183 | 184 | def dump_tree(self, filename, tree=None): 185 | ''' 存储决策树 186 | ''' 187 | if tree is None: 188 | tree = self.tree 189 | 190 | with open(filename, 'wb') as f: 191 | pickle.dump(tree, f) 192 | 193 | def load_tree(self, filename): 194 | ''' 加载树结构 195 | ''' 196 | with open(filename, 'rb') as f: 197 | tree = pickle.load(f) 198 | self.tree = tree 199 | return tree 200 | 201 | -------------------------------------------------------------------------------- /support_vector_machine/svm_platt_smo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class SVMUtil(object): 11 | ''' 12 | Struct to save all important values in SVM. 13 | ''' 14 | def __init__(self, dataset, labels, C, tolerance=0.001): 15 | self.dataset, self.labels, self.C = dataset, labels, C 16 | 17 | self.m, self.n = np.array(dataset).shape 18 | self.alphas = np.zeros(self.m) 19 | self.b = 0 20 | self.tolerance = tolerance 21 | # Cached errors ,f(x_i) - y_i 22 | self.errors = [self.get_error(i) for i in range(self.m)] 23 | 24 | def f(self, x): 25 | '''SVM分类器函数 y = w^Tx + b 26 | ''' 27 | # Kernel function vector. 28 | x = np.matrix(x).T 29 | data = np.matrix(self.dataset) 30 | ks = data*x 31 | 32 | # Predictive value. 33 | wx = np.matrix(self.alphas*self.labels)*ks 34 | fx = wx + self.b 35 | 36 | return fx[0, 0] 37 | 38 | def get_error(self, i): 39 | ''' 获取第i个数据对应的误差. 40 | ''' 41 | x, y = self.dataset[i], self.labels[i] 42 | fx = self.f(x) 43 | return fx - y 44 | 45 | def update_errors(self): 46 | ''' 更新所有的误差值. 47 | ''' 48 | self.errors = [self.get_error(i) for i in range(self.m)] 49 | 50 | def meet_kkt(self, i): 51 | alpha = self.alphas[i] 52 | x = self.dataset[i] 53 | 54 | if alpha == 0: 55 | return self.f(x) >= 1 56 | elif alpha == self.C: 57 | return self.f(x) <= 1 58 | else: 59 | return self.f(x) == 1 60 | 61 | def load_data(filename): 62 | dataset, labels = [], [] 63 | with open(filename, 'r') as f: 64 | for line in f: 65 | x, y, label = [float(i) for i in line.strip().split()] 66 | dataset.append([x, y]) 67 | labels.append(label) 68 | return dataset, labels 69 | 70 | def clip(alpha, L, H): 71 | ''' 修建alpha的值到L和H之间. 72 | ''' 73 | if alpha < L: 74 | return L 75 | elif alpha > H: 76 | return H 77 | else: 78 | return alpha 79 | 80 | def select_j_rand(i, m): 81 | ''' 在m中随机选择除了i之外剩余的数 82 | ''' 83 | l = list(range(m)) 84 | seq = l[: i] + l[i+1:] 85 | return random.choice(seq) 86 | 87 | def select_j(i, svm_util): 88 | ''' 通过最大化步长的方式来获取第二个alpha值的索引. 89 | ''' 90 | errors = svm_util.errors 91 | valid_indices = [i for i, a in enumerate(svm_util.alphas) if 0 < a < svm_util.C] 92 | 93 | if len(valid_indices) > 1: 94 | j = -1 95 | max_delta = 0 96 | for k in valid_indices: 97 | if k == i: 98 | continue 99 | delta = abs(errors[i] - errors[j]) 100 | if delta > max_delta: 101 | j = k 102 | max_delta = delta 103 | else: 104 | j = select_j_rand(i, svm_util.m) 105 | return j 106 | 107 | def get_w(alphas, dataset, labels): 108 | ''' 通过已知数据点和拉格朗日乘子获得分割超平面参数w 109 | ''' 110 | alphas, dataset, labels = np.array(alphas), np.array(dataset), np.array(labels) 111 | yx = labels.reshape(1, -1).T*np.array([1, 1])*dataset 112 | w = np.dot(yx.T, alphas) 113 | 114 | return w.tolist() 115 | 116 | def take_step(i, j, svm_util): 117 | ''' 对选定的一对alpha对进行优化. 118 | ''' 119 | svm_util.update_errors() 120 | alphas, dataset, labels = svm_util.alphas, svm_util.dataset, svm_util.labels 121 | C, b = svm_util.C, svm_util.b 122 | 123 | a_i, x_i, y_i, E_i = alphas[i], dataset[i], labels[i], svm_util.errors[i] 124 | a_j, x_j, y_j, E_j = alphas[j], dataset[j], labels[j], svm_util.errors[j] 125 | 126 | K_ii, K_jj, K_ij = np.dot(x_i, x_i), np.dot(x_j, x_j), np.dot(x_i, x_j) 127 | eta = K_ii + K_jj - 2*K_ij 128 | if eta <= 0: 129 | print('WARNING eta <= 0') 130 | return 0 131 | 132 | a_i_old, a_j_old = a_i, a_j 133 | a_j_new = a_j_old + y_j*(E_i - E_j)/eta 134 | 135 | # 对alpha进行修剪 136 | if y_i != y_j: 137 | L = max(0, a_j_old - a_i_old) 138 | H = min(C, C + a_j_old - a_i_old) 139 | else: 140 | L = max(0, a_i_old + a_j_old - C) 141 | H = min(C, a_j_old + a_i_old) 142 | 143 | a_j_new = clip(a_j_new, L, H) 144 | a_i_new = a_i_old + y_i*y_j*(a_j_old - a_j_new) 145 | 146 | if abs(a_j_new - a_j_old) < 0.00001: 147 | #print('WARNING alpha_j not moving enough') 148 | return 0 149 | 150 | alphas[i], alphas[j] = a_i_new, a_j_new 151 | svm_util.update_errors() 152 | 153 | # 更新阈值b 154 | b_i = -E_i - y_i*K_ii*(a_i_new - a_i_old) - y_j*K_ij*(a_j_new - a_j_old) + b 155 | b_j = -E_j - y_i*K_ij*(a_i_new - a_i_old) - y_j*K_jj*(a_j_new - a_j_old) + b 156 | 157 | if 0 < a_i_new < C: 158 | b = b_i 159 | elif 0 < a_j_new < C: 160 | b = b_j 161 | else: 162 | b = (b_i + b_j)/2 163 | 164 | svm_util.b = b 165 | print(b) 166 | 167 | return 1 168 | 169 | def examine_example(i, svm_util): 170 | ''' 给定第一个alpha, 检测对应alpha是否符合KKT条件并选取第二个alpha进行迭代. 171 | ''' 172 | E_i, y_i, alpha = svm_util.errors[i], svm_util.labels[i], svm_util.alphas[i] 173 | r = E_i*y_i 174 | C, tolerance = svm_util.C, svm_util.tolerance 175 | 176 | # 是否违反KKT条件 177 | if (r < -tolerance and alpha < C) or (r > tolerance and alpha > 0): 178 | j = select_j(i, svm_util) 179 | #j = select_j_rand(i, svm_util.m) 180 | return take_step(i, j, svm_util) 181 | else: 182 | return 0 183 | 184 | def platt_smo(dataset, labels, C, max_iter): 185 | ''' Platt SMO算法实现,使用启发式方法对alpha对进行选择. 186 | 187 | :param dataset: 所有特征数据向量 188 | :param labels: 所有的数据标签 189 | :param C: 软间隔常数, 0 <= alpha_i <= C 190 | :param max_iter: 外层循环最大迭代次数 191 | ''' 192 | # 初始化SVM工具对象 193 | svm_util = SVMUtil(dataset, labels, C) 194 | it = 0 195 | 196 | # 遍历所有alpha的标记 197 | entire = True 198 | 199 | pair_changed = 0 200 | while (it < max_iter): #and (pair_changed > 0 or entire): 201 | pair_changed = 0 202 | if entire: 203 | for i in range(svm_util.m): 204 | pair_changed += examine_example(i, svm_util) 205 | print('Full set - iter: {}, pair changed: {}'.format(i, pair_changed)) 206 | else: 207 | alphas = svm_util.alphas 208 | non_bound_indices = [i for i in range(svm_util.m) 209 | if alphas[i] > 0 and alphas[i] < C] 210 | for i in non_bound_indices: 211 | pair_changed += examine_example(i, svm_util) 212 | print('Non-bound - iter:{}, pair changed: {}'.format(i, pair_changed)) 213 | it += 1 214 | 215 | if entire: 216 | entire = False 217 | elif pair_changed == 0: 218 | entire = True 219 | 220 | print('iteration number: {}'.format(it)) 221 | 222 | return svm_util.alphas, svm_util.b 223 | 224 | if '__main__' == __name__: 225 | # 加载训练数据 226 | dataset, labels = load_data('testSet.txt') 227 | # 使用简化版SMO算法优化SVM 228 | alphas, b = platt_smo(dataset, labels, 0.8, 40) 229 | 230 | # 分类数据点 231 | classified_pts = {'+1': [], '-1': []} 232 | for point, label in zip(dataset, labels): 233 | if label == 1.0: 234 | classified_pts['+1'].append(point) 235 | else: 236 | classified_pts['-1'].append(point) 237 | 238 | fig = plt.figure() 239 | ax = fig.add_subplot(111) 240 | 241 | # 绘制数据点 242 | for label, pts in classified_pts.items(): 243 | pts = np.array(pts) 244 | ax.scatter(pts[:, 0], pts[:, 1], label=label) 245 | 246 | # 绘制分割线 247 | w = get_w(alphas, dataset, labels) 248 | x1, _ = max(dataset, key=lambda x: x[0]) 249 | x2, _ = min(dataset, key=lambda x: x[0]) 250 | a1, a2 = w 251 | y1, y2 = (-b - a1*x1)/a2, (-b - a1*x2)/a2 252 | ax.plot([x1, x2], [y1, y2]) 253 | 254 | # 绘制支持向量 255 | for i, alpha in enumerate(alphas): 256 | if abs(alpha) > 1e-3: 257 | x, y = dataset[i] 258 | ax.scatter([x], [y], s=150, c='none', alpha=0.7, 259 | linewidth=1.5, edgecolor='#AB3319') 260 | 261 | plt.show() 262 | 263 | -------------------------------------------------------------------------------- /linear_regression/lasso_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Iteration: 1, delta = 1466.2236239668678\n", 22 | "Iteration: 2, delta = 178.09595140235479\n", 23 | "Iteration: 3, delta = 122.9449681776814\n", 24 | "Iteration: 4, delta = 87.31358567622465\n", 25 | "Iteration: 5, delta = 62.69665876827412\n", 26 | "Iteration: 6, delta = 45.42239934676854\n", 27 | "Iteration: 7, delta = 33.158544205951785\n", 28 | "Iteration: 8, delta = 24.39248658850147\n", 29 | "Iteration: 9, delta = 18.094550266671376\n", 30 | "Iteration: 10, delta = 13.54904027267321\n", 31 | "Iteration: 11, delta = 10.252379488046245\n", 32 | "Iteration: 12, delta = 7.848669709154365\n", 33 | "Iteration: 13, delta = 6.086454636246799\n", 34 | "Iteration: 14, delta = 4.788458307723886\n", 35 | "Iteration: 15, delta = 3.8299053756531976\n", 36 | "Iteration: 16, delta = 3.122873178243026\n", 37 | "Iteration: 17, delta = 2.6050408381288435\n", 38 | "Iteration: 18, delta = 2.2316904209010318\n", 39 | "Iteration: 19, delta = 1.9701101960431515\n", 40 | "Iteration: 20, delta = 1.795752091414215\n", 41 | "Iteration: 21, delta = 1.6896477099976437\n", 42 | "Iteration: 22, delta = 1.6367076550845923\n", 43 | "Iteration: 23, delta = 1.6246251283932907\n", 44 | "Iteration: 24, delta = 1.6431809395098753\n", 45 | "Iteration: 25, delta = 1.6838062089104824\n", 46 | "Iteration: 26, delta = 1.7393038856876046\n", 47 | "Iteration: 27, delta = 1.8036633204969803\n", 48 | "Iteration: 28, delta = 1.8719259246654474\n", 49 | "Iteration: 29, delta = 1.94007652367236\n", 50 | "Iteration: 30, delta = 2.00494617126742\n", 51 | "Iteration: 31, delta = 2.0641193924147956\n", 52 | "Iteration: 32, delta = 2.1158432226679906\n", 53 | "Iteration: 33, delta = 2.1589378902735916\n", 54 | "Iteration: 34, delta = 2.1954186438620127\n", 55 | "Iteration: 35, delta = 2.1792291130373087\n", 56 | "Iteration: 36, delta = 2.161131873532213\n", 57 | "Iteration: 37, delta = 2.138616526509395\n", 58 | "Iteration: 38, delta = 2.1107247941065452\n", 59 | "Iteration: 39, delta = 2.077248416043858\n", 60 | "Iteration: 40, delta = 2.0384724195671424\n", 61 | "Iteration: 41, delta = 1.994984235397851\n", 62 | "Iteration: 42, delta = 1.94753433583287\n", 63 | "Iteration: 43, delta = 1.8969373637091849\n", 64 | "Iteration: 44, delta = 1.8440046355945015\n", 65 | "Iteration: 45, delta = 1.7895004867714306\n", 66 | "Iteration: 46, delta = 1.7341163178480201\n", 67 | "Iteration: 47, delta = 1.6784574320126922\n", 68 | "Iteration: 48, delta = 1.6230388169699381\n", 69 | "Iteration: 49, delta = 1.5682869251297689\n", 70 | "Iteration: 50, delta = 1.5145452458264117\n", 71 | "Iteration: 51, delta = 1.4620820576578808\n", 72 | "Iteration: 52, delta = 1.4110992154357973\n", 73 | "Iteration: 53, delta = 1.361741184467519\n", 74 | "Iteration: 54, delta = 1.3141038041451338\n", 75 | "Iteration: 55, delta = 1.268242460967258\n", 76 | "Iteration: 56, delta = 1.2241794934923291\n", 77 | "Iteration: 57, delta = 1.1819107512205846\n", 78 | "Iteration: 58, delta = 1.1414112967506753\n", 79 | "Iteration: 59, delta = 1.1026402839911498\n", 80 | "Iteration: 60, delta = 1.0655450714154995\n", 81 | "Iteration: 61, delta = 1.0300646433558995\n", 82 | "Iteration: 62, delta = 0.9961324178657378\n", 83 | "Iteration: 63, delta = 0.9636785196698838\n", 84 | "Iteration: 64, delta = 0.9326315931257341\n", 85 | "Iteration: 65, delta = 0.9029202244089447\n", 86 | "Iteration: 66, delta = 0.8744740354454734\n", 87 | "Iteration: 67, delta = 0.84722450494246\n", 88 | "Iteration: 68, delta = 0.8211055648937418\n", 89 | "Iteration: 69, delta = 0.7960540142573791\n", 90 | "Iteration: 70, delta = 0.7720097853969037\n", 91 | "Iteration: 71, delta = 0.7489160933666881\n", 92 | "Iteration: 72, delta = 0.7267194932596794\n", 93 | "Iteration: 73, delta = 0.7053698665672528\n", 94 | "Iteration: 74, delta = 0.6777633374424568\n", 95 | "Iteration: 75, delta = 0.6420322119474804\n", 96 | "Iteration: 76, delta = 0.6227444190744791\n", 97 | "Iteration: 77, delta = 0.6042345772832505\n", 98 | "Iteration: 78, delta = 0.586459970063288\n", 99 | "Iteration: 79, delta = 0.5694015038363887\n", 100 | "Iteration: 80, delta = 0.5530330654548834\n", 101 | "Iteration: 81, delta = 0.5373242592779661\n", 102 | "Iteration: 82, delta = 0.5222424281671465\n", 103 | "Iteration: 83, delta = 0.507754099076692\n", 104 | "Iteration: 84, delta = 0.4938259987636684\n", 105 | "Iteration: 85, delta = 0.4804257497094113\n", 106 | "Iteration: 86, delta = 0.4675223292865667\n", 107 | "Iteration: 87, delta = 0.45508635469013825\n", 108 | "Iteration: 88, delta = 0.44309024068775216\n", 109 | "Iteration: 89, delta = 0.4315082654932212\n", 110 | "Iteration: 90, delta = 0.42031657128563893\n", 111 | "Iteration: 91, delta = 0.4094931191989417\n", 112 | "Iteration: 92, delta = 0.399017613593287\n", 113 | "Iteration: 93, delta = 0.38887140662154707\n", 114 | "Iteration: 94, delta = 0.37903739121543367\n", 115 | "Iteration: 95, delta = 0.3694998884739107\n", 116 | "Iteration: 96, delta = 0.3602445337601239\n", 117 | "Iteration: 97, delta = 0.3512581645982209\n", 118 | "Iteration: 98, delta = 0.3425287125085106\n", 119 | "Iteration: 99, delta = 0.3340451002179634\n", 120 | "Iteration: 100, delta = 0.32579714515327396\n", 121 | "Iteration: 101, delta = 0.3177754697128421\n", 122 | "Iteration: 102, delta = 0.3099714185511857\n", 123 | "Iteration: 103, delta = 0.30237698284577164\n", 124 | "Iteration: 104, delta = 0.29498473141893555\n", 125 | "Iteration: 105, delta = 0.28778774842476196\n", 126 | "Iteration: 106, delta = 0.2807795772657755\n", 127 | "Iteration: 107, delta = 0.2739541703631403\n", 128 | "Iteration: 108, delta = 0.26730584434994853\n", 129 | "Iteration: 109, delta = 0.2608292402769621\n", 130 | "Iteration: 110, delta = 0.25451928839697757\n", 131 | "Iteration: 111, delta = 0.24837117712013423\n", 132 | "Iteration: 112, delta = 0.24238032574362478\n", 133 | "Iteration: 113, delta = 0.23654236056131595\n", 134 | "Iteration: 114, delta = 0.23085309402540588\n", 135 | "Iteration: 115, delta = 0.22530850659109092\n", 136 | "Iteration: 116, delta = 0.21990473097775975\n", 137 | "Iteration: 117, delta = 0.21463803852952879\n", 138 | "Iteration: 118, delta = 0.20950482743864995\n", 139 | "Iteration: 119, delta = 0.20450161259304878\n", 140 | "Iteration: 120, delta = 0.19962501683926348\n", 141 | "Iteration: 121, delta = 0.1948717634638797\n", 142 | "Iteration: 122, delta = 0.19023866974180237\n", 143 | "Iteration: 123, delta = 0.18572264138401806\n", 144 | "Iteration: 124, delta = 0.1813206677618382\n", 145 | "Iteration: 125, delta = 0.17702981779780202\n", 146 | "Iteration: 126, delta = 0.1728472364031859\n", 147 | "Iteration: 127, delta = 0.1687701413927698\n", 148 | "Iteration: 128, delta = 0.1647958207904594\n", 149 | "Iteration: 129, delta = 0.16092163045755115\n", 150 | "Iteration: 130, delta = 0.1571449919990755\n", 151 | "Iteration: 131, delta = 0.15346339087727756\n", 152 | "Iteration: 132, delta = 0.14987437472564125\n", 153 | "Iteration: 133, delta = 0.14637555178751427\n", 154 | "Iteration: 134, delta = 0.1429645894920668\n", 155 | "Iteration: 135, delta = 0.1396392131052835\n", 156 | "Iteration: 136, delta = 0.1363972044705406\n", 157 | "Iteration: 137, delta = 0.1332364007942033\n", 158 | "Iteration: 138, delta = 0.13015469348692932\n", 159 | "Iteration: 139, delta = 0.12715002702884703\n", 160 | "Iteration: 140, delta = 0.12422039787293215\n", 161 | "Iteration: 141, delta = 0.1213638533536141\n", 162 | "Iteration: 142, delta = 0.11857849063085268\n", 163 | "Iteration: 143, delta = 0.11586245562853037\n", 164 | "Iteration: 144, delta = 0.11321394199057977\n", 165 | "Iteration: 145, delta = 0.11063119004825239\n", 166 | "Iteration: 146, delta = 0.1081124857935265\n", 167 | "Iteration: 147, delta = 0.10565615985365184\n", 168 | "Iteration: 148, delta = 0.10326058648411163\n", 169 | "Iteration: 149, delta = 0.10092418256476776\n", 170 | "Iteration: 150, delta = 0.09864540659987142\n", 171 | "Correlation coefficient: 0.7255254877587117\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "run lasso_regression.py" 177 | ] 178 | } 179 | ], 180 | "metadata": { 181 | "kernelspec": { 182 | "display_name": "Python 3", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.5.3" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /Reinforcement Learning/Policy Iteration Algorithm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Policy iteration is guaranteed to converge and at convergence, the current policy and its utility function are the optimal policy and the optimal utility function. First of all, we define a policy π which assigns an action to each state. We can assign random actions to this policy, it does not matter.\n", 8 | "Once we evaluate the policy we can improve it. The policy improvement is the second and last step of the algorithm. Our environment has a finite number of states and then a finite number of policies. Each iteration yields to a better policy." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### Implementing the policy iteration algorithm:" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import numpy as np\n", 25 | "\n", 26 | "def return_policy_evaluation(p, u, r, T, gamma):\n", 27 | "\n", 28 | " #v is the state vector\n", 29 | " #T is the transition matrix\n", 30 | " #u is the utility vector\n", 31 | " #reward consists of the rewards earned for moving to a particular state\n", 32 | " #gamma is the discount factor by which rewards are discounted over the time\n", 33 | " for s in range(12):\n", 34 | " if not np.isnan(p[s]):\n", 35 | " v = np.zeros((1,12))\n", 36 | " v[0,s] = 1.0\n", 37 | " action = int(p[s])\n", 38 | " u[s] = r[s] + gamma * np.sum(np.multiply(u, np.dot(v, T[:,:,action])))\n", 39 | " return u" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def return_expected_action(u, T, v):\n", 49 | " \n", 50 | "# It returns an action based on the\n", 51 | "# expected utility of doing a in state s, \n", 52 | "# according to T and u. This action is\n", 53 | "# the one that maximize the expected\n", 54 | "# utility.\n", 55 | " \n", 56 | " actions_array = np.zeros(4)\n", 57 | " for action in range(4):\n", 58 | " #Expected utility of doing a in state s, according to T and u.\n", 59 | " actions_array[action] = np.sum(np.multiply(u, np.dot(v, T[:,:,action])))\n", 60 | " return np.argmax(actions_array)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def print_policy(p, shape):\n", 70 | " \"\"\"Printing utility.\n", 71 | "\n", 72 | " Print the policy actions using symbols:\n", 73 | " ^, v, <, > up, down, left, right\n", 74 | " * terminal states\n", 75 | " # obstacles\n", 76 | " \"\"\"\n", 77 | " counter = 0\n", 78 | " policy_string = \"\"\n", 79 | " for row in range(shape[0]):\n", 80 | " for col in range(shape[1]):\n", 81 | " if(p[counter] == -1): policy_string += \" * \" \n", 82 | " elif(p[counter] == 0): policy_string += \" ^ \"\n", 83 | " elif(p[counter] == 1): policy_string += \" < \"\n", 84 | " elif(p[counter] == 2): policy_string += \" v \" \n", 85 | " elif(p[counter] == 3): policy_string += \" > \"\n", 86 | " elif(np.isnan(p[counter])): policy_string += \" # \"\n", 87 | " counter += 1\n", 88 | " policy_string += '\\n'\n", 89 | " print(policy_string)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | " v < > * \n", 102 | " ^ # < * \n", 103 | " < < ^ v \n", 104 | "\n", 105 | " ^ > > * \n", 106 | " ^ # ^ * \n", 107 | " < > ^ v \n", 108 | "\n", 109 | " > > > * \n", 110 | " ^ # ^ * \n", 111 | " > > ^ < \n", 112 | "\n", 113 | " > > > * \n", 114 | " ^ # ^ * \n", 115 | " > > ^ < \n", 116 | "\n", 117 | " > > > * \n", 118 | " ^ # ^ * \n", 119 | " ^ > ^ < \n", 120 | "\n", 121 | " > > > * \n", 122 | " ^ # ^ * \n", 123 | " ^ > ^ < \n", 124 | "\n", 125 | " > > > * \n", 126 | " ^ # ^ * \n", 127 | " ^ < ^ < \n", 128 | "\n", 129 | " > > > * \n", 130 | " ^ # ^ * \n", 131 | " ^ < ^ < \n", 132 | "\n", 133 | " > > > * \n", 134 | " ^ # ^ * \n", 135 | " ^ < ^ < \n", 136 | "\n", 137 | " > > > * \n", 138 | " ^ # ^ * \n", 139 | " ^ < < < \n", 140 | "\n", 141 | " > > > * \n", 142 | " ^ # ^ * \n", 143 | " ^ < < < \n", 144 | "\n", 145 | " > > > * \n", 146 | " ^ # ^ * \n", 147 | " ^ < < < \n", 148 | "\n", 149 | " > > > * \n", 150 | " ^ # ^ * \n", 151 | " ^ < < < \n", 152 | "\n", 153 | " > > > * \n", 154 | " ^ # ^ * \n", 155 | " ^ < < < \n", 156 | "\n", 157 | " > > > * \n", 158 | " ^ # ^ * \n", 159 | " ^ < < < \n", 160 | "\n", 161 | " > > > * \n", 162 | " ^ # ^ * \n", 163 | " ^ < < < \n", 164 | "\n", 165 | " > > > * \n", 166 | " ^ # ^ * \n", 167 | " ^ < < < \n", 168 | "\n", 169 | " > > > * \n", 170 | " ^ # ^ * \n", 171 | " ^ < < < \n", 172 | "\n", 173 | " > > > * \n", 174 | " ^ # ^ * \n", 175 | " ^ < < < \n", 176 | "\n", 177 | " > > > * \n", 178 | " ^ # ^ * \n", 179 | " ^ < < < \n", 180 | "\n", 181 | " > > > * \n", 182 | " ^ # ^ * \n", 183 | " ^ < < < \n", 184 | "\n", 185 | "=================== FINAL RESULT ==================\n", 186 | "Iterations: 22\n", 187 | "Delta: 9.043213450299348e-08\n", 188 | "Gamma: 0.999\n", 189 | "Epsilon: 0.0001\n", 190 | "===================================================\n", 191 | "[0.80796344 0.86539911 0.91653199 1. ]\n", 192 | "[ 0.75696624 0. 0.65836281 -1. ]\n", 193 | "[0.69968295 0.64882105 0.60471972 0.38150427]\n", 194 | "===================================================\n", 195 | " > > > * \n", 196 | " ^ # ^ * \n", 197 | " ^ < < < \n", 198 | "\n", 199 | "===================================================\n" 200 | ] 201 | } 202 | ], 203 | "source": [ 204 | "def main():\n", 205 | " gamma = 0.999\n", 206 | " epsilon = 0.0001\n", 207 | " iteration = 0\n", 208 | " T = np.load(\"T.npy\")\n", 209 | " #Generate the first policy randomly\n", 210 | " # NaN=Nothing, -1=Terminal, 0=Up, 1=Left, 2=Down, 3=Right\n", 211 | " p = np.random.randint(0, 4, size=(12)).astype(np.float32)\n", 212 | " p[5] = np.NaN\n", 213 | " p[3] = p[7] = -1\n", 214 | " #Utility vectors\n", 215 | " u = np.array([0.0, 0.0, 0.0, 0.0,\n", 216 | " 0.0, 0.0, 0.0, 0.0,\n", 217 | " 0.0, 0.0, 0.0, 0.0])\n", 218 | " #Reward vector\n", 219 | " r = np.array([-0.04, -0.04, -0.04, +1.0,\n", 220 | " -0.04, 0.0, -0.04, -1.0,\n", 221 | " -0.04, -0.04, -0.04, -0.04])\n", 222 | "\n", 223 | " while True:\n", 224 | " iteration += 1\n", 225 | " #1- Policy evaluation\n", 226 | " u_0 = u.copy()\n", 227 | " u = return_policy_evaluation(p, u, r, T, gamma)\n", 228 | " #Stopping criteria\n", 229 | " delta = np.absolute(u - u_0).max()\n", 230 | " if delta < epsilon * (1 - gamma) / gamma: break\n", 231 | " for s in range(12):\n", 232 | " if not np.isnan(p[s]) and not p[s]==-1:\n", 233 | " v = np.zeros((1,12))\n", 234 | " v[0,s] = 1.0\n", 235 | " #2- Policy improvement\n", 236 | " a = return_expected_action(u, T, v) \n", 237 | " if a != p[s]: p[s] = a\n", 238 | " print_policy(p, shape=(3,4))\n", 239 | "\n", 240 | " print(\"=================== FINAL RESULT ==================\")\n", 241 | " print(\"Iterations: \" + str(iteration))\n", 242 | " print(\"Delta: \" + str(delta))\n", 243 | " print(\"Gamma: \" + str(gamma))\n", 244 | " print(\"Epsilon: \" + str(epsilon))\n", 245 | " print(\"===================================================\")\n", 246 | " print(u[0:4])\n", 247 | " print(u[4:8])\n", 248 | " print(u[8:12])\n", 249 | " print(\"===================================================\")\n", 250 | " print_policy(p, shape=(3,4))\n", 251 | " print(\"===================================================\")\n", 252 | "\n", 253 | "if __name__ == \"__main__\":\n", 254 | " main()" 255 | ] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "Python 3", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.5.4" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 2 279 | } 280 | -------------------------------------------------------------------------------- /classification_and_regression_trees/dot/ex2_prune.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "c4bff19d-b75d-4b50-99e8-34f696a77644" [label="0: 0.508542"]; 3 | "68b83894-3568-462c-a8c2-a2cfa600d44c" [label="0: 0.463241"]; 4 | "8fb7d681-5bb9-487b-804f-592a8760babb" [label="0: 0.130626"]; 5 | "ffb50925-bc5a-405b-aae5-15ba23da800d" [label="0: 0.085111"]; 6 | "92c528b2-54f3-488d-a817-b99003f2be3a" [label="0.77"]; 7 | "8f89c68f-ec97-4e89-9757-6d4d349dfb0b" [label="6.51"]; 8 | "ca28f207-af7e-43c9-8e5d-36a5e9cd3be5" [label="0: 0.377383"]; 9 | "14779c7f-cb2a-4452-b325-a27aef8b1af8" [label="0: 0.3417"]; 10 | "560ae6ed-a17a-46b8-a400-9e78920cf41b" [label="0: 0.32889"]; 11 | "288dea95-80b6-4638-8fb1-a1b657a19a73" [label="0: 0.300318"]; 12 | "9ac61729-5c19-4e91-915b-20beb665ceaa" [label="0: 0.176523"]; 13 | "42f1a782-408b-4c15-8bf0-13a01f968f81" [label="-9.18"]; 14 | "5f466b9c-fe6c-4465-97af-2520e6b66286" [label="0: 0.203993"]; 15 | "0a18504c-63a2-4af1-868c-38c591a02a60" [label="3.45"]; 16 | "74c75b17-a2a7-4bf0-b7d0-c4e4e916a29e" [label="0: 0.218321"]; 17 | "d1373766-52cd-4196-babe-eff0093695fe" [label="-11.82"]; 18 | "cfdc11bb-2ae8-4807-b3f1-c8e27c6f2317" [label="0: 0.228628"]; 19 | "480b7cfa-89e7-434d-9980-09192f674081" [label="6.77"]; 20 | "d56fe489-c3a4-4bcd-ace9-56c0e2510ec4" [label="0: 0.264639"]; 21 | "675438ce-b6e6-4489-8756-653ee5e37021" [label="-13.07"]; 22 | "2a8f2806-4eac-45ea-a820-eb776a0d8689" [label="0.40"]; 23 | "7993ad91-42aa-4b70-91db-5fbfed178b89" [label="-19.99"]; 24 | "f55749c7-c885-4d1a-b741-c3bd5c09f2dd" [label="15.06"]; 25 | "4f8fcefe-a83c-4aae-bff7-8ff98dec1643" [label="0: 0.351478"]; 26 | "e87b5ff0-822f-48c0-ae82-d32f92dd2111" [label="-22.69"]; 27 | "37a37975-c5da-4c85-9c02-d597c9041252" [label="-15.09"]; 28 | "0521b42e-1682-418f-a005-fe3d5098505a" [label="0: 0.446196"]; 29 | "c7762b6e-66e6-4a15-8458-f9ff2cc5cac6" [label="0: 0.418943"]; 30 | "266b231b-6083-474d-9ba8-8dbc1e4fec60" [label="1.38"]; 31 | "5981dfef-12f2-4ecf-83ef-fb111c6db412" [label="14.38"]; 32 | "fa23660c-72b3-49e7-a1f6-11038e6d0c2b" [label="-12.56"]; 33 | "c5ec6450-a9d7-4173-bca6-bb19e0eb3816" [label="0: 0.483803"]; 34 | "f31655f3-1003-4d88-8fcb-d57d75cf914e" [label="3.43"]; 35 | "f29cc8c5-ed1d-4b57-b16c-c657dd787edc" [label="12.51"]; 36 | "e74f8324-62e8-4d39-bd77-75ff0f2998f4" [label="0: 0.731636"]; 37 | "22f47068-79d7-4f15-8375-fca2b236b75f" [label="0: 0.642373"]; 38 | "d706789e-3e3b-4a19-8c5d-f97c17092013" [label="0: 0.618868"]; 39 | "72712f51-6089-4f84-bfac-5211971a9785" [label="0: 0.585413"]; 40 | "ac1505e9-2492-4269-a36d-bb93042e0af7" [label="0: 0.560301"]; 41 | "fc6c2bb4-7b7b-4d3f-b979-8ca27d29be5b" [label="0: 0.531944"]; 42 | "74d4917b-5afd-4078-bf91-22911bf0286f" [label="101.74"]; 43 | "eb5af9d7-1dc0-4e12-9b6f-d2de3d1dd6a7" [label="110.18"]; 44 | "d84a0c88-9b2c-4932-9f0f-c4c52d82de76" [label="97.20"]; 45 | "31a5cd3b-dbf9-4943-b90e-9a6d3a0a2a21" [label="123.21"]; 46 | "0790ed24-11c3-4310-a28c-fdd2d20ebed1" [label="93.67"]; 47 | "993e9dfe-e330-48c9-a4d5-44b6124403f0" [label="0: 0.667851"]; 48 | "8f434490-e990-4724-98b8-53c4b2cdbed5" [label="114.15"]; 49 | "4dc134f3-4882-45c0-bdac-eddc21bc10df" [label="0: 0.70889"]; 50 | "48da8c69-8998-4d8d-bd99-b268ed3281b9" [label="106.88"]; 51 | "393f82d6-a9a0-40e7-9c48-b3dc22233ef6" [label="114.55"]; 52 | "87ee99b7-59a7-49af-816d-ac7a0dc7d0cb" [label="0: 0.953902"]; 53 | "c5768cbc-7bae-45f3-80a0-ed03b930766b" [label="0: 0.763328"]; 54 | "633c832d-b589-4db2-a15e-0369eee79d08" [label="78.09"]; 55 | "e81e22e5-514c-42f7-859f-d21419008366" [label="0: 0.798198"]; 56 | "16ac6188-471c-439d-8503-da094a54ec82" [label="102.36"]; 57 | "0ed901ab-9dfb-4e60-b878-62b692e7dddc" [label="0: 0.838587"]; 58 | "e6c982d1-2c92-4787-9af6-d0f25432b035" [label="84.95"]; 59 | "7fa3f95d-e8a4-452d-bd78-dae6ad591cca" [label="0: 0.948822"]; 60 | "37d790b6-4448-4f11-b6c4-fa117cba9992" [label="0: 0.856421"]; 61 | "e068d29d-a65a-4c65-a5e6-3321f93076ba" [label="95.28"]; 62 | "7031a55c-261c-486b-ad1c-c94a66412706" [label="0: 0.912161"]; 63 | "5a91e296-ef36-4790-9e42-d191c8f5830a" [label="0: 0.896683"]; 64 | "c63246a4-4508-41f8-b7fe-412ac81e918d" [label="98.72"]; 65 | "0615432d-ca65-458e-971a-bfd6201edbaa" [label="104.83"]; 66 | "64df91ab-df9b-4c98-9e0d-4f8b2d2a4794" [label="96.45"]; 67 | "b1351d64-cf46-4654-9446-1bf4b89c0c32" [label="87.31"]; 68 | "e0dd7e12-6468-41b5-b63b-1c9a04ab682b" [label="108.84"]; 69 | "c4bff19d-b75d-4b50-99e8-34f696a77644" -> "68b83894-3568-462c-a8c2-a2cfa600d44c"; 70 | "68b83894-3568-462c-a8c2-a2cfa600d44c" -> "8fb7d681-5bb9-487b-804f-592a8760babb"; 71 | "8fb7d681-5bb9-487b-804f-592a8760babb" -> "ffb50925-bc5a-405b-aae5-15ba23da800d"; 72 | "ffb50925-bc5a-405b-aae5-15ba23da800d" -> "92c528b2-54f3-488d-a817-b99003f2be3a"; 73 | "ffb50925-bc5a-405b-aae5-15ba23da800d" -> "8f89c68f-ec97-4e89-9757-6d4d349dfb0b"; 74 | "8fb7d681-5bb9-487b-804f-592a8760babb" -> "ca28f207-af7e-43c9-8e5d-36a5e9cd3be5"; 75 | "ca28f207-af7e-43c9-8e5d-36a5e9cd3be5" -> "14779c7f-cb2a-4452-b325-a27aef8b1af8"; 76 | "14779c7f-cb2a-4452-b325-a27aef8b1af8" -> "560ae6ed-a17a-46b8-a400-9e78920cf41b"; 77 | "560ae6ed-a17a-46b8-a400-9e78920cf41b" -> "288dea95-80b6-4638-8fb1-a1b657a19a73"; 78 | "288dea95-80b6-4638-8fb1-a1b657a19a73" -> "9ac61729-5c19-4e91-915b-20beb665ceaa"; 79 | "9ac61729-5c19-4e91-915b-20beb665ceaa" -> "42f1a782-408b-4c15-8bf0-13a01f968f81"; 80 | "9ac61729-5c19-4e91-915b-20beb665ceaa" -> "5f466b9c-fe6c-4465-97af-2520e6b66286"; 81 | "5f466b9c-fe6c-4465-97af-2520e6b66286" -> "0a18504c-63a2-4af1-868c-38c591a02a60"; 82 | "5f466b9c-fe6c-4465-97af-2520e6b66286" -> "74c75b17-a2a7-4bf0-b7d0-c4e4e916a29e"; 83 | "74c75b17-a2a7-4bf0-b7d0-c4e4e916a29e" -> "d1373766-52cd-4196-babe-eff0093695fe"; 84 | "74c75b17-a2a7-4bf0-b7d0-c4e4e916a29e" -> "cfdc11bb-2ae8-4807-b3f1-c8e27c6f2317"; 85 | "cfdc11bb-2ae8-4807-b3f1-c8e27c6f2317" -> "480b7cfa-89e7-434d-9980-09192f674081"; 86 | "cfdc11bb-2ae8-4807-b3f1-c8e27c6f2317" -> "d56fe489-c3a4-4bcd-ace9-56c0e2510ec4"; 87 | "d56fe489-c3a4-4bcd-ace9-56c0e2510ec4" -> "675438ce-b6e6-4489-8756-653ee5e37021"; 88 | "d56fe489-c3a4-4bcd-ace9-56c0e2510ec4" -> "2a8f2806-4eac-45ea-a820-eb776a0d8689"; 89 | "288dea95-80b6-4638-8fb1-a1b657a19a73" -> "7993ad91-42aa-4b70-91db-5fbfed178b89"; 90 | "560ae6ed-a17a-46b8-a400-9e78920cf41b" -> "f55749c7-c885-4d1a-b741-c3bd5c09f2dd"; 91 | "14779c7f-cb2a-4452-b325-a27aef8b1af8" -> "4f8fcefe-a83c-4aae-bff7-8ff98dec1643"; 92 | "4f8fcefe-a83c-4aae-bff7-8ff98dec1643" -> "e87b5ff0-822f-48c0-ae82-d32f92dd2111"; 93 | "4f8fcefe-a83c-4aae-bff7-8ff98dec1643" -> "37a37975-c5da-4c85-9c02-d597c9041252"; 94 | "ca28f207-af7e-43c9-8e5d-36a5e9cd3be5" -> "0521b42e-1682-418f-a005-fe3d5098505a"; 95 | "0521b42e-1682-418f-a005-fe3d5098505a" -> "c7762b6e-66e6-4a15-8458-f9ff2cc5cac6"; 96 | "c7762b6e-66e6-4a15-8458-f9ff2cc5cac6" -> "266b231b-6083-474d-9ba8-8dbc1e4fec60"; 97 | "c7762b6e-66e6-4a15-8458-f9ff2cc5cac6" -> "5981dfef-12f2-4ecf-83ef-fb111c6db412"; 98 | "0521b42e-1682-418f-a005-fe3d5098505a" -> "fa23660c-72b3-49e7-a1f6-11038e6d0c2b"; 99 | "68b83894-3568-462c-a8c2-a2cfa600d44c" -> "c5ec6450-a9d7-4173-bca6-bb19e0eb3816"; 100 | "c5ec6450-a9d7-4173-bca6-bb19e0eb3816" -> "f31655f3-1003-4d88-8fcb-d57d75cf914e"; 101 | "c5ec6450-a9d7-4173-bca6-bb19e0eb3816" -> "f29cc8c5-ed1d-4b57-b16c-c657dd787edc"; 102 | "c4bff19d-b75d-4b50-99e8-34f696a77644" -> "e74f8324-62e8-4d39-bd77-75ff0f2998f4"; 103 | "e74f8324-62e8-4d39-bd77-75ff0f2998f4" -> "22f47068-79d7-4f15-8375-fca2b236b75f"; 104 | "22f47068-79d7-4f15-8375-fca2b236b75f" -> "d706789e-3e3b-4a19-8c5d-f97c17092013"; 105 | "d706789e-3e3b-4a19-8c5d-f97c17092013" -> "72712f51-6089-4f84-bfac-5211971a9785"; 106 | "72712f51-6089-4f84-bfac-5211971a9785" -> "ac1505e9-2492-4269-a36d-bb93042e0af7"; 107 | "ac1505e9-2492-4269-a36d-bb93042e0af7" -> "fc6c2bb4-7b7b-4d3f-b979-8ca27d29be5b"; 108 | "fc6c2bb4-7b7b-4d3f-b979-8ca27d29be5b" -> "74d4917b-5afd-4078-bf91-22911bf0286f"; 109 | "fc6c2bb4-7b7b-4d3f-b979-8ca27d29be5b" -> "eb5af9d7-1dc0-4e12-9b6f-d2de3d1dd6a7"; 110 | "ac1505e9-2492-4269-a36d-bb93042e0af7" -> "d84a0c88-9b2c-4932-9f0f-c4c52d82de76"; 111 | "72712f51-6089-4f84-bfac-5211971a9785" -> "31a5cd3b-dbf9-4943-b90e-9a6d3a0a2a21"; 112 | "d706789e-3e3b-4a19-8c5d-f97c17092013" -> "0790ed24-11c3-4310-a28c-fdd2d20ebed1"; 113 | "22f47068-79d7-4f15-8375-fca2b236b75f" -> "993e9dfe-e330-48c9-a4d5-44b6124403f0"; 114 | "993e9dfe-e330-48c9-a4d5-44b6124403f0" -> "8f434490-e990-4724-98b8-53c4b2cdbed5"; 115 | "993e9dfe-e330-48c9-a4d5-44b6124403f0" -> "4dc134f3-4882-45c0-bdac-eddc21bc10df"; 116 | "4dc134f3-4882-45c0-bdac-eddc21bc10df" -> "48da8c69-8998-4d8d-bd99-b268ed3281b9"; 117 | "4dc134f3-4882-45c0-bdac-eddc21bc10df" -> "393f82d6-a9a0-40e7-9c48-b3dc22233ef6"; 118 | "e74f8324-62e8-4d39-bd77-75ff0f2998f4" -> "87ee99b7-59a7-49af-816d-ac7a0dc7d0cb"; 119 | "87ee99b7-59a7-49af-816d-ac7a0dc7d0cb" -> "c5768cbc-7bae-45f3-80a0-ed03b930766b"; 120 | "c5768cbc-7bae-45f3-80a0-ed03b930766b" -> "633c832d-b589-4db2-a15e-0369eee79d08"; 121 | "c5768cbc-7bae-45f3-80a0-ed03b930766b" -> "e81e22e5-514c-42f7-859f-d21419008366"; 122 | "e81e22e5-514c-42f7-859f-d21419008366" -> "16ac6188-471c-439d-8503-da094a54ec82"; 123 | "e81e22e5-514c-42f7-859f-d21419008366" -> "0ed901ab-9dfb-4e60-b878-62b692e7dddc"; 124 | "0ed901ab-9dfb-4e60-b878-62b692e7dddc" -> "e6c982d1-2c92-4787-9af6-d0f25432b035"; 125 | "0ed901ab-9dfb-4e60-b878-62b692e7dddc" -> "7fa3f95d-e8a4-452d-bd78-dae6ad591cca"; 126 | "7fa3f95d-e8a4-452d-bd78-dae6ad591cca" -> "37d790b6-4448-4f11-b6c4-fa117cba9992"; 127 | "37d790b6-4448-4f11-b6c4-fa117cba9992" -> "e068d29d-a65a-4c65-a5e6-3321f93076ba"; 128 | "37d790b6-4448-4f11-b6c4-fa117cba9992" -> "7031a55c-261c-486b-ad1c-c94a66412706"; 129 | "7031a55c-261c-486b-ad1c-c94a66412706" -> "5a91e296-ef36-4790-9e42-d191c8f5830a"; 130 | "5a91e296-ef36-4790-9e42-d191c8f5830a" -> "c63246a4-4508-41f8-b7fe-412ac81e918d"; 131 | "5a91e296-ef36-4790-9e42-d191c8f5830a" -> "0615432d-ca65-458e-971a-bfd6201edbaa"; 132 | "7031a55c-261c-486b-ad1c-c94a66412706" -> "64df91ab-df9b-4c98-9e0d-4f8b2d2a4794"; 133 | "7fa3f95d-e8a4-452d-bd78-dae6ad591cca" -> "b1351d64-cf46-4654-9446-1bf4b89c0c32"; 134 | "87ee99b7-59a7-49af-816d-ac7a0dc7d0cb" -> "e0dd7e12-6468-41b5-b63b-1c9a04ab682b"; 135 | } -------------------------------------------------------------------------------- /classification_and_regression_trees/dot/ex2.dot: -------------------------------------------------------------------------------- 1 | digraph decision_tree { 2 | "bdbe6f68-a446-4539-8a80-860f22663afe" [label="0: 0.508542"]; 3 | "acef94b2-b18f-4c9c-bb21-4caa8279f319" [label="0: 0.463241"]; 4 | "31303341-5a1f-4167-83ce-22e8ea1e462f" [label="0: 0.130626"]; 5 | "1ee2e839-eb76-48b6-8149-a694de0bc740" [label="0: 0.085111"]; 6 | "1294ba07-97b4-44da-b30f-6dc8de1f2506" [label="0: 0.053764"]; 7 | "bb2ea4a5-e090-43bd-a89f-47459726a906" [label="4.09"]; 8 | "cf4e8349-aae6-4f47-b2e3-bdfad76140b9" [label="-2.54"]; 9 | "746064bd-117c-46d6-bbbd-8f3943d2b418" [label="6.51"]; 10 | "3c7d308c-4094-44ed-a689-23971c62ea5a" [label="0: 0.377383"]; 11 | "8cd38fe6-64f6-48f5-b3da-1755767c9c9e" [label="0: 0.3417"]; 12 | "8fa7515c-8c27-400c-b2eb-fa95acd8f8c0" [label="0: 0.32889"]; 13 | "c7820d96-a919-408a-a48d-6472c6b6cbe4" [label="0: 0.300318"]; 14 | "046be329-b0aa-49f7-b3d6-7fc216d478c9" [label="0: 0.176523"]; 15 | "0a30919d-85b8-4c93-bfbb-f1fcff96dba6" [label="0: 0.156273"]; 16 | "4ad75fa1-e0b7-4678-be73-5de3775d397e" [label="-6.25"]; 17 | "398d6223-1ef9-4c45-9ac7-44f61850dc45" [label="-12.11"]; 18 | "3a51d811-12ca-46b5-8548-f40a6571b263" [label="0: 0.203993"]; 19 | "2460b0c3-0648-4d24-8a57-e996df15a425" [label="3.45"]; 20 | "a595f63d-5bc6-4901-8802-749a77568531" [label="0: 0.218321"]; 21 | "a4cedd79-a7f9-4d67-acd2-bba51563dc30" [label="-11.82"]; 22 | "9e2b62f5-877e-4c05-82e1-8264fe924106" [label="0: 0.228628"]; 23 | "47a064c2-9c82-4e09-a069-80e6cd741aba" [label="6.77"]; 24 | "78788a82-c815-4891-8933-d0d2d39e8ace" [label="0: 0.264639"]; 25 | "83554092-3407-44a5-89d9-b1954126079b" [label="-13.07"]; 26 | "26d04a07-eb77-4691-aca6-3bf94fae2063" [label="0.40"]; 27 | "aabebca3-28f6-4bb4-aa90-9393dfa9479c" [label="-19.99"]; 28 | "217e5a9d-99f6-48f6-b60d-4940de06b873" [label="15.06"]; 29 | "322a229e-98a5-4f9d-b6ce-5be216bb9b66" [label="0: 0.351478"]; 30 | "20c1cd26-5c8f-45ef-8557-34bdc7e7e936" [label="-22.69"]; 31 | "d142ca3c-8c45-4ff5-a975-5614ec532660" [label="-15.09"]; 32 | "dd4807c3-e5dd-4139-85a3-673bc6384037" [label="0: 0.446196"]; 33 | "98190fc7-96ed-4c83-b200-1305f1663d12" [label="0: 0.418943"]; 34 | "a2389757-85a4-4e78-82e5-3ddcb540cb5f" [label="0: 0.388789"]; 35 | "3d1b079c-d685-4810-b7f4-c88824625a53" [label="3.66"]; 36 | "0d29f899-b06a-4c70-aa65-9fbf6794eca4" [label="-0.89"]; 37 | "25f74f36-c5be-4b83-a233-0c555c86ead9" [label="14.38"]; 38 | "cadc3d94-2545-4a48-be70-6915c3fa2795" [label="-12.56"]; 39 | "a57b9084-97cc-450d-bc7d-8feefef45a82" [label="0: 0.483803"]; 40 | "8b548b72-2bc2-499c-bb77-12b597d0e793" [label="3.43"]; 41 | "a8a356da-2c12-474c-a463-8110c1f6b8bf" [label="12.51"]; 42 | "e9657d5a-ad84-4e6f-acfe-6cfaa5f88780" [label="0: 0.731636"]; 43 | "39dce041-ce2f-44d7-bdbd-45cf38da4af8" [label="0: 0.642373"]; 44 | "9aebff68-5a03-4060-899d-345ef1fddace" [label="0: 0.618868"]; 45 | "5e1c0e83-b7d6-48b0-9ef5-560cbca09f44" [label="0: 0.585413"]; 46 | "7daf39a6-a515-4506-a06e-23af1edbec6a" [label="0: 0.560301"]; 47 | "097d7583-bea8-4b59-8cf6-bb1a7ec82295" [label="0: 0.531944"]; 48 | "f4e8defb-d0a3-4328-9fb1-835d3f85e406" [label="101.74"]; 49 | "b5524dcb-e110-4fca-95ba-eec513d60fdb" [label="0: 0.546601"]; 50 | "c68f0abe-5d5a-4493-8cf7-6818f9601baa" [label="110.98"]; 51 | "6bb0d9b4-4be8-468e-a888-19c08287fce3" [label="109.39"]; 52 | "d51fb426-e5ec-416b-b844-e069df5f6dc8" [label="97.20"]; 53 | "09721c80-fa8e-4cec-b962-32a79348d015" [label="123.21"]; 54 | "841fee93-2890-490a-b2c8-8bb720c32ea6" [label="93.67"]; 55 | "c53bc2c2-dfc2-4375-9f88-9736f075c294" [label="0: 0.667851"]; 56 | "c55165f2-d3c8-4b70-8d8c-ace5415b172f" [label="114.15"]; 57 | "a0ae2570-2f84-4f3f-aa45-689cdc3f64cf" [label="0: 0.70889"]; 58 | "b640a2a1-5fd8-4063-869a-ff60c957dbc2" [label="0: 0.69892"]; 59 | "1c5d7da8-d76c-444c-aa07-c8811130765d" [label="108.93"]; 60 | "1fe9a864-a467-4aaf-b9f1-87bfc45c25b5" [label="104.82"]; 61 | "c7358239-c55b-413d-aa76-540e5994f84f" [label="114.55"]; 62 | "709c2602-24e3-4f9a-bb0d-8c089399c018" [label="0: 0.953902"]; 63 | "80f583a9-c4eb-41af-bef3-cfb7d38d41ee" [label="0: 0.763328"]; 64 | "9c72f602-2240-4ea7-9c0b-f7269fc86618" [label="78.09"]; 65 | "6b400587-213d-483b-8ad0-110dc6d85507" [label="0: 0.798198"]; 66 | "73c06578-7739-4122-98e0-63cb9dc39f81" [label="102.36"]; 67 | "8ac0c1b7-2b2d-44b4-99f4-282aae0fe308" [label="0: 0.838587"]; 68 | "633bb485-db8e-4d5d-8d68-1fc08619a673" [label="0: 0.815215"]; 69 | "0b343b88-8a2a-49ae-a2d7-bd624044698b" [label="88.78"]; 70 | "7da8e0bd-b3b8-4092-a34e-1ba8cd3ffbf7" [label="81.11"]; 71 | "d3bb79be-7b85-4c0f-b5b7-dbb6bd4c54a2" [label="0: 0.948822"]; 72 | "f3081916-79ff-4e39-8b30-fa054670c50f" [label="0: 0.856421"]; 73 | "829fae81-4a91-46f9-8c6f-7559dd670841" [label="95.28"]; 74 | "4d529db0-b857-4910-a482-9e61bc87f959" [label="0: 0.912161"]; 75 | "818b003f-8abd-4ee8-a4ea-b596111a4f77" [label="0: 0.896683"]; 76 | "c5f07126-a0a7-4d03-8ee2-b60dd344013f" [label="0: 0.883615"]; 77 | "f9034542-ada8-48be-9db2-e4c9b2be70de" [label="102.25"]; 78 | "3257b4c9-7a8a-4cb7-a620-0980005ecb6d" [label="95.18"]; 79 | "6f8ddb81-d225-41e7-b1ec-fa6185f55097" [label="104.83"]; 80 | "b5e03953-ad6c-4a39-981a-140d7fafc46a" [label="96.45"]; 81 | "e85378cb-074a-4d35-ae82-a9a99f9daf50" [label="87.31"]; 82 | "dfe3dd59-018c-41ab-8046-21ed7d136b4a" [label="0: 0.960398"]; 83 | "2d37edbb-d4d3-4b55-a593-8a2b1ab3ff25" [label="112.43"]; 84 | "f075116f-c182-4b0d-9f7e-2dae61982c18" [label="105.25"]; 85 | "bdbe6f68-a446-4539-8a80-860f22663afe" -> "acef94b2-b18f-4c9c-bb21-4caa8279f319"; 86 | "acef94b2-b18f-4c9c-bb21-4caa8279f319" -> "31303341-5a1f-4167-83ce-22e8ea1e462f"; 87 | "31303341-5a1f-4167-83ce-22e8ea1e462f" -> "1ee2e839-eb76-48b6-8149-a694de0bc740"; 88 | "1ee2e839-eb76-48b6-8149-a694de0bc740" -> "1294ba07-97b4-44da-b30f-6dc8de1f2506"; 89 | "1294ba07-97b4-44da-b30f-6dc8de1f2506" -> "bb2ea4a5-e090-43bd-a89f-47459726a906"; 90 | "1294ba07-97b4-44da-b30f-6dc8de1f2506" -> "cf4e8349-aae6-4f47-b2e3-bdfad76140b9"; 91 | "1ee2e839-eb76-48b6-8149-a694de0bc740" -> "746064bd-117c-46d6-bbbd-8f3943d2b418"; 92 | "31303341-5a1f-4167-83ce-22e8ea1e462f" -> "3c7d308c-4094-44ed-a689-23971c62ea5a"; 93 | "3c7d308c-4094-44ed-a689-23971c62ea5a" -> "8cd38fe6-64f6-48f5-b3da-1755767c9c9e"; 94 | "8cd38fe6-64f6-48f5-b3da-1755767c9c9e" -> "8fa7515c-8c27-400c-b2eb-fa95acd8f8c0"; 95 | "8fa7515c-8c27-400c-b2eb-fa95acd8f8c0" -> "c7820d96-a919-408a-a48d-6472c6b6cbe4"; 96 | "c7820d96-a919-408a-a48d-6472c6b6cbe4" -> "046be329-b0aa-49f7-b3d6-7fc216d478c9"; 97 | "046be329-b0aa-49f7-b3d6-7fc216d478c9" -> "0a30919d-85b8-4c93-bfbb-f1fcff96dba6"; 98 | "0a30919d-85b8-4c93-bfbb-f1fcff96dba6" -> "4ad75fa1-e0b7-4678-be73-5de3775d397e"; 99 | "0a30919d-85b8-4c93-bfbb-f1fcff96dba6" -> "398d6223-1ef9-4c45-9ac7-44f61850dc45"; 100 | "046be329-b0aa-49f7-b3d6-7fc216d478c9" -> "3a51d811-12ca-46b5-8548-f40a6571b263"; 101 | "3a51d811-12ca-46b5-8548-f40a6571b263" -> "2460b0c3-0648-4d24-8a57-e996df15a425"; 102 | "3a51d811-12ca-46b5-8548-f40a6571b263" -> "a595f63d-5bc6-4901-8802-749a77568531"; 103 | "a595f63d-5bc6-4901-8802-749a77568531" -> "a4cedd79-a7f9-4d67-acd2-bba51563dc30"; 104 | "a595f63d-5bc6-4901-8802-749a77568531" -> "9e2b62f5-877e-4c05-82e1-8264fe924106"; 105 | "9e2b62f5-877e-4c05-82e1-8264fe924106" -> "47a064c2-9c82-4e09-a069-80e6cd741aba"; 106 | "9e2b62f5-877e-4c05-82e1-8264fe924106" -> "78788a82-c815-4891-8933-d0d2d39e8ace"; 107 | "78788a82-c815-4891-8933-d0d2d39e8ace" -> "83554092-3407-44a5-89d9-b1954126079b"; 108 | "78788a82-c815-4891-8933-d0d2d39e8ace" -> "26d04a07-eb77-4691-aca6-3bf94fae2063"; 109 | "c7820d96-a919-408a-a48d-6472c6b6cbe4" -> "aabebca3-28f6-4bb4-aa90-9393dfa9479c"; 110 | "8fa7515c-8c27-400c-b2eb-fa95acd8f8c0" -> "217e5a9d-99f6-48f6-b60d-4940de06b873"; 111 | "8cd38fe6-64f6-48f5-b3da-1755767c9c9e" -> "322a229e-98a5-4f9d-b6ce-5be216bb9b66"; 112 | "322a229e-98a5-4f9d-b6ce-5be216bb9b66" -> "20c1cd26-5c8f-45ef-8557-34bdc7e7e936"; 113 | "322a229e-98a5-4f9d-b6ce-5be216bb9b66" -> "d142ca3c-8c45-4ff5-a975-5614ec532660"; 114 | "3c7d308c-4094-44ed-a689-23971c62ea5a" -> "dd4807c3-e5dd-4139-85a3-673bc6384037"; 115 | "dd4807c3-e5dd-4139-85a3-673bc6384037" -> "98190fc7-96ed-4c83-b200-1305f1663d12"; 116 | "98190fc7-96ed-4c83-b200-1305f1663d12" -> "a2389757-85a4-4e78-82e5-3ddcb540cb5f"; 117 | "a2389757-85a4-4e78-82e5-3ddcb540cb5f" -> "3d1b079c-d685-4810-b7f4-c88824625a53"; 118 | "a2389757-85a4-4e78-82e5-3ddcb540cb5f" -> "0d29f899-b06a-4c70-aa65-9fbf6794eca4"; 119 | "98190fc7-96ed-4c83-b200-1305f1663d12" -> "25f74f36-c5be-4b83-a233-0c555c86ead9"; 120 | "dd4807c3-e5dd-4139-85a3-673bc6384037" -> "cadc3d94-2545-4a48-be70-6915c3fa2795"; 121 | "acef94b2-b18f-4c9c-bb21-4caa8279f319" -> "a57b9084-97cc-450d-bc7d-8feefef45a82"; 122 | "a57b9084-97cc-450d-bc7d-8feefef45a82" -> "8b548b72-2bc2-499c-bb77-12b597d0e793"; 123 | "a57b9084-97cc-450d-bc7d-8feefef45a82" -> "a8a356da-2c12-474c-a463-8110c1f6b8bf"; 124 | "bdbe6f68-a446-4539-8a80-860f22663afe" -> "e9657d5a-ad84-4e6f-acfe-6cfaa5f88780"; 125 | "e9657d5a-ad84-4e6f-acfe-6cfaa5f88780" -> "39dce041-ce2f-44d7-bdbd-45cf38da4af8"; 126 | "39dce041-ce2f-44d7-bdbd-45cf38da4af8" -> "9aebff68-5a03-4060-899d-345ef1fddace"; 127 | "9aebff68-5a03-4060-899d-345ef1fddace" -> "5e1c0e83-b7d6-48b0-9ef5-560cbca09f44"; 128 | "5e1c0e83-b7d6-48b0-9ef5-560cbca09f44" -> "7daf39a6-a515-4506-a06e-23af1edbec6a"; 129 | "7daf39a6-a515-4506-a06e-23af1edbec6a" -> "097d7583-bea8-4b59-8cf6-bb1a7ec82295"; 130 | "097d7583-bea8-4b59-8cf6-bb1a7ec82295" -> "f4e8defb-d0a3-4328-9fb1-835d3f85e406"; 131 | "097d7583-bea8-4b59-8cf6-bb1a7ec82295" -> "b5524dcb-e110-4fca-95ba-eec513d60fdb"; 132 | "b5524dcb-e110-4fca-95ba-eec513d60fdb" -> "c68f0abe-5d5a-4493-8cf7-6818f9601baa"; 133 | "b5524dcb-e110-4fca-95ba-eec513d60fdb" -> "6bb0d9b4-4be8-468e-a888-19c08287fce3"; 134 | "7daf39a6-a515-4506-a06e-23af1edbec6a" -> "d51fb426-e5ec-416b-b844-e069df5f6dc8"; 135 | "5e1c0e83-b7d6-48b0-9ef5-560cbca09f44" -> "09721c80-fa8e-4cec-b962-32a79348d015"; 136 | "9aebff68-5a03-4060-899d-345ef1fddace" -> "841fee93-2890-490a-b2c8-8bb720c32ea6"; 137 | "39dce041-ce2f-44d7-bdbd-45cf38da4af8" -> "c53bc2c2-dfc2-4375-9f88-9736f075c294"; 138 | "c53bc2c2-dfc2-4375-9f88-9736f075c294" -> "c55165f2-d3c8-4b70-8d8c-ace5415b172f"; 139 | "c53bc2c2-dfc2-4375-9f88-9736f075c294" -> "a0ae2570-2f84-4f3f-aa45-689cdc3f64cf"; 140 | "a0ae2570-2f84-4f3f-aa45-689cdc3f64cf" -> "b640a2a1-5fd8-4063-869a-ff60c957dbc2"; 141 | "b640a2a1-5fd8-4063-869a-ff60c957dbc2" -> "1c5d7da8-d76c-444c-aa07-c8811130765d"; 142 | "b640a2a1-5fd8-4063-869a-ff60c957dbc2" -> "1fe9a864-a467-4aaf-b9f1-87bfc45c25b5"; 143 | "a0ae2570-2f84-4f3f-aa45-689cdc3f64cf" -> "c7358239-c55b-413d-aa76-540e5994f84f"; 144 | "e9657d5a-ad84-4e6f-acfe-6cfaa5f88780" -> "709c2602-24e3-4f9a-bb0d-8c089399c018"; 145 | "709c2602-24e3-4f9a-bb0d-8c089399c018" -> "80f583a9-c4eb-41af-bef3-cfb7d38d41ee"; 146 | "80f583a9-c4eb-41af-bef3-cfb7d38d41ee" -> "9c72f602-2240-4ea7-9c0b-f7269fc86618"; 147 | "80f583a9-c4eb-41af-bef3-cfb7d38d41ee" -> "6b400587-213d-483b-8ad0-110dc6d85507"; 148 | "6b400587-213d-483b-8ad0-110dc6d85507" -> "73c06578-7739-4122-98e0-63cb9dc39f81"; 149 | "6b400587-213d-483b-8ad0-110dc6d85507" -> "8ac0c1b7-2b2d-44b4-99f4-282aae0fe308"; 150 | "8ac0c1b7-2b2d-44b4-99f4-282aae0fe308" -> "633bb485-db8e-4d5d-8d68-1fc08619a673"; 151 | "633bb485-db8e-4d5d-8d68-1fc08619a673" -> "0b343b88-8a2a-49ae-a2d7-bd624044698b"; 152 | "633bb485-db8e-4d5d-8d68-1fc08619a673" -> "7da8e0bd-b3b8-4092-a34e-1ba8cd3ffbf7"; 153 | "8ac0c1b7-2b2d-44b4-99f4-282aae0fe308" -> "d3bb79be-7b85-4c0f-b5b7-dbb6bd4c54a2"; 154 | "d3bb79be-7b85-4c0f-b5b7-dbb6bd4c54a2" -> "f3081916-79ff-4e39-8b30-fa054670c50f"; 155 | "f3081916-79ff-4e39-8b30-fa054670c50f" -> "829fae81-4a91-46f9-8c6f-7559dd670841"; 156 | "f3081916-79ff-4e39-8b30-fa054670c50f" -> "4d529db0-b857-4910-a482-9e61bc87f959"; 157 | "4d529db0-b857-4910-a482-9e61bc87f959" -> "818b003f-8abd-4ee8-a4ea-b596111a4f77"; 158 | "818b003f-8abd-4ee8-a4ea-b596111a4f77" -> "c5f07126-a0a7-4d03-8ee2-b60dd344013f"; 159 | "c5f07126-a0a7-4d03-8ee2-b60dd344013f" -> "f9034542-ada8-48be-9db2-e4c9b2be70de"; 160 | "c5f07126-a0a7-4d03-8ee2-b60dd344013f" -> "3257b4c9-7a8a-4cb7-a620-0980005ecb6d"; 161 | "818b003f-8abd-4ee8-a4ea-b596111a4f77" -> "6f8ddb81-d225-41e7-b1ec-fa6185f55097"; 162 | "4d529db0-b857-4910-a482-9e61bc87f959" -> "b5e03953-ad6c-4a39-981a-140d7fafc46a"; 163 | "d3bb79be-7b85-4c0f-b5b7-dbb6bd4c54a2" -> "e85378cb-074a-4d35-ae82-a9a99f9daf50"; 164 | "709c2602-24e3-4f9a-bb0d-8c089399c018" -> "dfe3dd59-018c-41ab-8046-21ed7d136b4a"; 165 | "dfe3dd59-018c-41ab-8046-21ed7d136b4a" -> "2d37edbb-d4d3-4b55-a593-8a2b1ab3ff25"; 166 | "dfe3dd59-018c-41ab-8046-21ed7d136b4a" -> "f075116f-c182-4b0d-9f7e-2dae61982c18"; 167 | } -------------------------------------------------------------------------------- /classification_and_regression_trees/notebook/后剪枝.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from prune import *" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## 加载数据" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "data = load_data('ex2.txt')\n", 26 | "tree = create_tree(data, fleaf, ferr)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 判断树结构" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "False" 45 | ] 46 | }, 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "not_tree(tree)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## 对树结构进行塌陷处理" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "53.136107929136443" 72 | ] 73 | }, 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "collapse(tree)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## 输出树结构" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "{'feat_idx': 0,\n", 99 | " 'feat_val': 0.50854200000000005,\n", 100 | " 'left': {'feat_idx': 0,\n", 101 | " 'feat_val': 0.46324100000000001,\n", 102 | " 'left': {'feat_idx': 0,\n", 103 | " 'feat_val': 0.13062599999999999,\n", 104 | " 'left': {'feat_idx': 0,\n", 105 | " 'feat_val': 0.085111000000000006,\n", 106 | " 'left': {'feat_idx': 0,\n", 107 | " 'feat_val': 0.053763999999999999,\n", 108 | " 'left': 4.0916259999999998,\n", 109 | " 'right': -2.5443927142857148},\n", 110 | " 'right': 6.5098432857142843},\n", 111 | " 'right': {'feat_idx': 0,\n", 112 | " 'feat_val': 0.37738300000000002,\n", 113 | " 'left': {'feat_idx': 0,\n", 114 | " 'feat_val': 0.3417,\n", 115 | " 'left': {'feat_idx': 0,\n", 116 | " 'feat_val': 0.32889000000000002,\n", 117 | " 'left': {'feat_idx': 0,\n", 118 | " 'feat_val': 0.30031799999999997,\n", 119 | " 'left': {'feat_idx': 0,\n", 120 | " 'feat_val': 0.17652300000000001,\n", 121 | " 'left': {'feat_idx': 0,\n", 122 | " 'feat_val': 0.156273,\n", 123 | " 'left': -6.2479000000000013,\n", 124 | " 'right': -12.107972500000001},\n", 125 | " 'right': {'feat_idx': 0,\n", 126 | " 'feat_val': 0.20399300000000001,\n", 127 | " 'left': 3.4496025000000001,\n", 128 | " 'right': {'feat_idx': 0,\n", 129 | " 'feat_val': 0.21832099999999999,\n", 130 | " 'left': -11.822278500000001,\n", 131 | " 'right': {'feat_idx': 0,\n", 132 | " 'feat_val': 0.228628,\n", 133 | " 'left': 6.770429,\n", 134 | " 'right': {'feat_idx': 0,\n", 135 | " 'feat_val': 0.26463900000000001,\n", 136 | " 'left': -13.070501,\n", 137 | " 'right': 0.40377471428571476}}}}},\n", 138 | " 'right': -19.994155200000002},\n", 139 | " 'right': 15.059290750000001},\n", 140 | " 'right': {'feat_idx': 0,\n", 141 | " 'feat_val': 0.35147800000000001,\n", 142 | " 'left': -22.693879600000002,\n", 143 | " 'right': -15.085111749999999}},\n", 144 | " 'right': {'feat_idx': 0,\n", 145 | " 'feat_val': 0.44619599999999998,\n", 146 | " 'left': {'feat_idx': 0,\n", 147 | " 'feat_val': 0.41894300000000001,\n", 148 | " 'left': {'feat_idx': 0,\n", 149 | " 'feat_val': 0.388789,\n", 150 | " 'left': 3.6584772500000016,\n", 151 | " 'right': -0.89235549999999952},\n", 152 | " 'right': 14.38417875},\n", 153 | " 'right': -12.558604833333334}}},\n", 154 | " 'right': {'feat_idx': 0,\n", 155 | " 'feat_val': 0.48380299999999998,\n", 156 | " 'left': 3.4331330000000007,\n", 157 | " 'right': 12.50675925}},\n", 158 | " 'right': {'feat_idx': 0,\n", 159 | " 'feat_val': 0.73163599999999995,\n", 160 | " 'left': {'feat_idx': 0,\n", 161 | " 'feat_val': 0.64237299999999997,\n", 162 | " 'left': {'feat_idx': 0,\n", 163 | " 'feat_val': 0.61886799999999997,\n", 164 | " 'left': {'feat_idx': 0,\n", 165 | " 'feat_val': 0.58541299999999996,\n", 166 | " 'left': {'feat_idx': 0,\n", 167 | " 'feat_val': 0.56030100000000005,\n", 168 | " 'left': {'feat_idx': 0,\n", 169 | " 'feat_val': 0.53194399999999997,\n", 170 | " 'left': 101.73699325000001,\n", 171 | " 'right': {'feat_idx': 0,\n", 172 | " 'feat_val': 0.546601,\n", 173 | " 'left': 110.979946,\n", 174 | " 'right': 109.38961049999999}},\n", 175 | " 'right': 97.200180249999988},\n", 176 | " 'right': 123.2101316},\n", 177 | " 'right': 93.673449714285724},\n", 178 | " 'right': {'feat_idx': 0,\n", 179 | " 'feat_val': 0.66785099999999997,\n", 180 | " 'left': 114.15162428571431,\n", 181 | " 'right': {'feat_idx': 0,\n", 182 | " 'feat_val': 0.70889000000000002,\n", 183 | " 'left': {'feat_idx': 0,\n", 184 | " 'feat_val': 0.69891999999999999,\n", 185 | " 'left': 108.92921799999999,\n", 186 | " 'right': 104.82495374999999},\n", 187 | " 'right': 114.554706}}},\n", 188 | " 'right': {'feat_idx': 0,\n", 189 | " 'feat_val': 0.95390200000000003,\n", 190 | " 'left': {'feat_idx': 0,\n", 191 | " 'feat_val': 0.76332800000000001,\n", 192 | " 'left': 78.085643250000004,\n", 193 | " 'right': {'feat_idx': 0,\n", 194 | " 'feat_val': 0.79819799999999996,\n", 195 | " 'left': 102.35780185714285,\n", 196 | " 'right': {'feat_idx': 0,\n", 197 | " 'feat_val': 0.83858699999999997,\n", 198 | " 'left': {'feat_idx': 0,\n", 199 | " 'feat_val': 0.81521500000000002,\n", 200 | " 'left': 88.784498800000009,\n", 201 | " 'right': 81.110151999999999},\n", 202 | " 'right': {'feat_idx': 0,\n", 203 | " 'feat_val': 0.94882200000000005,\n", 204 | " 'left': {'feat_idx': 0,\n", 205 | " 'feat_val': 0.85642099999999999,\n", 206 | " 'left': 95.275843166666661,\n", 207 | " 'right': {'feat_idx': 0,\n", 208 | " 'feat_val': 0.912161,\n", 209 | " 'left': {'feat_idx': 0,\n", 210 | " 'feat_val': 0.89668300000000001,\n", 211 | " 'left': {'feat_idx': 0,\n", 212 | " 'feat_val': 0.88361500000000004,\n", 213 | " 'left': 102.25234449999999,\n", 214 | " 'right': 95.181792999999999},\n", 215 | " 'right': 104.82540899999999},\n", 216 | " 'right': 96.452866999999998}},\n", 217 | " 'right': 87.310387500000004}}}},\n", 218 | " 'right': {'feat_idx': 0,\n", 219 | " 'feat_val': 0.96039799999999997,\n", 220 | " 'left': 112.42895575000001,\n", 221 | " 'right': 105.24862350000001}}}}" 222 | ] 223 | }, 224 | "execution_count": 5, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "tree" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## 使用测试数据进行后剪枝" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 6, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "data_test = load_data('ex2test.txt')" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 7, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "name": "stdout", 256 | "output_type": "stream", 257 | "text": [ 258 | "merged\n", 259 | "merged\n", 260 | "merged\n", 261 | "merged\n", 262 | "merged\n", 263 | "merged\n", 264 | "merged\n", 265 | "merged\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "pruned_tree = postprune(tree, data_test)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 8, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "{'feat_idx': 0,\n", 282 | " 'feat_val': 0.50854200000000005,\n", 283 | " 'left': {'feat_idx': 0,\n", 284 | " 'feat_val': 0.46324100000000001,\n", 285 | " 'left': {'feat_idx': 0,\n", 286 | " 'feat_val': 0.13062599999999999,\n", 287 | " 'left': {'feat_idx': 0,\n", 288 | " 'feat_val': 0.085111000000000006,\n", 289 | " 'left': 0.77361664285714249,\n", 290 | " 'right': 6.5098432857142843},\n", 291 | " 'right': {'feat_idx': 0,\n", 292 | " 'feat_val': 0.37738300000000002,\n", 293 | " 'left': {'feat_idx': 0,\n", 294 | " 'feat_val': 0.3417,\n", 295 | " 'left': {'feat_idx': 0,\n", 296 | " 'feat_val': 0.32889000000000002,\n", 297 | " 'left': {'feat_idx': 0,\n", 298 | " 'feat_val': 0.30031799999999997,\n", 299 | " 'left': {'feat_idx': 0,\n", 300 | " 'feat_val': 0.17652300000000001,\n", 301 | " 'left': -9.1779362500000019,\n", 302 | " 'right': {'feat_idx': 0,\n", 303 | " 'feat_val': 0.20399300000000001,\n", 304 | " 'left': 3.4496025000000001,\n", 305 | " 'right': {'feat_idx': 0,\n", 306 | " 'feat_val': 0.21832099999999999,\n", 307 | " 'left': -11.822278500000001,\n", 308 | " 'right': {'feat_idx': 0,\n", 309 | " 'feat_val': 0.228628,\n", 310 | " 'left': 6.770429,\n", 311 | " 'right': {'feat_idx': 0,\n", 312 | " 'feat_val': 0.26463900000000001,\n", 313 | " 'left': -13.070501,\n", 314 | " 'right': 0.40377471428571476}}}}},\n", 315 | " 'right': -19.994155200000002},\n", 316 | " 'right': 15.059290750000001},\n", 317 | " 'right': {'feat_idx': 0,\n", 318 | " 'feat_val': 0.35147800000000001,\n", 319 | " 'left': -22.693879600000002,\n", 320 | " 'right': -15.085111749999999}},\n", 321 | " 'right': {'feat_idx': 0,\n", 322 | " 'feat_val': 0.44619599999999998,\n", 323 | " 'left': {'feat_idx': 0,\n", 324 | " 'feat_val': 0.41894300000000001,\n", 325 | " 'left': 1.3830608750000011,\n", 326 | " 'right': 14.38417875},\n", 327 | " 'right': -12.558604833333334}}},\n", 328 | " 'right': {'feat_idx': 0,\n", 329 | " 'feat_val': 0.48380299999999998,\n", 330 | " 'left': 3.4331330000000007,\n", 331 | " 'right': 12.50675925}},\n", 332 | " 'right': {'feat_idx': 0,\n", 333 | " 'feat_val': 0.73163599999999995,\n", 334 | " 'left': {'feat_idx': 0,\n", 335 | " 'feat_val': 0.64237299999999997,\n", 336 | " 'left': {'feat_idx': 0,\n", 337 | " 'feat_val': 0.61886799999999997,\n", 338 | " 'left': {'feat_idx': 0,\n", 339 | " 'feat_val': 0.58541299999999996,\n", 340 | " 'left': {'feat_idx': 0,\n", 341 | " 'feat_val': 0.56030100000000005,\n", 342 | " 'left': {'feat_idx': 0,\n", 343 | " 'feat_val': 0.53194399999999997,\n", 344 | " 'left': 101.73699325000001,\n", 345 | " 'right': 110.18477824999999},\n", 346 | " 'right': 97.200180249999988},\n", 347 | " 'right': 123.2101316},\n", 348 | " 'right': 93.673449714285724},\n", 349 | " 'right': {'feat_idx': 0,\n", 350 | " 'feat_val': 0.66785099999999997,\n", 351 | " 'left': 114.15162428571431,\n", 352 | " 'right': {'feat_idx': 0,\n", 353 | " 'feat_val': 0.70889000000000002,\n", 354 | " 'left': 106.87708587499999,\n", 355 | " 'right': 114.554706}}},\n", 356 | " 'right': {'feat_idx': 0,\n", 357 | " 'feat_val': 0.95390200000000003,\n", 358 | " 'left': {'feat_idx': 0,\n", 359 | " 'feat_val': 0.76332800000000001,\n", 360 | " 'left': 78.085643250000004,\n", 361 | " 'right': {'feat_idx': 0,\n", 362 | " 'feat_val': 0.79819799999999996,\n", 363 | " 'left': 102.35780185714285,\n", 364 | " 'right': {'feat_idx': 0,\n", 365 | " 'feat_val': 0.83858699999999997,\n", 366 | " 'left': 84.947325400000011,\n", 367 | " 'right': {'feat_idx': 0,\n", 368 | " 'feat_val': 0.94882200000000005,\n", 369 | " 'left': {'feat_idx': 0,\n", 370 | " 'feat_val': 0.85642099999999999,\n", 371 | " 'left': 95.275843166666661,\n", 372 | " 'right': {'feat_idx': 0,\n", 373 | " 'feat_val': 0.912161,\n", 374 | " 'left': {'feat_idx': 0,\n", 375 | " 'feat_val': 0.89668300000000001,\n", 376 | " 'left': 98.717068749999996,\n", 377 | " 'right': 104.82540899999999},\n", 378 | " 'right': 96.452866999999998}},\n", 379 | " 'right': 87.310387500000004}}}},\n", 380 | " 'right': 108.838789625}}}" 381 | ] 382 | }, 383 | "execution_count": 8, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "pruned_tree" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "metadata": {}, 395 | "source": [ 396 | "## 生成树结构dot文件用于显示" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 9, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "with open('ex2_prune.dot', 'w') as f:\n", 406 | " content = dotify(pruned_tree)\n", 407 | " f.write(content)" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 10, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "\u001b[1m\u001b[36m__pycache__\u001b[m\u001b[m/ ex2test.txt\r\n", 420 | "\u001b[1m\u001b[36mdot\u001b[m\u001b[m/ \u001b[1m\u001b[36mpic\u001b[m\u001b[m/\r\n", 421 | "ex0.txt prune.py\r\n", 422 | "ex00.txt regression_tree.py\r\n", 423 | "ex2.txt 后剪枝.ipynb\r\n", 424 | "ex2_prune.dot 分段函数回归树.ipynb\r\n" 425 | ] 426 | } 427 | ], 428 | "source": [ 429 | "ls" 430 | ] 431 | } 432 | ], 433 | "metadata": { 434 | "kernelspec": { 435 | "display_name": "Python 3", 436 | "language": "python", 437 | "name": "python3" 438 | }, 439 | "language_info": { 440 | "codemirror_mode": { 441 | "name": "ipython", 442 | "version": 3 443 | }, 444 | "file_extension": ".py", 445 | "mimetype": "text/x-python", 446 | "name": "python", 447 | "nbconvert_exporter": "python", 448 | "pygments_lexer": "ipython3", 449 | "version": "3.5.3" 450 | } 451 | }, 452 | "nbformat": 4, 453 | "nbformat_minor": 2 454 | } 455 | --------------------------------------------------------------------------------