├── README.md ├── img ├── feature_extraction.png ├── ipython_notebook.png ├── iris_2d_plot.png ├── leastsq1.png ├── leastsq2.png ├── leastsq3.png ├── linearsvc_boundary.png ├── multiclass_target.png ├── multiclass_target2.png ├── perceptron1.png ├── perceptron2.png ├── perceptron3.png ├── perceptron4.png ├── perceptron5.png ├── perceptron6.png ├── perceptron7.png ├── perceptron8.png ├── perceptron9.png └── plot_iris_dataset_001.png ├── iris └── iris.py ├── leastsq ├── iris_leastsq.py ├── iris_multi_leastsq.py └── leastsq.py ├── perceptron ├── ml_perceptron.py ├── ml_perceptron1.png ├── ml_perceptron2.png ├── reflection1.png ├── reflection2.png ├── sigmoid.png ├── simple_perceptron.py ├── xor1.png ├── xor2.png └── xor3.png ├── svm ├── c_svm.py ├── linear_svm.py ├── svm.001.png ├── svm.002.png ├── svm.003.png ├── svm.004.png ├── svm.005.png ├── svm.006.png ├── svm.007.png ├── svm.008.png ├── svm.009.png ├── svm.010.png ├── svm.011.png ├── svm.012.png ├── svm.013.png ├── svm.014.png ├── svm.015.png ├── svm.016.png ├── svm.017.png ├── svm.018.png ├── svm.019.png ├── svm.020.png ├── svm.021.png ├── svm.022.png └── template_linear_svm.py └── wine ├── wine.data ├── wine_classify.py └── wine_xval.py /README.md: -------------------------------------------------------------------------------- 1 | machine-learning-2014 2 | ===================== 3 | 4 | 機械学習分科会 for TSG2014 winter 5 | 詳細は[Wiki](https://github.com/levelfour/machine-learning-2014/wiki)へ. 6 | 7 | # 内容 8 | + Pythonの環境構築 9 | + 線形代数の基礎知識 10 | + 機械学習の基礎知識 11 | + Naive Bayes 12 | + Perceptron 13 | + SVM 14 | + Decision Tree/Random Forest 15 | 16 | # 日程 17 | ### 第1回(10/30 Thu., 10/31 Fri.) 18 | + Pythonの環境構築 19 | + Python処理系(2.7がbetter) 20 | + Anaconda(numpy, scipy, matplotlib, scikit-learn) 21 | + 機械学習の基礎知識 22 | + iris classification 23 | 24 | ### 第2回 25 | + Naive Bayes 26 | + text classification 27 | 28 | ### 第3回 29 | + Logistic Regression 30 | + Perceptron 31 | 32 | ### 第4回 33 | + SVM 34 | + numeric classification 35 | 36 | ### 第5回 37 | + Decision Tree 38 | + Adaboost 39 | + Random Forest 40 | -------------------------------------------------------------------------------- /img/feature_extraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/feature_extraction.png -------------------------------------------------------------------------------- /img/ipython_notebook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/ipython_notebook.png -------------------------------------------------------------------------------- /img/iris_2d_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/iris_2d_plot.png -------------------------------------------------------------------------------- /img/leastsq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/leastsq1.png -------------------------------------------------------------------------------- /img/leastsq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/leastsq2.png -------------------------------------------------------------------------------- /img/leastsq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/leastsq3.png -------------------------------------------------------------------------------- /img/linearsvc_boundary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/linearsvc_boundary.png -------------------------------------------------------------------------------- /img/multiclass_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/multiclass_target.png -------------------------------------------------------------------------------- /img/multiclass_target2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/multiclass_target2.png -------------------------------------------------------------------------------- /img/perceptron1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron1.png -------------------------------------------------------------------------------- /img/perceptron2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron2.png -------------------------------------------------------------------------------- /img/perceptron3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron3.png -------------------------------------------------------------------------------- /img/perceptron4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron4.png -------------------------------------------------------------------------------- /img/perceptron5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron5.png -------------------------------------------------------------------------------- /img/perceptron6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron6.png -------------------------------------------------------------------------------- /img/perceptron7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron7.png -------------------------------------------------------------------------------- /img/perceptron8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron8.png -------------------------------------------------------------------------------- /img/perceptron9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/perceptron9.png -------------------------------------------------------------------------------- /img/plot_iris_dataset_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/img/plot_iris_dataset_001.png -------------------------------------------------------------------------------- /iris/iris.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import matplotlib.pyplot as plt 5 | from sklearn import datasets 6 | import numpy as np 7 | 8 | iris = datasets.load_iris() 9 | # 先頭から100個のデータ(setosaとversicolorを抽出) 10 | # 特徴は0番目(sepal length)と2列目(petal length)を使用 11 | data = iris.data[0:100][:,::2] 12 | target= iris.target[0:100] 13 | 14 | from sklearn.svm import LinearSVC 15 | from sklearn import cross_validation 16 | 17 | # 学習データとテストデータを4:1に分割 18 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 19 | clf = LinearSVC() # 線形SVM 20 | clf.fit(train_x, train_y) # 学習 21 | pred = clf.predict(test_x) # テストデータの識別 22 | print list(pred == test_y).count(True) / float(len(test_y)) 23 | 24 | h = .02 25 | x_min, x_max = data[:,0].min()-1, data[:,0].max()+1 26 | y_min, y_max = data[:,1].min()-1, data[:,1].max()+1 27 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 28 | np.arange(y_min, y_max, h)) 29 | Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) 30 | Z = Z.reshape(xx.shape) 31 | print xx.shape 32 | print xx 33 | # 等高線(contour)プロット 34 | plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) 35 | 36 | # データの0列目(sepal length)と2列目(petal length)をプロット 37 | # enumerate関数はループを回すときにイテレータにインデックスをセットする 38 | for (i, d) in enumerate(data): 39 | if target[i] == 0: # target == setosa 40 | color = '#008800' 41 | elif target[i] == 1: # target == versicolor 42 | color = '#ff0000' 43 | plt.scatter(d[0], d[1], c=color) 44 | plt.xlabel("sepal length[cm]") 45 | plt.ylabel("petal length[cm]") 46 | plt.xlim((x_min, x_max)) 47 | plt.ylim((y_min, y_max)) 48 | plt.show() 49 | -------------------------------------------------------------------------------- /leastsq/iris_leastsq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import cross_validation 7 | from sklearn import metrics 8 | 9 | iris = datasets.load_iris() 10 | data = iris.data[0:100] 11 | target = [1 if t == 1 else -1 for t in iris.target[0:100]] 12 | 13 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 14 | 15 | # 最小二乗法で学習 16 | w = np.linalg.inv(train_x.T.dot(train_x)).dot(train_x.T).dot(train_y) 17 | 18 | # 最小二乗法で推定 19 | pred_y = np.array([1 if w.dot(x) > 0 else -1 for x in test_x]) 20 | 21 | # テストデータに対する正答率 22 | print metrics.accuracy_score(test_y, pred_y) 23 | -------------------------------------------------------------------------------- /leastsq/iris_multi_leastsq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import cross_validation 7 | from sklearn import metrics 8 | 9 | iris = datasets.load_iris() 10 | data = iris.data 11 | target = np.array([ 12 | [1 if t == 0 else 0 for t in iris.target], 13 | [1 if t == 1 else 0 for t in iris.target], 14 | [1 if t == 2 else 0 for t in iris.target]]).T 15 | 16 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 17 | 18 | # 最小二乗法で学習 19 | w = np.linalg.inv(train_x.T.dot(train_x)).dot(train_x.T).dot(train_y) 20 | 21 | print np.array([w.T.dot(d) for d in test_x]) 22 | # 最小二乗法で推定 23 | pred_y = np.array([np.argmax(w.T.dot(d)) for d in test_x]) 24 | true_y = np.array([np.argmax(d) for d in test_y]) 25 | 26 | # テストデータに対する正答率 27 | print metrics.accuracy_score(true_y, pred_y) 28 | -------------------------------------------------------------------------------- /leastsq/leastsq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | c1 = [] 8 | c2 = [] 9 | m1 = (-1, 2) 10 | m2 = (1, -2) 11 | 12 | for i in range(10): 13 | c1.append([m1[0]+np.random.normal()/1.5, m1[1]+np.random.normal()/1.5]) 14 | c2.append([m2[0]+np.random.normal()/1.5, m2[1]+np.random.normal()/1.5]) 15 | 16 | plt.scatter(map(lambda x: x[0], c1), map(lambda x: x[1], c1), c='#008800') 17 | plt.scatter(map(lambda x: x[0], c2), map(lambda x: x[1], c2), c='#0000aa') 18 | 19 | w = (0.5, 0) 20 | xx = np.arange(-6, 6, 0.1) 21 | yy = [] 22 | for x in xx: 23 | yy.append(w[0]*x+w[1]) 24 | plt.plot(xx, yy) 25 | 26 | plt.quiver(0, 0, w[0], -1, angles='xy', scale_units='xy', scale=1) 27 | 28 | lim = (-6, 4) 29 | plt.xlim(lim) 30 | plt.ylim(lim) 31 | plt.grid() 32 | plt.draw() 33 | plt.show() 34 | -------------------------------------------------------------------------------- /perceptron/ml_perceptron.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import cross_validation 7 | from sklearn import metrics 8 | 9 | class MultiLayerPerceptron: 10 | def __init__(self, dim, n_mnodes, n_onodes=1, eta=1, beta=1): 11 | # expでオーバーフローが発生することがあるがひとまず無視する 12 | np.seterr(over='ignore') 13 | np.seterr(divide='raise') 14 | # 重みベクトルの初期値はランダムなベクトルとする 15 | self.w = np.random.uniform(-1., 1., (n_mnodes, dim)) 16 | self.v = np.random.uniform(-1., 1., (n_onodes, n_mnodes)) 17 | self.n_mnodes = n_mnodes # 中間層のノード数 18 | self.n_onodes = n_onodes # 出力層のノード数 19 | self.eta = eta 20 | self.beta = beta 21 | 22 | # シグモイド関数 23 | def g(self, x): 24 | return 1.0/(1.0+np.exp(-self.beta*x)) 25 | 26 | # シグモイド関数の1階導関数 27 | def g_(self, x): 28 | return self.beta*self.g(x)*(1-self.g(x)) 29 | 30 | # 入力ベクトルから中間層の出力を得る 31 | def hidden_layer(self, x): 32 | return np.array([self.g(self.w[m].T.dot(x)) for m in range(self.n_mnodes)]) 33 | 34 | def fit(self, x, t): 35 | z = self.predict(x) 36 | # 統一的に扱うためにラベルがスカラーの場合もベクトルで扱う 37 | if not isinstance(t, np.ndarray): 38 | t = np.array([t]) 39 | # 最急降下法で係数ベクトルを更新する 40 | # 計算式は頑張って読んでくださいm(_ _)m 41 | self.w = np.array([ 42 | self.w[m]-self.eta*2*np.dot(z-t,self.v.T[m])*self.g_(np.dot(self.w[m],x))*x 43 | for m in range(self.n_mnodes)]) 44 | self.v = np.array([ 45 | self.v[k]-self.eta*2*(z[k]-t[k])*self.hidden_layer(x) 46 | for k in range(self.n_onodes)]) 47 | 48 | def predict(self, x): 49 | yy = self.v.dot(self.hidden_layer(x)) 50 | return np.array([1 if y >= 0 else -1 for y in yy]) 51 | 52 | if __name__ == "__main__": 53 | # 数字パターンデータを用意 54 | digits = datasets.load_digits() 55 | data = digits.data 56 | # ラベルは10次元ベクトルで、正しいラベルのインデックスのみ1、他の要素は-1 57 | # 例えば1なら[-1,1,-1,-1,...] 58 | target = np.array([[-1 for i in range(10)] for j in range(len(data))]) 59 | for i in range(len(data)): 60 | target[i][digits.target[i]] = 1 61 | 62 | # 学習データとテストデータに分割 63 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 64 | 65 | p = MultiLayerPerceptron(dim=data[0].shape[0], eta=0.1, beta=0.01, n_mnodes=10, n_onodes=10) 66 | for i in range(len(train_x)): 67 | p.fit(train_x[i], train_y[i]) 68 | 69 | pred_y = np.array([p.predict(x) for x in test_x]) 70 | print metrics.accuracy_score(pred_y, test_y) 71 | -------------------------------------------------------------------------------- /perceptron/ml_perceptron1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/ml_perceptron1.png -------------------------------------------------------------------------------- /perceptron/ml_perceptron2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/ml_perceptron2.png -------------------------------------------------------------------------------- /perceptron/reflection1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/reflection1.png -------------------------------------------------------------------------------- /perceptron/reflection2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/reflection2.png -------------------------------------------------------------------------------- /perceptron/sigmoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/sigmoid.png -------------------------------------------------------------------------------- /perceptron/simple_perceptron.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import metrics 7 | from sklearn import cross_validation 8 | 9 | class SimplePerceptron: 10 | def __init__(self, dim, eta=0): 11 | # wに初期値として単位ベクトルを設定する 12 | self.w = np.array([1 for i in range(dim)]) 13 | self.eta = eta 14 | 15 | def fit(self, x, y): 16 | x = y * x # 学習データの反転 17 | if self.w.T.dot(x) >= 0: 18 | pass 19 | else: 20 | self.w = self.w + self.eta * x 21 | 22 | def predict(self, x): 23 | y = self.w.T.dot(x) 24 | return 1 if y >= 0 else -1 25 | 26 | if __name__ == '__main__': 27 | digits = datasets.load_digits() 28 | data = None 29 | target = None 30 | for i in range(len(digits.data)): 31 | # プリセットの手書き数字から0か1のデータのみを読み込む 32 | if digits.target[i] == 0 or digits.target[i] == 1: 33 | if data == None or target == None: 34 | data = np.array([digits.data[i]]) 35 | target = np.array([1 if digits.target[i] == 1 else -1]) 36 | else: 37 | data = np.r_[data, [digits.data[i]]] 38 | # 1と-1にラベリングし直す 39 | target = np.r_[target, 1 if digits.target[i] == 1 else -1] 40 | 41 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 42 | 43 | p = SimplePerceptron(dim = data[0].shape[0], eta = 0.3) 44 | for i in range(len(train_x)): 45 | # 学習データを逐次学習する 46 | p.fit(train_x[i], train_y[i]) 47 | 48 | # テストデータに対して推定する 49 | pred_y = np.array([p.predict(x) for x in test_x]) 50 | # 正答率を表示 51 | print metrics.accuracy_score(pred_y, test_y) 52 | -------------------------------------------------------------------------------- /perceptron/xor1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/xor1.png -------------------------------------------------------------------------------- /perceptron/xor2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/xor2.png -------------------------------------------------------------------------------- /perceptron/xor3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/perceptron/xor3.png -------------------------------------------------------------------------------- /svm/c_svm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import metrics 7 | from sklearn import cross_validation 8 | 9 | cmax = 1000 # 最急降下法の試行回数 10 | 11 | class LinearSVM: 12 | def __init__(self, dim, eta, C=0): 13 | self.dim = dim # 学習データの次元数 14 | self.C = C # コストパラメータ 15 | self.eta = eta # 最急降下法の学習率 16 | self.a = None # 未定乗数ベクトル 17 | self.w = np.zeros((dim,)) # 識別関数 18 | 19 | def fit(self, X, Y): 20 | data_n = X.shape[0] # 学習データ数 21 | # 未定乗数ベクトルαの初期値は0ベクトルとする 22 | self.a = np.zeros((data_n,)) 23 | # ones = (1,1,...,1)というベクトルを便宜上用意する 24 | ones = np.array([1 for _ in range(data_n)]) 25 | H = np.array([[Y[i]*Y[j]*np.dot(X[i],X[j]) \ 26 | for i in range(data_n)] for j in range(data_n)]) 27 | # αについてのLagrange関数 28 | def L(av): return np.dot(self.a,ones)-self.a.T.dot(H.dot(self.a))/2. 29 | # Lの導関数 30 | def dL(av): return ones - np.dot(H,av) 31 | 32 | # 勾配法でLを最大化→αを求める 33 | for _ in range(cmax): 34 | a = self.a + self.eta*dL(self.a) 35 | if self.C == 0 or np.product(np.array(map(lambda x: 0 <= x and x <= self.C, a))) != 0: 36 | self.a = a 37 | 38 | # 最適な線形識別関数を求める 39 | self.w = np.zeros((self.dim,)) 40 | for i in range(data_n): 41 | self.w = self.w + self.a[i]*Y[i]*X[i] 42 | 43 | def predict(self, x): 44 | if np.dot(self.w, x) >= 0: 45 | return 1 46 | else: 47 | return -1 48 | 49 | if __name__ == '__main__': 50 | digits = datasets.load_digits() 51 | data = [] 52 | target = [] 53 | for i in range(len(digits.data)): 54 | # プリセットの手書き数字から0か1のデータのみを読み込む 55 | if digits.target[i] == 0 or digits.target[i] == 1: 56 | if data == [] or target == []: 57 | data = np.array([digits.data[i]]) 58 | target = np.array([1 if digits.target[i] == 1 else -1]) 59 | else: 60 | data = np.r_[data, [digits.data[i]]] 61 | # 1と-1にラベリングし直す 62 | target = np.r_[target, 1 if digits.target[i] == 1 else -1] 63 | 64 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 65 | 66 | clf = LinearSVM(dim=train_x.shape[1], C=10, eta=1e-15) 67 | clf.fit(train_x, train_y) 68 | pred_y = np.array([clf.predict(x) for x in test_x]) 69 | print metrics.accuracy_score(pred_y, test_y) 70 | -------------------------------------------------------------------------------- /svm/linear_svm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import metrics 7 | from sklearn import cross_validation 8 | 9 | cmax = 1000 # 最急降下法の試行回数 10 | 11 | class LinearSVM: 12 | def __init__(self, dim, eta): 13 | self.dim = dim # 学習データの次元数 14 | self.eta = eta # 最急降下法の学習率 15 | self.a = None # 未定乗数ベクトル 16 | self.w = np.zeros((dim,)) # 識別関数 17 | 18 | def fit(self, X, Y): 19 | data_n = X.shape[0] # 学習データ数 20 | # 未定乗数ベクトルαの初期値は0ベクトルとする 21 | self.a = np.zeros((data_n,)) 22 | # ones = (1,1,...,1)というベクトルを便宜上用意する 23 | ones = np.array([1 for _ in range(data_n)]) 24 | H = np.array([[Y[i]*Y[j]*np.dot(X[i],X[j]) \ 25 | for i in range(data_n)] for j in range(data_n)]) 26 | # αについてのLagrange関数 27 | def L(av): return np.dot(self.a,ones)-self.a.T.dot(H.dot(self.a))/2. 28 | # Lの導関数 29 | def dL(av): return ones - np.dot(H,av) 30 | 31 | # 勾配法でLを最大化→αを求める 32 | for _ in range(cmax): 33 | self.a = self.a + self.eta*dL(self.a) 34 | 35 | # 最適な線形識別関数を求める 36 | self.w = np.zeros((self.dim,)) 37 | for i in range(data_n): 38 | self.w = self.w + self.a[i]*Y[i]*X[i] 39 | 40 | def predict(self, x): 41 | if np.dot(self.w, x) >= 0: 42 | return 1 43 | else: 44 | return -1 45 | 46 | if __name__ == '__main__': 47 | digits = datasets.load_digits() 48 | data = [] 49 | target = [] 50 | for i in range(len(digits.data)): 51 | # プリセットの手書き数字から0か1のデータのみを読み込む 52 | if digits.target[i] == 0 or digits.target[i] == 1: 53 | if data == [] or target == []: 54 | data = np.array([digits.data[i]]) 55 | target = np.array([1 if digits.target[i] == 1 else -1]) 56 | else: 57 | data = np.r_[data, [digits.data[i]]] 58 | # 1と-1にラベリングし直す 59 | target = np.r_[target, 1 if digits.target[i] == 1 else -1] 60 | 61 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 62 | 63 | clf = LinearSVM(dim=train_x.shape[1], eta=1e-15) 64 | clf.fit(train_x, train_y) 65 | pred_y = np.array([clf.predict(x) for x in test_x]) 66 | print metrics.accuracy_score(pred_y, test_y) 67 | -------------------------------------------------------------------------------- /svm/svm.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.001.png -------------------------------------------------------------------------------- /svm/svm.002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.002.png -------------------------------------------------------------------------------- /svm/svm.003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.003.png -------------------------------------------------------------------------------- /svm/svm.004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.004.png -------------------------------------------------------------------------------- /svm/svm.005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.005.png -------------------------------------------------------------------------------- /svm/svm.006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.006.png -------------------------------------------------------------------------------- /svm/svm.007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.007.png -------------------------------------------------------------------------------- /svm/svm.008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.008.png -------------------------------------------------------------------------------- /svm/svm.009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.009.png -------------------------------------------------------------------------------- /svm/svm.010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.010.png -------------------------------------------------------------------------------- /svm/svm.011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.011.png -------------------------------------------------------------------------------- /svm/svm.012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.012.png -------------------------------------------------------------------------------- /svm/svm.013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.013.png -------------------------------------------------------------------------------- /svm/svm.014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.014.png -------------------------------------------------------------------------------- /svm/svm.015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.015.png -------------------------------------------------------------------------------- /svm/svm.016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.016.png -------------------------------------------------------------------------------- /svm/svm.017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.017.png -------------------------------------------------------------------------------- /svm/svm.018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.018.png -------------------------------------------------------------------------------- /svm/svm.019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.019.png -------------------------------------------------------------------------------- /svm/svm.020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.020.png -------------------------------------------------------------------------------- /svm/svm.021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.021.png -------------------------------------------------------------------------------- /svm/svm.022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/levelfour/machine-learning-2014/7a0ac3d079f9faa499ce517eb52ef3ca10293a1b/svm/svm.022.png -------------------------------------------------------------------------------- /svm/template_linear_svm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import datasets 6 | from sklearn import metrics 7 | from sklearn import cross_validation 8 | 9 | class LinearSVM: 10 | def __init__(self): 11 | pass 12 | 13 | def fit(self, x, y): 14 | pass 15 | 16 | def predict(self, x): 17 | pass 18 | 19 | if __name__ == '__main__': 20 | digits = datasets.load_digits() 21 | data = [] 22 | target = [] 23 | for i in range(len(digits.data)): 24 | # プリセットの手書き数字から0か1のデータのみを読み込む 25 | if digits.target[i] == 0 or digits.target[i] == 1: 26 | if data == [] or target == []: 27 | data = np.array([digits.data[i]]) 28 | target = np.array([1 if digits.target[i] == 1 else -1]) 29 | else: 30 | data = np.r_[data, [digits.data[i]]] 31 | # 1と-1にラベリングし直す 32 | target = np.r_[target, 1 if digits.target[i] == 1 else -1] 33 | 34 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(data, target, test_size=0.2) 35 | 36 | clf = LinearSVM() 37 | clf.fit(train_x, train_y) 38 | pred_y = np.array([clf.predict(x) for x in test_x]) 39 | print metrics.accuracy_score(pred_y, test_y) 40 | -------------------------------------------------------------------------------- /wine/wine.data: -------------------------------------------------------------------------------- 1 | 1,14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065 2 | 1,13.2,1.78,2.14,11.2,100,2.65,2.76,.26,1.28,4.38,1.05,3.4,1050 3 | 1,13.16,2.36,2.67,18.6,101,2.8,3.24,.3,2.81,5.68,1.03,3.17,1185 4 | 1,14.37,1.95,2.5,16.8,113,3.85,3.49,.24,2.18,7.8,.86,3.45,1480 5 | 1,13.24,2.59,2.87,21,118,2.8,2.69,.39,1.82,4.32,1.04,2.93,735 6 | 1,14.2,1.76,2.45,15.2,112,3.27,3.39,.34,1.97,6.75,1.05,2.85,1450 7 | 1,14.39,1.87,2.45,14.6,96,2.5,2.52,.3,1.98,5.25,1.02,3.58,1290 8 | 1,14.06,2.15,2.61,17.6,121,2.6,2.51,.31,1.25,5.05,1.06,3.58,1295 9 | 1,14.83,1.64,2.17,14,97,2.8,2.98,.29,1.98,5.2,1.08,2.85,1045 10 | 1,13.86,1.35,2.27,16,98,2.98,3.15,.22,1.85,7.22,1.01,3.55,1045 11 | 1,14.1,2.16,2.3,18,105,2.95,3.32,.22,2.38,5.75,1.25,3.17,1510 12 | 1,14.12,1.48,2.32,16.8,95,2.2,2.43,.26,1.57,5,1.17,2.82,1280 13 | 1,13.75,1.73,2.41,16,89,2.6,2.76,.29,1.81,5.6,1.15,2.9,1320 14 | 1,14.75,1.73,2.39,11.4,91,3.1,3.69,.43,2.81,5.4,1.25,2.73,1150 15 | 1,14.38,1.87,2.38,12,102,3.3,3.64,.29,2.96,7.5,1.2,3,1547 16 | 1,13.63,1.81,2.7,17.2,112,2.85,2.91,.3,1.46,7.3,1.28,2.88,1310 17 | 1,14.3,1.92,2.72,20,120,2.8,3.14,.33,1.97,6.2,1.07,2.65,1280 18 | 1,13.83,1.57,2.62,20,115,2.95,3.4,.4,1.72,6.6,1.13,2.57,1130 19 | 1,14.19,1.59,2.48,16.5,108,3.3,3.93,.32,1.86,8.7,1.23,2.82,1680 20 | 1,13.64,3.1,2.56,15.2,116,2.7,3.03,.17,1.66,5.1,.96,3.36,845 21 | 1,14.06,1.63,2.28,16,126,3,3.17,.24,2.1,5.65,1.09,3.71,780 22 | 1,12.93,3.8,2.65,18.6,102,2.41,2.41,.25,1.98,4.5,1.03,3.52,770 23 | 1,13.71,1.86,2.36,16.6,101,2.61,2.88,.27,1.69,3.8,1.11,4,1035 24 | 1,12.85,1.6,2.52,17.8,95,2.48,2.37,.26,1.46,3.93,1.09,3.63,1015 25 | 1,13.5,1.81,2.61,20,96,2.53,2.61,.28,1.66,3.52,1.12,3.82,845 26 | 1,13.05,2.05,3.22,25,124,2.63,2.68,.47,1.92,3.58,1.13,3.2,830 27 | 1,13.39,1.77,2.62,16.1,93,2.85,2.94,.34,1.45,4.8,.92,3.22,1195 28 | 1,13.3,1.72,2.14,17,94,2.4,2.19,.27,1.35,3.95,1.02,2.77,1285 29 | 1,13.87,1.9,2.8,19.4,107,2.95,2.97,.37,1.76,4.5,1.25,3.4,915 30 | 1,14.02,1.68,2.21,16,96,2.65,2.33,.26,1.98,4.7,1.04,3.59,1035 31 | 1,13.73,1.5,2.7,22.5,101,3,3.25,.29,2.38,5.7,1.19,2.71,1285 32 | 1,13.58,1.66,2.36,19.1,106,2.86,3.19,.22,1.95,6.9,1.09,2.88,1515 33 | 1,13.68,1.83,2.36,17.2,104,2.42,2.69,.42,1.97,3.84,1.23,2.87,990 34 | 1,13.76,1.53,2.7,19.5,132,2.95,2.74,.5,1.35,5.4,1.25,3,1235 35 | 1,13.51,1.8,2.65,19,110,2.35,2.53,.29,1.54,4.2,1.1,2.87,1095 36 | 1,13.48,1.81,2.41,20.5,100,2.7,2.98,.26,1.86,5.1,1.04,3.47,920 37 | 1,13.28,1.64,2.84,15.5,110,2.6,2.68,.34,1.36,4.6,1.09,2.78,880 38 | 1,13.05,1.65,2.55,18,98,2.45,2.43,.29,1.44,4.25,1.12,2.51,1105 39 | 1,13.07,1.5,2.1,15.5,98,2.4,2.64,.28,1.37,3.7,1.18,2.69,1020 40 | 1,14.22,3.99,2.51,13.2,128,3,3.04,.2,2.08,5.1,.89,3.53,760 41 | 1,13.56,1.71,2.31,16.2,117,3.15,3.29,.34,2.34,6.13,.95,3.38,795 42 | 1,13.41,3.84,2.12,18.8,90,2.45,2.68,.27,1.48,4.28,.91,3,1035 43 | 1,13.88,1.89,2.59,15,101,3.25,3.56,.17,1.7,5.43,.88,3.56,1095 44 | 1,13.24,3.98,2.29,17.5,103,2.64,2.63,.32,1.66,4.36,.82,3,680 45 | 1,13.05,1.77,2.1,17,107,3,3,.28,2.03,5.04,.88,3.35,885 46 | 1,14.21,4.04,2.44,18.9,111,2.85,2.65,.3,1.25,5.24,.87,3.33,1080 47 | 1,14.38,3.59,2.28,16,102,3.25,3.17,.27,2.19,4.9,1.04,3.44,1065 48 | 1,13.9,1.68,2.12,16,101,3.1,3.39,.21,2.14,6.1,.91,3.33,985 49 | 1,14.1,2.02,2.4,18.8,103,2.75,2.92,.32,2.38,6.2,1.07,2.75,1060 50 | 1,13.94,1.73,2.27,17.4,108,2.88,3.54,.32,2.08,8.90,1.12,3.1,1260 51 | 1,13.05,1.73,2.04,12.4,92,2.72,3.27,.17,2.91,7.2,1.12,2.91,1150 52 | 1,13.83,1.65,2.6,17.2,94,2.45,2.99,.22,2.29,5.6,1.24,3.37,1265 53 | 1,13.82,1.75,2.42,14,111,3.88,3.74,.32,1.87,7.05,1.01,3.26,1190 54 | 1,13.77,1.9,2.68,17.1,115,3,2.79,.39,1.68,6.3,1.13,2.93,1375 55 | 1,13.74,1.67,2.25,16.4,118,2.6,2.9,.21,1.62,5.85,.92,3.2,1060 56 | 1,13.56,1.73,2.46,20.5,116,2.96,2.78,.2,2.45,6.25,.98,3.03,1120 57 | 1,14.22,1.7,2.3,16.3,118,3.2,3,.26,2.03,6.38,.94,3.31,970 58 | 1,13.29,1.97,2.68,16.8,102,3,3.23,.31,1.66,6,1.07,2.84,1270 59 | 1,13.72,1.43,2.5,16.7,108,3.4,3.67,.19,2.04,6.8,.89,2.87,1285 60 | 2,12.37,.94,1.36,10.6,88,1.98,.57,.28,.42,1.95,1.05,1.82,520 61 | 2,12.33,1.1,2.28,16,101,2.05,1.09,.63,.41,3.27,1.25,1.67,680 62 | 2,12.64,1.36,2.02,16.8,100,2.02,1.41,.53,.62,5.75,.98,1.59,450 63 | 2,13.67,1.25,1.92,18,94,2.1,1.79,.32,.73,3.8,1.23,2.46,630 64 | 2,12.37,1.13,2.16,19,87,3.5,3.1,.19,1.87,4.45,1.22,2.87,420 65 | 2,12.17,1.45,2.53,19,104,1.89,1.75,.45,1.03,2.95,1.45,2.23,355 66 | 2,12.37,1.21,2.56,18.1,98,2.42,2.65,.37,2.08,4.6,1.19,2.3,678 67 | 2,13.11,1.01,1.7,15,78,2.98,3.18,.26,2.28,5.3,1.12,3.18,502 68 | 2,12.37,1.17,1.92,19.6,78,2.11,2,.27,1.04,4.68,1.12,3.48,510 69 | 2,13.34,.94,2.36,17,110,2.53,1.3,.55,.42,3.17,1.02,1.93,750 70 | 2,12.21,1.19,1.75,16.8,151,1.85,1.28,.14,2.5,2.85,1.28,3.07,718 71 | 2,12.29,1.61,2.21,20.4,103,1.1,1.02,.37,1.46,3.05,.906,1.82,870 72 | 2,13.86,1.51,2.67,25,86,2.95,2.86,.21,1.87,3.38,1.36,3.16,410 73 | 2,13.49,1.66,2.24,24,87,1.88,1.84,.27,1.03,3.74,.98,2.78,472 74 | 2,12.99,1.67,2.6,30,139,3.3,2.89,.21,1.96,3.35,1.31,3.5,985 75 | 2,11.96,1.09,2.3,21,101,3.38,2.14,.13,1.65,3.21,.99,3.13,886 76 | 2,11.66,1.88,1.92,16,97,1.61,1.57,.34,1.15,3.8,1.23,2.14,428 77 | 2,13.03,.9,1.71,16,86,1.95,2.03,.24,1.46,4.6,1.19,2.48,392 78 | 2,11.84,2.89,2.23,18,112,1.72,1.32,.43,.95,2.65,.96,2.52,500 79 | 2,12.33,.99,1.95,14.8,136,1.9,1.85,.35,2.76,3.4,1.06,2.31,750 80 | 2,12.7,3.87,2.4,23,101,2.83,2.55,.43,1.95,2.57,1.19,3.13,463 81 | 2,12,.92,2,19,86,2.42,2.26,.3,1.43,2.5,1.38,3.12,278 82 | 2,12.72,1.81,2.2,18.8,86,2.2,2.53,.26,1.77,3.9,1.16,3.14,714 83 | 2,12.08,1.13,2.51,24,78,2,1.58,.4,1.4,2.2,1.31,2.72,630 84 | 2,13.05,3.86,2.32,22.5,85,1.65,1.59,.61,1.62,4.8,.84,2.01,515 85 | 2,11.84,.89,2.58,18,94,2.2,2.21,.22,2.35,3.05,.79,3.08,520 86 | 2,12.67,.98,2.24,18,99,2.2,1.94,.3,1.46,2.62,1.23,3.16,450 87 | 2,12.16,1.61,2.31,22.8,90,1.78,1.69,.43,1.56,2.45,1.33,2.26,495 88 | 2,11.65,1.67,2.62,26,88,1.92,1.61,.4,1.34,2.6,1.36,3.21,562 89 | 2,11.64,2.06,2.46,21.6,84,1.95,1.69,.48,1.35,2.8,1,2.75,680 90 | 2,12.08,1.33,2.3,23.6,70,2.2,1.59,.42,1.38,1.74,1.07,3.21,625 91 | 2,12.08,1.83,2.32,18.5,81,1.6,1.5,.52,1.64,2.4,1.08,2.27,480 92 | 2,12,1.51,2.42,22,86,1.45,1.25,.5,1.63,3.6,1.05,2.65,450 93 | 2,12.69,1.53,2.26,20.7,80,1.38,1.46,.58,1.62,3.05,.96,2.06,495 94 | 2,12.29,2.83,2.22,18,88,2.45,2.25,.25,1.99,2.15,1.15,3.3,290 95 | 2,11.62,1.99,2.28,18,98,3.02,2.26,.17,1.35,3.25,1.16,2.96,345 96 | 2,12.47,1.52,2.2,19,162,2.5,2.27,.32,3.28,2.6,1.16,2.63,937 97 | 2,11.81,2.12,2.74,21.5,134,1.6,.99,.14,1.56,2.5,.95,2.26,625 98 | 2,12.29,1.41,1.98,16,85,2.55,2.5,.29,1.77,2.9,1.23,2.74,428 99 | 2,12.37,1.07,2.1,18.5,88,3.52,3.75,.24,1.95,4.5,1.04,2.77,660 100 | 2,12.29,3.17,2.21,18,88,2.85,2.99,.45,2.81,2.3,1.42,2.83,406 101 | 2,12.08,2.08,1.7,17.5,97,2.23,2.17,.26,1.4,3.3,1.27,2.96,710 102 | 2,12.6,1.34,1.9,18.5,88,1.45,1.36,.29,1.35,2.45,1.04,2.77,562 103 | 2,12.34,2.45,2.46,21,98,2.56,2.11,.34,1.31,2.8,.8,3.38,438 104 | 2,11.82,1.72,1.88,19.5,86,2.5,1.64,.37,1.42,2.06,.94,2.44,415 105 | 2,12.51,1.73,1.98,20.5,85,2.2,1.92,.32,1.48,2.94,1.04,3.57,672 106 | 2,12.42,2.55,2.27,22,90,1.68,1.84,.66,1.42,2.7,.86,3.3,315 107 | 2,12.25,1.73,2.12,19,80,1.65,2.03,.37,1.63,3.4,1,3.17,510 108 | 2,12.72,1.75,2.28,22.5,84,1.38,1.76,.48,1.63,3.3,.88,2.42,488 109 | 2,12.22,1.29,1.94,19,92,2.36,2.04,.39,2.08,2.7,.86,3.02,312 110 | 2,11.61,1.35,2.7,20,94,2.74,2.92,.29,2.49,2.65,.96,3.26,680 111 | 2,11.46,3.74,1.82,19.5,107,3.18,2.58,.24,3.58,2.9,.75,2.81,562 112 | 2,12.52,2.43,2.17,21,88,2.55,2.27,.26,1.22,2,.9,2.78,325 113 | 2,11.76,2.68,2.92,20,103,1.75,2.03,.6,1.05,3.8,1.23,2.5,607 114 | 2,11.41,.74,2.5,21,88,2.48,2.01,.42,1.44,3.08,1.1,2.31,434 115 | 2,12.08,1.39,2.5,22.5,84,2.56,2.29,.43,1.04,2.9,.93,3.19,385 116 | 2,11.03,1.51,2.2,21.5,85,2.46,2.17,.52,2.01,1.9,1.71,2.87,407 117 | 2,11.82,1.47,1.99,20.8,86,1.98,1.6,.3,1.53,1.95,.95,3.33,495 118 | 2,12.42,1.61,2.19,22.5,108,2,2.09,.34,1.61,2.06,1.06,2.96,345 119 | 2,12.77,3.43,1.98,16,80,1.63,1.25,.43,.83,3.4,.7,2.12,372 120 | 2,12,3.43,2,19,87,2,1.64,.37,1.87,1.28,.93,3.05,564 121 | 2,11.45,2.4,2.42,20,96,2.9,2.79,.32,1.83,3.25,.8,3.39,625 122 | 2,11.56,2.05,3.23,28.5,119,3.18,5.08,.47,1.87,6,.93,3.69,465 123 | 2,12.42,4.43,2.73,26.5,102,2.2,2.13,.43,1.71,2.08,.92,3.12,365 124 | 2,13.05,5.8,2.13,21.5,86,2.62,2.65,.3,2.01,2.6,.73,3.1,380 125 | 2,11.87,4.31,2.39,21,82,2.86,3.03,.21,2.91,2.8,.75,3.64,380 126 | 2,12.07,2.16,2.17,21,85,2.6,2.65,.37,1.35,2.76,.86,3.28,378 127 | 2,12.43,1.53,2.29,21.5,86,2.74,3.15,.39,1.77,3.94,.69,2.84,352 128 | 2,11.79,2.13,2.78,28.5,92,2.13,2.24,.58,1.76,3,.97,2.44,466 129 | 2,12.37,1.63,2.3,24.5,88,2.22,2.45,.4,1.9,2.12,.89,2.78,342 130 | 2,12.04,4.3,2.38,22,80,2.1,1.75,.42,1.35,2.6,.79,2.57,580 131 | 3,12.86,1.35,2.32,18,122,1.51,1.25,.21,.94,4.1,.76,1.29,630 132 | 3,12.88,2.99,2.4,20,104,1.3,1.22,.24,.83,5.4,.74,1.42,530 133 | 3,12.81,2.31,2.4,24,98,1.15,1.09,.27,.83,5.7,.66,1.36,560 134 | 3,12.7,3.55,2.36,21.5,106,1.7,1.2,.17,.84,5,.78,1.29,600 135 | 3,12.51,1.24,2.25,17.5,85,2,.58,.6,1.25,5.45,.75,1.51,650 136 | 3,12.6,2.46,2.2,18.5,94,1.62,.66,.63,.94,7.1,.73,1.58,695 137 | 3,12.25,4.72,2.54,21,89,1.38,.47,.53,.8,3.85,.75,1.27,720 138 | 3,12.53,5.51,2.64,25,96,1.79,.6,.63,1.1,5,.82,1.69,515 139 | 3,13.49,3.59,2.19,19.5,88,1.62,.48,.58,.88,5.7,.81,1.82,580 140 | 3,12.84,2.96,2.61,24,101,2.32,.6,.53,.81,4.92,.89,2.15,590 141 | 3,12.93,2.81,2.7,21,96,1.54,.5,.53,.75,4.6,.77,2.31,600 142 | 3,13.36,2.56,2.35,20,89,1.4,.5,.37,.64,5.6,.7,2.47,780 143 | 3,13.52,3.17,2.72,23.5,97,1.55,.52,.5,.55,4.35,.89,2.06,520 144 | 3,13.62,4.95,2.35,20,92,2,.8,.47,1.02,4.4,.91,2.05,550 145 | 3,12.25,3.88,2.2,18.5,112,1.38,.78,.29,1.14,8.21,.65,2,855 146 | 3,13.16,3.57,2.15,21,102,1.5,.55,.43,1.3,4,.6,1.68,830 147 | 3,13.88,5.04,2.23,20,80,.98,.34,.4,.68,4.9,.58,1.33,415 148 | 3,12.87,4.61,2.48,21.5,86,1.7,.65,.47,.86,7.65,.54,1.86,625 149 | 3,13.32,3.24,2.38,21.5,92,1.93,.76,.45,1.25,8.42,.55,1.62,650 150 | 3,13.08,3.9,2.36,21.5,113,1.41,1.39,.34,1.14,9.40,.57,1.33,550 151 | 3,13.5,3.12,2.62,24,123,1.4,1.57,.22,1.25,8.60,.59,1.3,500 152 | 3,12.79,2.67,2.48,22,112,1.48,1.36,.24,1.26,10.8,.48,1.47,480 153 | 3,13.11,1.9,2.75,25.5,116,2.2,1.28,.26,1.56,7.1,.61,1.33,425 154 | 3,13.23,3.3,2.28,18.5,98,1.8,.83,.61,1.87,10.52,.56,1.51,675 155 | 3,12.58,1.29,2.1,20,103,1.48,.58,.53,1.4,7.6,.58,1.55,640 156 | 3,13.17,5.19,2.32,22,93,1.74,.63,.61,1.55,7.9,.6,1.48,725 157 | 3,13.84,4.12,2.38,19.5,89,1.8,.83,.48,1.56,9.01,.57,1.64,480 158 | 3,12.45,3.03,2.64,27,97,1.9,.58,.63,1.14,7.5,.67,1.73,880 159 | 3,14.34,1.68,2.7,25,98,2.8,1.31,.53,2.7,13,.57,1.96,660 160 | 3,13.48,1.67,2.64,22.5,89,2.6,1.1,.52,2.29,11.75,.57,1.78,620 161 | 3,12.36,3.83,2.38,21,88,2.3,.92,.5,1.04,7.65,.56,1.58,520 162 | 3,13.69,3.26,2.54,20,107,1.83,.56,.5,.8,5.88,.96,1.82,680 163 | 3,12.85,3.27,2.58,22,106,1.65,.6,.6,.96,5.58,.87,2.11,570 164 | 3,12.96,3.45,2.35,18.5,106,1.39,.7,.4,.94,5.28,.68,1.75,675 165 | 3,13.78,2.76,2.3,22,90,1.35,.68,.41,1.03,9.58,.7,1.68,615 166 | 3,13.73,4.36,2.26,22.5,88,1.28,.47,.52,1.15,6.62,.78,1.75,520 167 | 3,13.45,3.7,2.6,23,111,1.7,.92,.43,1.46,10.68,.85,1.56,695 168 | 3,12.82,3.37,2.3,19.5,88,1.48,.66,.4,.97,10.26,.72,1.75,685 169 | 3,13.58,2.58,2.69,24.5,105,1.55,.84,.39,1.54,8.66,.74,1.8,750 170 | 3,13.4,4.6,2.86,25,112,1.98,.96,.27,1.11,8.5,.67,1.92,630 171 | 3,12.2,3.03,2.32,19,96,1.25,.49,.4,.73,5.5,.66,1.83,510 172 | 3,12.77,2.39,2.28,19.5,86,1.39,.51,.48,.64,9.899999,.57,1.63,470 173 | 3,14.16,2.51,2.48,20,91,1.68,.7,.44,1.24,9.7,.62,1.71,660 174 | 3,13.71,5.65,2.45,20.5,95,1.68,.61,.52,1.06,7.7,.64,1.74,740 175 | 3,13.4,3.91,2.48,23,102,1.8,.75,.43,1.41,7.3,.7,1.56,750 176 | 3,13.27,4.28,2.26,20,120,1.59,.69,.43,1.35,10.2,.59,1.56,835 177 | 3,13.17,2.59,2.37,20,120,1.65,.68,.53,1.46,9.3,.6,1.62,840 178 | 3,14.13,4.1,2.74,24.5,96,2.05,.76,.56,1.35,9.2,.61,1.6,560 179 | -------------------------------------------------------------------------------- /wine/wine_classify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import cross_validation 6 | from sklearn import metrics 7 | 8 | with open('wine.data', 'r') as f: 9 | data = np.loadtxt(f, delimiter=',') 10 | 11 | # データの読み込み 12 | X = data[:,1:] 13 | # ラベルの読み込み 14 | Y = np.array([ 15 | [1 if t == i+1 else 0 for t in data[:,0]] 16 | for i in range(int(data[:,0].max())) 17 | ]).T 18 | 19 | # 最小二乗法による学習 20 | w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y) 21 | 22 | v = np.array([ 23 | 14, 24 | 0, 25 | 0, 26 | 0, 27 | 100, 28 | 0, 29 | 0, 30 | 0, 31 | 0, 32 | 0, 33 | 0, 34 | 0, 35 | 0]) 36 | # w.T.dot(v) = [-0.15889022 0.63867965 0.44747914] 37 | # np.argmax(w.T.dot(v)) = 1 38 | # したがってクラス2 39 | print np.argmax(w.T.dot(v)) 40 | -------------------------------------------------------------------------------- /wine/wine_xval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | from sklearn import cross_validation 6 | from sklearn import metrics 7 | 8 | with open('wine.data', 'r') as f: 9 | data = np.loadtxt(f, delimiter=',') 10 | 11 | # データの読み込み 12 | X = data[:,1:] 13 | # ラベルの読み込み 14 | Y = np.array([ 15 | [1 if t == i+1 else 0 for t in data[:,0]] 16 | for i in range(int(data[:,0].max())) 17 | ]).T 18 | 19 | train_x, test_x, train_y, test_y = cross_validation.train_test_split(X, Y, test_size=0.2) 20 | 21 | # 最小二乗法による学習 22 | w = np.linalg.inv(train_x.T.dot(train_x)).dot(train_x.T).dot(train_y) 23 | 24 | # 最小二乗法による推定 25 | pred_y = np.array([np.argmax(w.T.dot(d)) for d in test_x]) 26 | true_y = np.array([np.argmax(d) for d in test_y]) 27 | 28 | # テストデータに対する正答率 29 | print metrics.accuracy_score(true_y, pred_y) 30 | --------------------------------------------------------------------------------