├── README ├── setup.py └── src └── __init__.py /README: -------------------------------------------------------------------------------- 1 | == Support Vector Machines in Python == 2 | 3 | Author: Jeremy Stober 4 | Contact: stober@gmail.com 5 | Version: 0.1 6 | 7 | This is a simple support vector machine implementation based on the 8 | primal form of SVMs for linearly separable problems, and problems that 9 | also require slack variables. I used Bishop's PRML text as a basis for 10 | this implementation. This is meant as a guide for the basic ideas 11 | behind support vector machiens. The CVXOPT library is used for solving 12 | the quadratic program at the heart of the SVM. 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /suer/bin/env python 2 | ''' 3 | @author: stober 4 | ''' 5 | 6 | from distutils.core import setup 7 | 8 | setup(name='svm', 9 | version='0.1', 10 | description='Support Vector Machines', 11 | author='Jeremy Stober', 12 | author_email='stober@gmail.com', 13 | package_dir={'svm':'src'}, 14 | packages=['svm'], 15 | ) 16 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Author: Jeremy M. Stober 4 | Program: __INIT__.PY 5 | Description: A simple SVM implementation. 6 | """ 7 | 8 | import numpy as np 9 | import numpy.random as npr 10 | import pylab 11 | from cvxopt import solvers, matrix 12 | from utils import plot_line # gist: https://gist.github.com/2778598 13 | 14 | def svm_slack(pts, labels, c = 1.0): 15 | """ 16 | Support Vector Machine using CVXOPT in Python. SVM with slack. 17 | """ 18 | n = len(pts[0]) 19 | m = len(pts) 20 | 21 | nvars = n + m + 1 22 | 23 | # x is a column vector [w b]^T 24 | 25 | # set up P 26 | P = matrix(0.0, (nvars, nvars)) 27 | for i in range(n): 28 | P[i,i] = 1.0 29 | 30 | # q^t x 31 | # set up q 32 | q = matrix(0.0,(nvars,1)) 33 | for i in range(n,n+m): 34 | q[i] = c 35 | q[-1] = 1.0 36 | 37 | # set up h 38 | h = matrix(-1.0,(m+m,1)) 39 | h[m:] = 0.0 40 | 41 | # set up G 42 | print m 43 | G = matrix(0.0, (m+m,nvars)) 44 | for i in range(m): 45 | G[i,:n] = -labels[i] * pts[i] 46 | G[i,n+i] = -1 47 | G[i,-1] = -labels[i] 48 | 49 | for i in range(m,m+m): 50 | G[i,n+i-m] = -1.0 51 | 52 | x = solvers.qp(P,q,G,h)['x'] 53 | 54 | return P, q, h, G, x 55 | 56 | 57 | def svm(pts, labels): 58 | """ 59 | Support Vector Machine using CVXOPT in Python. This example is 60 | mean to illustrate how SVMs work. 61 | """ 62 | n = len(pts[0]) 63 | 64 | # x is a column vector [w b]^T 65 | 66 | # set up P 67 | P = matrix(0.0, (n+1,n+1)) 68 | for i in range(n): 69 | P[i,i] = 1.0 70 | 71 | # q^t x 72 | # set up q 73 | q = matrix(0.0,(n+1,1)) 74 | q[-1] = 1.0 75 | 76 | m = len(pts) 77 | # set up h 78 | h = matrix(-1.0,(m,1)) 79 | 80 | # set up G 81 | G = matrix(0.0, (m,n+1)) 82 | for i in range(m): 83 | G[i,:n] = -labels[i] * pts[i] 84 | G[i,n] = -labels[i] 85 | 86 | x = solvers.qp(P,q,G,h)['x'] 87 | 88 | return P, q, h, G, x 89 | 90 | if __name__ == '__main__': 91 | 92 | def create_overlapping_classification_problem(n=100): 93 | import gmm 94 | 95 | n1 = gmm.Normal(2, mu = [0,0], sigma = [[1,0],[0,1]]) 96 | n2 = gmm.Normal(2, mu = [0,3], sigma = [[1,0],[0,1]]) 97 | class1 = n1.simulate(n/2) 98 | class2 = n2.simulate(n/2) 99 | 100 | samples = np.vstack([class1,class2]) 101 | 102 | labels = np.zeros(n) 103 | labels[:n/2] = -1 104 | labels[n/2:] = 1 105 | 106 | return samples, labels 107 | 108 | def create_classification_problem(n=100): 109 | class1 = npr.rand(n/2,2) 110 | class2 = npr.rand(n/2,2) + np.array([1.3,0.0]) 111 | 112 | theta = np.pi / 8.0 113 | r = np.cos(theta) 114 | s = np.sin(theta) 115 | rotation = np.array([[r,s],[s,-r]]) 116 | 117 | samples = np.dot(np.vstack([class1,class2]), rotation) 118 | 119 | labels = np.zeros(n) 120 | labels[:n/2] = -1 121 | labels[n/2:] = 1 122 | return samples, labels 123 | 124 | if True: 125 | samples, labels = create_overlapping_classification_problem() 126 | 127 | c = ['red'] * 50 + ['blue'] * 50 128 | pylab.scatter(samples[:,0], samples[:,1], color = c) 129 | 130 | #import pdb 131 | #pdb.set_trace() 132 | P,q,h,G,x = svm_slack(samples, labels, c = 2.0) 133 | #print P, q, h, G 134 | line_params = list(x[:2]) + [x[-1]] 135 | 136 | xlim = pylab.gca().get_xlim() 137 | ylim = pylab.gca().get_ylim() 138 | print xlim,ylim 139 | 140 | plot_line(line_params, xlim, ylim) 141 | print line_params 142 | 143 | pylab.show() 144 | 145 | 146 | if False: 147 | samples,labels = create_classification_problem() 148 | P,q,h,G,x = svm(samples, labels) 149 | print x 150 | 151 | 152 | if False: 153 | c = ['red'] * 50 + ['blue'] * 50 154 | pylab.scatter(samples[:,0], samples[:,1], color = c) 155 | 156 | xlim = pylab.gca().get_xlim() 157 | ylim = pylab.gca().get_ylim() 158 | print xlim,ylim 159 | 160 | plot_line(x, xlim, ylim) 161 | pylab.show() 162 | 163 | 164 | 165 | 166 | --------------------------------------------------------------------------------