├── .github └── workflows │ └── wheel.yml ├── COPYRIGHT ├── FAQ.html ├── Makefile ├── Makefile.win ├── README ├── heart_scale ├── java ├── Makefile ├── libsvm.jar ├── libsvm │ ├── svm.java │ ├── svm.m4 │ ├── svm_model.java │ ├── svm_node.java │ ├── svm_parameter.java │ ├── svm_print_interface.java │ └── svm_problem.java ├── svm_predict.java ├── svm_scale.java ├── svm_toy.java └── svm_train.java ├── matlab ├── Makefile ├── README ├── libsvmread.c ├── libsvmwrite.c ├── make.m ├── svm_model_matlab.c ├── svm_model_matlab.h ├── svmpredict.c └── svmtrain.c ├── python ├── MANIFEST.in ├── Makefile ├── README ├── libsvm │ ├── __init__.py │ ├── commonutil.py │ ├── svm.py │ └── svmutil.py └── setup.py ├── svm-predict.c ├── svm-scale.c ├── svm-toy ├── qt │ ├── Makefile │ └── svm-toy.cpp └── windows │ └── svm-toy.cpp ├── svm-train.c ├── svm.cpp ├── svm.def ├── svm.h ├── tools ├── README ├── checkdata.py ├── easy.py ├── grid.py └── subset.py └── windows ├── libsvm.dll ├── libsvmread.mexw64 ├── libsvmwrite.mexw64 ├── svm-predict.exe ├── svm-scale.exe ├── svm-toy.exe ├── svm-train.exe ├── svmpredict.mexw64 └── svmtrain.mexw64 /.github/workflows/wheel.yml: -------------------------------------------------------------------------------- 1 | name: Build wheels 2 | 3 | on: 4 | # on new tag 5 | push: 6 | tags: 7 | - "*" 8 | 9 | # manually trigger 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build_wheels: 14 | name: Build wheels on ${{ matrix.os }} 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [windows-2022, macos-13] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set MacOS compiler 24 | if: runner.os == 'macOS' 25 | run: | 26 | brew install gcc@13; 27 | echo "CXX=gcc-13" >> $GITHUB_ENV 28 | 29 | - name: Build wheels 30 | uses: pypa/cibuildwheel@v2.10.2 31 | env: 32 | # don't build for PyPython and windows 32-bit 33 | CIBW_SKIP: pp* *win32* 34 | # force compiler on macOS 35 | CXX: ${{ env.CXX }} 36 | CC: ${{ env.CXX }} 37 | with: 38 | package-dir: ./python 39 | output-dir: ./python/wheelhouse 40 | 41 | - name: Upload a Build Artifact 42 | uses: actions/upload-artifact@v4 43 | with: 44 | name: wheels-${{ matrix.os }} 45 | path: ./python/wheelhouse 46 | -------------------------------------------------------------------------------- /COPYRIGHT: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2000-2023 Chih-Chung Chang and Chih-Jen Lin 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | 3. Neither name of copyright holders nor the names of its contributors 17 | may be used to endorse or promote products derived from this software 18 | without specific prior written permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 25 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX ?= g++ 2 | CFLAGS = -Wall -Wconversion -O3 -fPIC 3 | SHVER = 4 4 | OS = $(shell uname) 5 | ifeq ($(OS),Darwin) 6 | SHARED_LIB_FLAG = -dynamiclib -Wl,-install_name,libsvm.so.$(SHVER) 7 | else 8 | SHARED_LIB_FLAG = -shared -Wl,-soname,libsvm.so.$(SHVER) 9 | endif 10 | 11 | # Uncomment the following lines to enable parallelization with OpenMP 12 | # CFLAGS += -fopenmp 13 | # SHARED_LIB_FLAG += -fopenmp 14 | 15 | all: svm-train svm-predict svm-scale 16 | 17 | lib: svm.o 18 | $(CXX) $(SHARED_LIB_FLAG) svm.o -o libsvm.so.$(SHVER) 19 | svm-predict: svm-predict.c svm.o 20 | $(CXX) $(CFLAGS) svm-predict.c svm.o -o svm-predict -lm 21 | svm-train: svm-train.c svm.o 22 | $(CXX) $(CFLAGS) svm-train.c svm.o -o svm-train -lm 23 | svm-scale: svm-scale.c 24 | $(CXX) $(CFLAGS) svm-scale.c -o svm-scale 25 | svm.o: svm.cpp svm.h 26 | $(CXX) $(CFLAGS) -c svm.cpp 27 | clean: 28 | rm -f *~ svm.o svm-train svm-predict svm-scale libsvm.so.$(SHVER) 29 | -------------------------------------------------------------------------------- /Makefile.win: -------------------------------------------------------------------------------- 1 | #You must ensure nmake.exe, cl.exe, link.exe are in system path. 2 | #VCVARS64.bat 3 | #Under dosbox prompt 4 | #nmake -f Makefile.win 5 | 6 | ########################################## 7 | CXX = cl.exe 8 | CFLAGS = /nologo /O2 /EHsc /I. /D _WIN64 /D _CRT_SECURE_NO_DEPRECATE 9 | TARGET = windows 10 | 11 | # Uncomment the following lines to enable parallelization with OpenMP 12 | # CFLAGS = /nologo /O2 /EHsc /I. /D _WIN64 /D _CRT_SECURE_NO_DEPRECATE /openmp 13 | 14 | all: $(TARGET)\svm-train.exe $(TARGET)\svm-predict.exe $(TARGET)\svm-scale.exe $(TARGET)\svm-toy.exe lib 15 | 16 | $(TARGET)\svm-predict.exe: svm.h svm-predict.c svm.obj 17 | $(CXX) $(CFLAGS) svm-predict.c svm.obj -Fe$(TARGET)\svm-predict.exe 18 | 19 | $(TARGET)\svm-train.exe: svm.h svm-train.c svm.obj 20 | $(CXX) $(CFLAGS) svm-train.c svm.obj -Fe$(TARGET)\svm-train.exe 21 | 22 | $(TARGET)\svm-scale.exe: svm.h svm-scale.c 23 | $(CXX) $(CFLAGS) svm-scale.c -Fe$(TARGET)\svm-scale.exe 24 | 25 | $(TARGET)\svm-toy.exe: svm.h svm.obj svm-toy\windows\svm-toy.cpp 26 | $(CXX) $(CFLAGS) svm-toy\windows\svm-toy.cpp svm.obj user32.lib gdi32.lib comdlg32.lib -Fe$(TARGET)\svm-toy.exe 27 | 28 | svm.obj: svm.cpp svm.h 29 | $(CXX) $(CFLAGS) -c svm.cpp 30 | 31 | lib: svm.cpp svm.h svm.def 32 | $(CXX) $(CFLAGS) -LD svm.cpp -Fe$(TARGET)\libsvm -link -DEF:svm.def 33 | 34 | clean: 35 | -erase /Q *.obj $(TARGET)\*.exe $(TARGET)\*.dll $(TARGET)\*.exp $(TARGET)\*.lib 36 | 37 | -------------------------------------------------------------------------------- /java/Makefile: -------------------------------------------------------------------------------- 1 | .SUFFIXES: .class .java 2 | FILES = libsvm/svm.class libsvm/svm_model.class libsvm/svm_node.class \ 3 | libsvm/svm_parameter.class libsvm/svm_problem.class \ 4 | libsvm/svm_print_interface.class \ 5 | svm_train.class svm_predict.class svm_toy.class svm_scale.class 6 | 7 | #JAVAC = jikes 8 | JAVAC_FLAGS = --release 11 9 | JAVAC = javac 10 | # JAVAC_FLAGS = 11 | export CLASSPATH := .:$(CLASSPATH) 12 | 13 | all: $(FILES) 14 | jar cvf libsvm.jar *.class libsvm/*.class 15 | 16 | .java.class: 17 | $(JAVAC) $(JAVAC_FLAGS) $< 18 | 19 | libsvm/svm.java: libsvm/svm.m4 20 | m4 libsvm/svm.m4 > libsvm/svm.java 21 | 22 | clean: 23 | rm -f libsvm/*.class *.class *.jar libsvm/*~ *~ libsvm/svm.java 24 | 25 | dist: clean all 26 | rm *.class libsvm/*.class 27 | -------------------------------------------------------------------------------- /java/libsvm.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/java/libsvm.jar -------------------------------------------------------------------------------- /java/libsvm/svm_model.java: -------------------------------------------------------------------------------- 1 | // 2 | // svm_model 3 | // 4 | package libsvm; 5 | public class svm_model implements java.io.Serializable 6 | { 7 | public svm_parameter param; // parameter 8 | public int nr_class; // number of classes, = 2 in regression/one class svm 9 | public int l; // total #SV 10 | public svm_node[][] SV; // SVs (SV[l]) 11 | public double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l]) 12 | public double[] rho; // constants in decision functions (rho[k*(k-1)/2]) 13 | public double[] probA; // pariwise probability information 14 | public double[] probB; 15 | public double[] prob_density_marks; // probability information for ONE_CLASS 16 | public int[] sv_indices; // sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set 17 | 18 | // for classification only 19 | 20 | public int[] label; // label of each class (label[k]) 21 | public int[] nSV; // number of SVs for each class (nSV[k]) 22 | // nSV[0] + nSV[1] + ... + nSV[k-1] = l 23 | }; 24 | -------------------------------------------------------------------------------- /java/libsvm/svm_node.java: -------------------------------------------------------------------------------- 1 | package libsvm; 2 | public class svm_node implements java.io.Serializable 3 | { 4 | public int index; 5 | public double value; 6 | } 7 | -------------------------------------------------------------------------------- /java/libsvm/svm_parameter.java: -------------------------------------------------------------------------------- 1 | package libsvm; 2 | public class svm_parameter implements Cloneable,java.io.Serializable 3 | { 4 | /* svm_type */ 5 | public static final int C_SVC = 0; 6 | public static final int NU_SVC = 1; 7 | public static final int ONE_CLASS = 2; 8 | public static final int EPSILON_SVR = 3; 9 | public static final int NU_SVR = 4; 10 | 11 | /* kernel_type */ 12 | public static final int LINEAR = 0; 13 | public static final int POLY = 1; 14 | public static final int RBF = 2; 15 | public static final int SIGMOID = 3; 16 | public static final int PRECOMPUTED = 4; 17 | 18 | public int svm_type; 19 | public int kernel_type; 20 | public int degree; // for poly 21 | public double gamma; // for poly/rbf/sigmoid 22 | public double coef0; // for poly/sigmoid 23 | 24 | // these are for training only 25 | public double cache_size; // in MB 26 | public double eps; // stopping criteria 27 | public double C; // for C_SVC, EPSILON_SVR and NU_SVR 28 | public int nr_weight; // for C_SVC 29 | public int[] weight_label; // for C_SVC 30 | public double[] weight; // for C_SVC 31 | public double nu; // for NU_SVC, ONE_CLASS, and NU_SVR 32 | public double p; // for EPSILON_SVR 33 | public int shrinking; // use the shrinking heuristics 34 | public int probability; // do probability estimates 35 | 36 | public Object clone() 37 | { 38 | try 39 | { 40 | return super.clone(); 41 | } catch (CloneNotSupportedException e) 42 | { 43 | return null; 44 | } 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /java/libsvm/svm_print_interface.java: -------------------------------------------------------------------------------- 1 | package libsvm; 2 | public interface svm_print_interface 3 | { 4 | public void print(String s); 5 | } 6 | -------------------------------------------------------------------------------- /java/libsvm/svm_problem.java: -------------------------------------------------------------------------------- 1 | package libsvm; 2 | public class svm_problem implements java.io.Serializable 3 | { 4 | public int l; 5 | public double[] y; 6 | public svm_node[][] x; 7 | } 8 | -------------------------------------------------------------------------------- /java/svm_predict.java: -------------------------------------------------------------------------------- 1 | import libsvm.*; 2 | import java.io.*; 3 | import java.util.*; 4 | 5 | class svm_predict { 6 | private static svm_print_interface svm_print_null = new svm_print_interface() 7 | { 8 | public void print(String s) {} 9 | }; 10 | 11 | private static svm_print_interface svm_print_stdout = new svm_print_interface() 12 | { 13 | public void print(String s) 14 | { 15 | System.out.print(s); 16 | } 17 | }; 18 | 19 | private static svm_print_interface svm_print_string = svm_print_stdout; 20 | 21 | static void info(String s) 22 | { 23 | svm_print_string.print(s); 24 | } 25 | 26 | private static double atof(String s) 27 | { 28 | return Double.valueOf(s).doubleValue(); 29 | } 30 | 31 | private static int atoi(String s) 32 | { 33 | return Integer.parseInt(s); 34 | } 35 | 36 | private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException 37 | { 38 | int correct = 0; 39 | int total = 0; 40 | double error = 0; 41 | double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; 42 | 43 | int svm_type=svm.svm_get_svm_type(model); 44 | int nr_class=svm.svm_get_nr_class(model); 45 | double[] prob_estimates=null; 46 | 47 | if(predict_probability == 1) 48 | { 49 | if(svm_type == svm_parameter.EPSILON_SVR || 50 | svm_type == svm_parameter.NU_SVR) 51 | { 52 | svm_predict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n"); 53 | } 54 | else if(svm_type == svm_parameter.ONE_CLASS) 55 | { 56 | // nr_class = 2 for ONE_CLASS 57 | prob_estimates = new double[nr_class]; 58 | output.writeBytes("label normal outlier\n"); 59 | } 60 | else 61 | { 62 | int[] labels=new int[nr_class]; 63 | svm.svm_get_labels(model,labels); 64 | prob_estimates = new double[nr_class]; 65 | output.writeBytes("labels"); 66 | for(int j=0;j=argv.length-2) 161 | exit_with_help(); 162 | try 163 | { 164 | BufferedReader input = new BufferedReader(new FileReader(argv[i])); 165 | DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2]))); 166 | svm_model model = svm.svm_load_model(argv[i+1]); 167 | if (model == null) 168 | { 169 | System.err.print("can't open model file "+argv[i+1]+"\n"); 170 | System.exit(1); 171 | } 172 | if(predict_probability == 1) 173 | { 174 | if(svm.svm_check_probability_model(model)==0) 175 | { 176 | System.err.print("Model does not support probabiliy estimates\n"); 177 | System.exit(1); 178 | } 179 | } 180 | else 181 | { 182 | if(svm.svm_check_probability_model(model)!=0) 183 | { 184 | svm_predict.info("Model supports probability estimates, but disabled in prediction.\n"); 185 | } 186 | } 187 | predict(input,output,model,predict_probability); 188 | input.close(); 189 | output.close(); 190 | } 191 | catch(FileNotFoundException e) 192 | { 193 | exit_with_help(); 194 | } 195 | catch(ArrayIndexOutOfBoundsException e) 196 | { 197 | exit_with_help(); 198 | } 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /java/svm_scale.java: -------------------------------------------------------------------------------- 1 | import libsvm.*; 2 | import java.io.*; 3 | import java.util.*; 4 | import java.text.DecimalFormat; 5 | 6 | class svm_scale 7 | { 8 | private String line = null; 9 | private double lower = -1.0; 10 | private double upper = 1.0; 11 | private double y_lower; 12 | private double y_upper; 13 | private boolean y_scaling = false; 14 | private double[] feature_max; 15 | private double[] feature_min; 16 | private double y_max = -Double.MAX_VALUE; 17 | private double y_min = Double.MAX_VALUE; 18 | private int max_index; 19 | private long num_nonzeros = 0; 20 | private long new_num_nonzeros = 0; 21 | 22 | private static void exit_with_help() 23 | { 24 | System.out.print( 25 | "Usage: svm-scale [options] data_filename\n" 26 | +"options:\n" 27 | +"-l lower : x scaling lower limit (default -1)\n" 28 | +"-u upper : x scaling upper limit (default +1)\n" 29 | +"-y y_lower y_upper : y scaling limits (default: no y scaling)\n" 30 | +"-s save_filename : save scaling parameters to save_filename\n" 31 | +"-r restore_filename : restore scaling parameters from restore_filename\n" 32 | ); 33 | System.exit(1); 34 | } 35 | 36 | private BufferedReader rewind(BufferedReader fp, String filename) throws IOException 37 | { 38 | fp.close(); 39 | return new BufferedReader(new FileReader(filename)); 40 | } 41 | 42 | private void output_target(double value) 43 | { 44 | if(y_scaling) 45 | { 46 | if(value == y_min) 47 | value = y_lower; 48 | else if(value == y_max) 49 | value = y_upper; 50 | else 51 | value = y_lower + (y_upper-y_lower) * 52 | (value-y_min) / (y_max-y_min); 53 | } 54 | 55 | System.out.print(value + " "); 56 | } 57 | 58 | private void output(int index, double value) 59 | { 60 | /* skip single-valued attribute */ 61 | if(feature_max[index] == feature_min[index]) 62 | return; 63 | 64 | if(value == feature_min[index]) 65 | value = lower; 66 | else if(value == feature_max[index]) 67 | value = upper; 68 | else 69 | value = lower + (upper-lower) * 70 | (value-feature_min[index])/ 71 | (feature_max[index]-feature_min[index]); 72 | 73 | if(value != 0) 74 | { 75 | System.out.print(index + ":" + value + " "); 76 | new_num_nonzeros++; 77 | } 78 | } 79 | 80 | private String readline(BufferedReader fp) throws IOException 81 | { 82 | line = fp.readLine(); 83 | return line; 84 | } 85 | 86 | private void run(String []argv) throws IOException 87 | { 88 | int i,index; 89 | BufferedReader fp = null, fp_restore = null; 90 | String save_filename = null; 91 | String restore_filename = null; 92 | String data_filename = null; 93 | 94 | 95 | for(i=0;i lower) || (y_scaling && !(y_upper > y_lower))) 118 | { 119 | System.err.println("inconsistent lower/upper specification"); 120 | System.exit(1); 121 | } 122 | if(restore_filename != null && save_filename != null) 123 | { 124 | System.err.println("cannot use -r and -s simultaneously"); 125 | System.exit(1); 126 | } 127 | 128 | if(argv.length != i+1) 129 | exit_with_help(); 130 | 131 | data_filename = argv[i]; 132 | try { 133 | fp = new BufferedReader(new FileReader(data_filename)); 134 | } catch (Exception e) { 135 | System.err.println("can't open file " + data_filename); 136 | System.exit(1); 137 | } 138 | 139 | /* assumption: min index of attributes is 1 */ 140 | /* pass 1: find out max index of attributes */ 141 | max_index = 0; 142 | 143 | if(restore_filename != null) 144 | { 145 | int idx, c; 146 | 147 | try { 148 | fp_restore = new BufferedReader(new FileReader(restore_filename)); 149 | } 150 | catch (Exception e) { 151 | System.err.println("can't open file " + restore_filename); 152 | System.exit(1); 153 | } 154 | if((c = fp_restore.read()) == 'y') 155 | { 156 | fp_restore.readLine(); 157 | fp_restore.readLine(); 158 | fp_restore.readLine(); 159 | } 160 | fp_restore.readLine(); 161 | fp_restore.readLine(); 162 | 163 | String restore_line = null; 164 | while((restore_line = fp_restore.readLine())!=null) 165 | { 166 | StringTokenizer st2 = new StringTokenizer(restore_line); 167 | idx = Integer.parseInt(st2.nextToken()); 168 | max_index = Math.max(max_index, idx); 169 | } 170 | fp_restore = rewind(fp_restore, restore_filename); 171 | } 172 | 173 | while (readline(fp) != null) 174 | { 175 | StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 176 | st.nextToken(); 177 | while(st.hasMoreTokens()) 178 | { 179 | index = Integer.parseInt(st.nextToken()); 180 | max_index = Math.max(max_index, index); 181 | st.nextToken(); 182 | num_nonzeros++; 183 | } 184 | } 185 | 186 | try { 187 | feature_max = new double[(max_index+1)]; 188 | feature_min = new double[(max_index+1)]; 189 | } catch(OutOfMemoryError e) { 190 | System.err.println("can't allocate enough memory"); 191 | System.exit(1); 192 | } 193 | 194 | for(i=0;i<=max_index;i++) 195 | { 196 | feature_max[i] = -Double.MAX_VALUE; 197 | feature_min[i] = Double.MAX_VALUE; 198 | } 199 | 200 | fp = rewind(fp, data_filename); 201 | 202 | /* pass 2: find out min/max value */ 203 | while(readline(fp) != null) 204 | { 205 | int next_index = 1; 206 | double target; 207 | double value; 208 | 209 | StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 210 | target = Double.parseDouble(st.nextToken()); 211 | y_max = Math.max(y_max, target); 212 | y_min = Math.min(y_min, target); 213 | 214 | while (st.hasMoreTokens()) 215 | { 216 | index = Integer.parseInt(st.nextToken()); 217 | value = Double.parseDouble(st.nextToken()); 218 | 219 | for (i = next_index; i num_nonzeros) 337 | System.err.print( 338 | "WARNING: original #nonzeros " + num_nonzeros+"\n" 339 | +" new #nonzeros " + new_num_nonzeros+"\n" 340 | +"Use -l 0 if many original feature values are zeros\n"); 341 | 342 | fp.close(); 343 | } 344 | 345 | public static void main(String argv[]) throws IOException 346 | { 347 | svm_scale s = new svm_scale(); 348 | s.run(argv); 349 | } 350 | } 351 | -------------------------------------------------------------------------------- /java/svm_toy.java: -------------------------------------------------------------------------------- 1 | import libsvm.*; 2 | import java.awt.*; 3 | import java.util.*; 4 | import java.awt.event.*; 5 | import java.io.*; 6 | 7 | public class svm_toy { 8 | public static void main(String[] args) { 9 | svm_toy_frame frame = new svm_toy_frame("svm_toy", 500, 500+50); 10 | } 11 | } 12 | class svm_toy_frame extends Frame { 13 | 14 | static final String DEFAULT_PARAM="-t 2 -c 100"; 15 | int XLEN; 16 | int YLEN; 17 | 18 | // off-screen buffer 19 | 20 | Image buffer; 21 | Graphics buffer_gc; 22 | 23 | // pre-allocated colors 24 | 25 | final static Color colors[] = 26 | { 27 | new Color(0,0,0), 28 | new Color(0,120,120), 29 | new Color(120,120,0), 30 | new Color(120,0,120), 31 | new Color(0,200,200), 32 | new Color(200,200,0), 33 | new Color(200,0,200) 34 | }; 35 | 36 | class point { 37 | point(double x, double y, byte value) 38 | { 39 | this.x = x; 40 | this.y = y; 41 | this.value = value; 42 | } 43 | double x, y; 44 | byte value; 45 | } 46 | 47 | Vector point_list = new Vector(); 48 | byte current_value = 1; 49 | 50 | svm_toy_frame(String title, int width, int height) 51 | { 52 | super(title); 53 | this.addWindowListener(new WindowAdapter() { 54 | public void windowClosing(WindowEvent e) { 55 | System.exit(0); 56 | } 57 | }); 58 | this.init(); 59 | this.setSize(width, height); 60 | XLEN = width; 61 | YLEN = height-50; 62 | this.clear_all(); 63 | this.setVisible(true); 64 | } 65 | 66 | void init() 67 | { 68 | final Button button_change = new Button("Change"); 69 | Button button_run = new Button("Run"); 70 | Button button_clear = new Button("Clear"); 71 | Button button_save = new Button("Save"); 72 | Button button_load = new Button("Load"); 73 | final TextField input_line = new TextField(DEFAULT_PARAM); 74 | 75 | BorderLayout layout = new BorderLayout(); 76 | this.setLayout(layout); 77 | 78 | Panel p = new Panel(); 79 | GridBagLayout gridbag = new GridBagLayout(); 80 | p.setLayout(gridbag); 81 | 82 | GridBagConstraints c = new GridBagConstraints(); 83 | c.fill = GridBagConstraints.HORIZONTAL; 84 | c.weightx = 1; 85 | c.gridwidth = 1; 86 | gridbag.setConstraints(button_change,c); 87 | gridbag.setConstraints(button_run,c); 88 | gridbag.setConstraints(button_clear,c); 89 | gridbag.setConstraints(button_save,c); 90 | gridbag.setConstraints(button_load,c); 91 | c.weightx = 5; 92 | c.gridwidth = 5; 93 | gridbag.setConstraints(input_line,c); 94 | 95 | button_change.setBackground(colors[current_value]); 96 | 97 | p.add(button_change); 98 | p.add(button_run); 99 | p.add(button_clear); 100 | p.add(button_save); 101 | p.add(button_load); 102 | p.add(input_line); 103 | this.add(p,BorderLayout.SOUTH); 104 | 105 | button_change.addActionListener(new ActionListener() 106 | { public void actionPerformed (ActionEvent e) 107 | { button_change_clicked(); button_change.setBackground(colors[current_value]); }}); 108 | 109 | button_run.addActionListener(new ActionListener() 110 | { public void actionPerformed (ActionEvent e) 111 | { button_run_clicked(input_line.getText()); }}); 112 | 113 | button_clear.addActionListener(new ActionListener() 114 | { public void actionPerformed (ActionEvent e) 115 | { button_clear_clicked(); }}); 116 | 117 | button_save.addActionListener(new ActionListener() 118 | { public void actionPerformed (ActionEvent e) 119 | { button_save_clicked(input_line.getText()); }}); 120 | 121 | button_load.addActionListener(new ActionListener() 122 | { public void actionPerformed (ActionEvent e) 123 | { button_load_clicked(); }}); 124 | 125 | input_line.addActionListener(new ActionListener() 126 | { public void actionPerformed (ActionEvent e) 127 | { button_run_clicked(input_line.getText()); }}); 128 | 129 | this.enableEvents(AWTEvent.MOUSE_EVENT_MASK); 130 | } 131 | 132 | void draw_point(point p) 133 | { 134 | Color c = colors[p.value+3]; 135 | 136 | Graphics window_gc = getGraphics(); 137 | buffer_gc.setColor(c); 138 | buffer_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4); 139 | window_gc.setColor(c); 140 | window_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4); 141 | } 142 | 143 | void clear_all() 144 | { 145 | point_list.removeAllElements(); 146 | if(buffer != null) 147 | { 148 | buffer_gc.setColor(colors[0]); 149 | buffer_gc.fillRect(0,0,XLEN,YLEN); 150 | } 151 | repaint(); 152 | } 153 | 154 | void draw_all_points() 155 | { 156 | int n = point_list.size(); 157 | for(int i=0;i 3) current_value = 1; 165 | } 166 | 167 | private static double atof(String s) 168 | { 169 | return Double.valueOf(s).doubleValue(); 170 | } 171 | 172 | private static int atoi(String s) 173 | { 174 | return Integer.parseInt(s); 175 | } 176 | 177 | void button_run_clicked(String args) 178 | { 179 | // guard 180 | if(point_list.isEmpty()) return; 181 | 182 | svm_parameter param = new svm_parameter(); 183 | 184 | // default values 185 | param.svm_type = svm_parameter.C_SVC; 186 | param.kernel_type = svm_parameter.RBF; 187 | param.degree = 3; 188 | param.gamma = 0; 189 | param.coef0 = 0; 190 | param.nu = 0.5; 191 | param.cache_size = 40; 192 | param.C = 1; 193 | param.eps = 1e-3; 194 | param.p = 0.1; 195 | param.shrinking = 1; 196 | param.probability = 0; 197 | param.nr_weight = 0; 198 | param.weight_label = new int[0]; 199 | param.weight = new double[0]; 200 | 201 | // parse options 202 | StringTokenizer st = new StringTokenizer(args); 203 | String[] argv = new String[st.countTokens()]; 204 | for(int i=0;i=argv.length) 211 | { 212 | System.err.print("unknown option\n"); 213 | break; 214 | } 215 | switch(argv[i-1].charAt(1)) 216 | { 217 | case 's': 218 | param.svm_type = atoi(argv[i]); 219 | break; 220 | case 't': 221 | param.kernel_type = atoi(argv[i]); 222 | break; 223 | case 'd': 224 | param.degree = atoi(argv[i]); 225 | break; 226 | case 'g': 227 | param.gamma = atof(argv[i]); 228 | break; 229 | case 'r': 230 | param.coef0 = atof(argv[i]); 231 | break; 232 | case 'n': 233 | param.nu = atof(argv[i]); 234 | break; 235 | case 'm': 236 | param.cache_size = atof(argv[i]); 237 | break; 238 | case 'c': 239 | param.C = atof(argv[i]); 240 | break; 241 | case 'e': 242 | param.eps = atof(argv[i]); 243 | break; 244 | case 'p': 245 | param.p = atof(argv[i]); 246 | break; 247 | case 'h': 248 | param.shrinking = atoi(argv[i]); 249 | break; 250 | case 'b': 251 | param.probability = atoi(argv[i]); 252 | break; 253 | case 'w': 254 | ++param.nr_weight; 255 | { 256 | int[] old = param.weight_label; 257 | param.weight_label = new int[param.nr_weight]; 258 | System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1); 259 | } 260 | 261 | { 262 | double[] old = param.weight; 263 | param.weight = new double[param.nr_weight]; 264 | System.arraycopy(old,0,param.weight,0,param.nr_weight-1); 265 | } 266 | 267 | param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2)); 268 | param.weight[param.nr_weight-1] = atof(argv[i]); 269 | break; 270 | default: 271 | System.err.print("unknown option\n"); 272 | } 273 | } 274 | 275 | // build problem 276 | svm_problem prob = new svm_problem(); 277 | prob.l = point_list.size(); 278 | prob.y = new double[prob.l]; 279 | 280 | if(param.kernel_type == svm_parameter.PRECOMPUTED) 281 | { 282 | } 283 | else if(param.svm_type == svm_parameter.EPSILON_SVR || 284 | param.svm_type == svm_parameter.NU_SVR) 285 | { 286 | if(param.gamma == 0) param.gamma = 1; 287 | prob.x = new svm_node[prob.l][1]; 288 | for(int i=0;i= XLEN || e.getY() >= YLEN) return; 468 | point p = new point((double)e.getX()/XLEN, 469 | (double)e.getY()/YLEN, 470 | current_value); 471 | point_list.addElement(p); 472 | draw_point(p); 473 | } 474 | } 475 | 476 | public void paint(Graphics g) 477 | { 478 | // create buffer first time 479 | if(buffer == null) { 480 | buffer = this.createImage(XLEN,YLEN); 481 | buffer_gc = buffer.getGraphics(); 482 | buffer_gc.setColor(colors[0]); 483 | buffer_gc.fillRect(0,0,XLEN,YLEN); 484 | } 485 | g.drawImage(buffer,0,0,this); 486 | } 487 | } -------------------------------------------------------------------------------- /java/svm_train.java: -------------------------------------------------------------------------------- 1 | import libsvm.*; 2 | import java.io.*; 3 | import java.util.*; 4 | 5 | class svm_train { 6 | private svm_parameter param; // set by parse_command_line 7 | private svm_problem prob; // set by read_problem 8 | private svm_model model; 9 | private String input_file_name; // set by parse_command_line 10 | private String model_file_name; // set by parse_command_line 11 | private String error_msg; 12 | private int cross_validation; 13 | private int nr_fold; 14 | 15 | private static svm_print_interface svm_print_null = new svm_print_interface() 16 | { 17 | public void print(String s) {} 18 | }; 19 | 20 | private static void exit_with_help() 21 | { 22 | System.out.print( 23 | "Usage: svm_train [options] training_set_file [model_file]\n" 24 | +"options:\n" 25 | +"-s svm_type : set type of SVM (default 0)\n" 26 | +" 0 -- C-SVC (multi-class classification)\n" 27 | +" 1 -- nu-SVC (multi-class classification)\n" 28 | +" 2 -- one-class SVM\n" 29 | +" 3 -- epsilon-SVR (regression)\n" 30 | +" 4 -- nu-SVR (regression)\n" 31 | +"-t kernel_type : set type of kernel function (default 2)\n" 32 | +" 0 -- linear: u'*v\n" 33 | +" 1 -- polynomial: (gamma*u'*v + coef0)^degree\n" 34 | +" 2 -- radial basis function: exp(-gamma*|u-v|^2)\n" 35 | +" 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n" 36 | +" 4 -- precomputed kernel (kernel values in training_set_file)\n" 37 | +"-d degree : set degree in kernel function (default 3)\n" 38 | +"-g gamma : set gamma in kernel function (default 1/num_features)\n" 39 | +"-r coef0 : set coef0 in kernel function (default 0)\n" 40 | +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n" 41 | +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n" 42 | +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n" 43 | +"-m cachesize : set cache memory size in MB (default 100)\n" 44 | +"-e epsilon : set tolerance of termination criterion (default 0.001)\n" 45 | +"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n" 46 | +"-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n" 47 | +"-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n" 48 | +"-v n : n-fold cross validation mode\n" 49 | +"-q : quiet mode (no outputs)\n" 50 | ); 51 | System.exit(1); 52 | } 53 | 54 | private void do_cross_validation() 55 | { 56 | int i; 57 | int total_correct = 0; 58 | double total_error = 0; 59 | double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 60 | double[] target = new double[prob.l]; 61 | 62 | svm.svm_cross_validation(prob,param,nr_fold,target); 63 | if(param.svm_type == svm_parameter.EPSILON_SVR || 64 | param.svm_type == svm_parameter.NU_SVR) 65 | { 66 | for(i=0;i=argv.length) 166 | exit_with_help(); 167 | switch(argv[i-1].charAt(1)) 168 | { 169 | case 's': 170 | param.svm_type = atoi(argv[i]); 171 | break; 172 | case 't': 173 | param.kernel_type = atoi(argv[i]); 174 | break; 175 | case 'd': 176 | param.degree = atoi(argv[i]); 177 | break; 178 | case 'g': 179 | param.gamma = atof(argv[i]); 180 | break; 181 | case 'r': 182 | param.coef0 = atof(argv[i]); 183 | break; 184 | case 'n': 185 | param.nu = atof(argv[i]); 186 | break; 187 | case 'm': 188 | param.cache_size = atof(argv[i]); 189 | break; 190 | case 'c': 191 | param.C = atof(argv[i]); 192 | break; 193 | case 'e': 194 | param.eps = atof(argv[i]); 195 | break; 196 | case 'p': 197 | param.p = atof(argv[i]); 198 | break; 199 | case 'h': 200 | param.shrinking = atoi(argv[i]); 201 | break; 202 | case 'b': 203 | param.probability = atoi(argv[i]); 204 | break; 205 | case 'q': 206 | print_func = svm_print_null; 207 | i--; 208 | break; 209 | case 'v': 210 | cross_validation = 1; 211 | nr_fold = atoi(argv[i]); 212 | if(nr_fold < 2) 213 | { 214 | System.err.print("n-fold cross validation: n must >= 2\n"); 215 | exit_with_help(); 216 | } 217 | break; 218 | case 'w': 219 | ++param.nr_weight; 220 | { 221 | int[] old = param.weight_label; 222 | param.weight_label = new int[param.nr_weight]; 223 | System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1); 224 | } 225 | 226 | { 227 | double[] old = param.weight; 228 | param.weight = new double[param.nr_weight]; 229 | System.arraycopy(old,0,param.weight,0,param.nr_weight-1); 230 | } 231 | 232 | param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2)); 233 | param.weight[param.nr_weight-1] = atof(argv[i]); 234 | break; 235 | default: 236 | System.err.print("Unknown option: " + argv[i-1] + "\n"); 237 | exit_with_help(); 238 | } 239 | } 240 | 241 | svm.svm_set_print_string_function(print_func); 242 | 243 | // determine filenames 244 | 245 | if(i>=argv.length) 246 | exit_with_help(); 247 | 248 | input_file_name = argv[i]; 249 | 250 | if(i vy = new Vector(); 266 | Vector vx = new Vector(); 267 | int max_index = 0; 268 | 269 | while(true) 270 | { 271 | String line = fp.readLine(); 272 | if(line == null) break; 273 | 274 | StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 275 | 276 | vy.addElement(atof(st.nextToken())); 277 | int m = st.countTokens()/2; 278 | svm_node[] x = new svm_node[m]; 279 | for(int j=0;j0) max_index = Math.max(max_index, x[m-1].index); 286 | vx.addElement(x); 287 | } 288 | 289 | prob = new svm_problem(); 290 | prob.l = vy.size(); 291 | prob.x = new svm_node[prob.l][]; 292 | for(int i=0;i 0) 299 | param.gamma = 1.0/max_index; 300 | 301 | if(param.kernel_type == svm_parameter.PRECOMPUTED) 302 | for(int i=0;i max_index) 310 | { 311 | System.err.print("Wrong input format: sample_serial_number out of range\n"); 312 | System.exit(1); 313 | } 314 | } 315 | 316 | fp.close(); 317 | } 318 | } 319 | -------------------------------------------------------------------------------- /matlab/Makefile: -------------------------------------------------------------------------------- 1 | # This Makefile is used under Linux 2 | 3 | MATLABDIR ?= /usr/local/matlab 4 | # for Mac 5 | # MATLABDIR ?= /opt/local/matlab 6 | 7 | CXX ?= g++ 8 | #CXX = g++-4.1 9 | CFLAGS = -Wall -Wconversion -O3 -fPIC -I$(MATLABDIR)/extern/include -I.. 10 | 11 | MEX = $(MATLABDIR)/bin/mex 12 | MEX_OPTION = CC="$(CXX)" CXX="$(CXX)" CFLAGS="$(CFLAGS)" CXXFLAGS="$(CFLAGS)" 13 | # comment the following line if you use MATLAB on 32-bit computer 14 | MEX_OPTION += -largeArrayDims 15 | MEX_EXT = $(shell $(MATLABDIR)/bin/mexext) 16 | 17 | all: matlab 18 | 19 | matlab: binary 20 | 21 | octave: 22 | @echo "please type make under Octave" 23 | 24 | binary: svmpredict.$(MEX_EXT) svmtrain.$(MEX_EXT) libsvmread.$(MEX_EXT) libsvmwrite.$(MEX_EXT) 25 | 26 | svmpredict.$(MEX_EXT): svmpredict.c ../svm.h ../svm.cpp svm_model_matlab.c 27 | $(MEX) $(MEX_OPTION) svmpredict.c ../svm.cpp svm_model_matlab.c 28 | 29 | svmtrain.$(MEX_EXT): svmtrain.c ../svm.h ../svm.cpp svm_model_matlab.c 30 | $(MEX) $(MEX_OPTION) svmtrain.c ../svm.cpp svm_model_matlab.c 31 | 32 | libsvmread.$(MEX_EXT): libsvmread.c 33 | $(MEX) $(MEX_OPTION) libsvmread.c 34 | 35 | libsvmwrite.$(MEX_EXT): libsvmwrite.c 36 | $(MEX) $(MEX_OPTION) libsvmwrite.c 37 | 38 | clean: 39 | rm -f *~ *.o *.mex* *.obj 40 | -------------------------------------------------------------------------------- /matlab/README: -------------------------------------------------------------------------------- 1 | ----------------------------------------- 2 | --- MATLAB/OCTAVE interface of LIBSVM --- 3 | ----------------------------------------- 4 | 5 | Table of Contents 6 | ================= 7 | 8 | - Introduction 9 | - Installation 10 | - Usage 11 | - Returned Model Structure 12 | - Other Utilities 13 | - Examples 14 | - Additional Information 15 | 16 | 17 | Introduction 18 | ============ 19 | 20 | This tool provides a simple interface to LIBSVM, a library for support vector 21 | machines (http://www.csie.ntu.edu.tw/~cjlin/libsvm). It is very easy to use as 22 | the usage and the way of specifying parameters are the same as that of LIBSVM. 23 | 24 | Installation 25 | ============ 26 | 27 | On Windows systems, pre-built mex files are already in the 28 | directory '..\windows', so please just copy them to the matlab 29 | directory. Now we provide binary files only for 64bit MATLAB on 30 | Windows. If you would like to re-build the package, please rely on the 31 | following steps. 32 | 33 | We recommend using make.m on both MATLAB and OCTAVE. Just type 'make' 34 | to build 'libsvmread.mex', 'libsvmwrite.mex', 'svmtrain.mex', and 35 | 'svmpredict.mex'. 36 | 37 | On MATLAB or Octave: 38 | 39 | >> make 40 | 41 | If make.m does not work on MATLAB (especially for Windows), try 'mex 42 | -setup' to choose a suitable compiler for mex. Make sure your compiler 43 | is accessible and workable. Then type 'make' to do the installation. 44 | 45 | Example: 46 | 47 | matlab>> mex -setup 48 | 49 | MATLAB will choose the default compiler. If you have multiple compliers, 50 | a list is given and you can choose one from the list. For more details, 51 | please check the following page: 52 | 53 | https://www.mathworks.com/help/matlab/matlab_external/choose-c-or-c-compilers.html 54 | 55 | On Windows, make.m has been tested via using Visual C++. 56 | 57 | On Unix systems, if neither make.m nor 'mex -setup' works, please use 58 | Makefile and type 'make' in a command window. Note that we assume 59 | your MATLAB is installed in '/usr/local/matlab'. If not, please change 60 | MATLABDIR in Makefile. 61 | 62 | Example: 63 | linux> make 64 | 65 | To use octave, type 'make octave': 66 | 67 | Example: 68 | linux> make octave 69 | 70 | For a list of supported/compatible compilers for MATLAB, please check 71 | the following page: 72 | 73 | http://www.mathworks.com/support/compilers/current_release/ 74 | 75 | Usage 76 | ===== 77 | 78 | matlab> model = svmtrain(training_label_vector, training_instance_matrix [, 'libsvm_options']); 79 | 80 | -training_label_vector: 81 | An m by 1 vector of training labels (type must be double). 82 | -training_instance_matrix: 83 | An m by n matrix of m training instances with n features. 84 | It can be dense or sparse (type must be double). 85 | -libsvm_options: 86 | A string of training options in the same format as that of LIBSVM. 87 | 88 | matlab> [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model [, 'libsvm_options']); 89 | matlab> [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model [, 'libsvm_options']); 90 | 91 | -testing_label_vector: 92 | An m by 1 vector of prediction labels. If labels of test 93 | data are unknown, simply use any random values. (type must be double) 94 | -testing_instance_matrix: 95 | An m by n matrix of m testing instances with n features. 96 | It can be dense or sparse. (type must be double) 97 | -model: 98 | The output of svmtrain. 99 | -libsvm_options: 100 | A string of testing options in the same format as that of LIBSVM. 101 | 102 | Returned Model Structure 103 | ======================== 104 | 105 | The 'svmtrain' function returns a model which can be used for future 106 | prediction. It is a structure and is organized as [Parameters, nr_class, 107 | totalSV, rho, Label, ProbA, ProbB, Prob_density_marks, nSV, sv_coef, SVs]: 108 | 109 | -Parameters: parameters 110 | -nr_class: number of classes; = 2 for regression/one-class svm 111 | -totalSV: total #SV 112 | -rho: -b of the decision function(s) wx+b 113 | -Label: label of each class; empty for regression/one-class SVM 114 | -sv_indices: values in [1,...,num_traning_data] to indicate SVs in the training set 115 | -ProbA: pairwise probability information; empty if -b 0 or in one-class SVM 116 | -ProbB: pairwise probability information; empty if -b 0 or in one-class SVM 117 | -Prob_density_marks: probability information for one-class SVM; empty if -b 0 or not in one-class SVM 118 | -nSV: number of SVs for each class; empty for regression/one-class SVM 119 | -sv_coef: coefficients for SVs in decision functions 120 | -SVs: support vectors 121 | 122 | If you do not use the option '-b 1', ProbA and ProbB are empty 123 | matrices. If the '-v' option is specified, cross validation is 124 | conducted and the returned model is just a scalar: cross-validation 125 | accuracy for classification and mean-squared error for regression. 126 | 127 | More details about this model can be found in LIBSVM FAQ 128 | (http://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html) and LIBSVM 129 | implementation document 130 | (http://www.csie.ntu.edu.tw/~cjlin/papers/libsvm.pdf). 131 | 132 | Result of Prediction 133 | ==================== 134 | 135 | The function 'svmpredict' has three outputs. The first one, 136 | predictd_label, is a vector of predicted labels. The second output, 137 | accuracy, is a vector including accuracy (for classification), mean 138 | squared error, and squared correlation coefficient (for regression). 139 | The third is a matrix containing decision values or probability 140 | estimates (if '-b 1' is specified). If k is the number of classes 141 | in training data, for decision values, each row includes results of 142 | predicting k(k-1)/2 binary-class SVMs. For classification, k = 1 is a 143 | special case. Decision value +1 is returned for each testing instance, 144 | instead of an empty vector. For probabilities, each row contains k values 145 | indicating the probability that the testing instance is in each class. 146 | Note that the order of classes here is the same as 'Label' field 147 | in the model structure. 148 | For one-class SVM, each row contains two elements for probabilities 149 | of normal instance/outlier. 150 | 151 | Other Utilities 152 | =============== 153 | 154 | A matlab function libsvmread reads files in LIBSVM format: 155 | 156 | [label_vector, instance_matrix] = libsvmread('data.txt'); 157 | 158 | Two outputs are labels and instances, which can then be used as inputs 159 | of svmtrain or svmpredict. 160 | 161 | A matlab function libsvmwrite writes Matlab matrix to a file in LIBSVM format: 162 | 163 | libsvmwrite('data.txt', label_vector, instance_matrix) 164 | 165 | The instance_matrix must be a sparse matrix. (type must be double) 166 | For 32bit and 64bit MATLAB on Windows, pre-built binary files are ready 167 | in the directory `..\windows', but in future releases, we will only 168 | include 64bit MATLAB binary files. 169 | 170 | These codes are prepared by Rong-En Fan and Kai-Wei Chang from National 171 | Taiwan University. 172 | 173 | Examples 174 | ======== 175 | 176 | Train and test on the provided data heart_scale: 177 | 178 | matlab> [heart_scale_label, heart_scale_inst] = libsvmread('../heart_scale'); 179 | matlab> model = svmtrain(heart_scale_label, heart_scale_inst, '-c 1 -g 0.07'); 180 | matlab> [predict_label, accuracy, dec_values] = svmpredict(heart_scale_label, heart_scale_inst, model); % test the training data 181 | 182 | For probability estimates, you need '-b 1' for training and testing: 183 | 184 | matlab> [heart_scale_label, heart_scale_inst] = libsvmread('../heart_scale'); 185 | matlab> model = svmtrain(heart_scale_label, heart_scale_inst, '-c 1 -g 0.07 -b 1'); 186 | matlab> [heart_scale_label, heart_scale_inst] = libsvmread('../heart_scale'); 187 | matlab> [predict_label, accuracy, prob_estimates] = svmpredict(heart_scale_label, heart_scale_inst, model, '-b 1'); 188 | 189 | To use precomputed kernel, you must include sample serial number as 190 | the first column of the training and testing data (assume your kernel 191 | matrix is K, # of instances is n): 192 | 193 | matlab> K1 = [(1:n)', K]; % include sample serial number as first column 194 | matlab> model = svmtrain(label_vector, K1, '-t 4'); 195 | matlab> [predict_label, accuracy, dec_values] = svmpredict(label_vector, K1, model); % test the training data 196 | 197 | We give the following detailed example by splitting heart_scale into 198 | 150 training and 120 testing data. Constructing a linear kernel 199 | matrix and then using the precomputed kernel gives exactly the same 200 | testing error as using the LIBSVM built-in linear kernel. 201 | 202 | matlab> [heart_scale_label, heart_scale_inst] = libsvmread('../heart_scale'); 203 | matlab> 204 | matlab> % Split Data 205 | matlab> train_data = heart_scale_inst(1:150,:); 206 | matlab> train_label = heart_scale_label(1:150,:); 207 | matlab> test_data = heart_scale_inst(151:270,:); 208 | matlab> test_label = heart_scale_label(151:270,:); 209 | matlab> 210 | matlab> % Linear Kernel 211 | matlab> model_linear = svmtrain(train_label, train_data, '-t 0'); 212 | matlab> [predict_label_L, accuracy_L, dec_values_L] = svmpredict(test_label, test_data, model_linear); 213 | matlab> 214 | matlab> % Precomputed Kernel 215 | matlab> model_precomputed = svmtrain(train_label, [(1:150)', train_data*train_data'], '-t 4'); 216 | matlab> [predict_label_P, accuracy_P, dec_values_P] = svmpredict(test_label, [(1:120)', test_data*train_data'], model_precomputed); 217 | matlab> 218 | matlab> accuracy_L % Display the accuracy using linear kernel 219 | matlab> accuracy_P % Display the accuracy using precomputed kernel 220 | 221 | Note that for testing, you can put anything in the 222 | testing_label_vector. For more details of precomputed kernels, please 223 | read the section ``Precomputed Kernels'' in the README of the LIBSVM 224 | package. 225 | 226 | Additional Information 227 | ====================== 228 | 229 | This interface was initially written by Jun-Cheng Chen, Kuan-Jen Peng, 230 | Chih-Yuan Yang and Chih-Huai Cheng from Department of Computer 231 | Science, National Taiwan University. The current version was prepared 232 | by Rong-En Fan and Ting-Fan Wu. If you find this tool useful, please 233 | cite LIBSVM as follows 234 | 235 | Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for support 236 | vector machines. ACM Transactions on Intelligent Systems and 237 | Technology, 2:27:1--27:27, 2011. Software available at 238 | http://www.csie.ntu.edu.tw/~cjlin/libsvm 239 | 240 | For any question, please contact Chih-Jen Lin , 241 | or check the FAQ page: 242 | 243 | http://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q10:_MATLAB_interface 244 | -------------------------------------------------------------------------------- /matlab/libsvmread.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "mex.h" 8 | 9 | #ifdef MX_API_VER 10 | #if MX_API_VER < 0x07030000 11 | typedef int mwIndex; 12 | #endif 13 | #endif 14 | #ifndef max 15 | #define max(x,y) (((x)>(y))?(x):(y)) 16 | #endif 17 | #ifndef min 18 | #define min(x,y) (((x)<(y))?(x):(y)) 19 | #endif 20 | 21 | void exit_with_help() 22 | { 23 | mexPrintf( 24 | "Usage: [label_vector, instance_matrix] = libsvmread('filename');\n" 25 | ); 26 | } 27 | 28 | static void fake_answer(int nlhs, mxArray *plhs[]) 29 | { 30 | int i; 31 | for(i=0;i start from 0 86 | strtok(line," \t"); // label 87 | while (1) 88 | { 89 | idx = strtok(NULL,":"); // index:value 90 | val = strtok(NULL," \t"); 91 | if(val == NULL) 92 | break; 93 | 94 | errno = 0; 95 | index = (int) strtol(idx,&endptr,10); 96 | if(endptr == idx || errno != 0 || *endptr != '\0' || index <= inst_max_index) 97 | { 98 | mexPrintf("Wrong input format at line %d\n",l+1); 99 | fake_answer(nlhs, plhs); 100 | return; 101 | } 102 | else 103 | inst_max_index = index; 104 | 105 | min_index = min(min_index, index); 106 | elements++; 107 | } 108 | max_index = max(max_index, inst_max_index); 109 | l++; 110 | } 111 | rewind(fp); 112 | 113 | // y 114 | plhs[0] = mxCreateDoubleMatrix(l, 1, mxREAL); 115 | // x^T 116 | if (min_index <= 0) 117 | plhs[1] = mxCreateSparse(max_index-min_index+1, l, elements, mxREAL); 118 | else 119 | plhs[1] = mxCreateSparse(max_index, l, elements, mxREAL); 120 | 121 | labels = mxGetPr(plhs[0]); 122 | samples = mxGetPr(plhs[1]); 123 | ir = mxGetIr(plhs[1]); 124 | jc = mxGetJc(plhs[1]); 125 | 126 | k=0; 127 | for(i=0;i start from 0 158 | 159 | errno = 0; 160 | samples[k] = strtod(val,&endptr); 161 | if (endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) 162 | { 163 | mexPrintf("Wrong input format at line %d\n",i+1); 164 | fake_answer(nlhs, plhs); 165 | return; 166 | } 167 | ++k; 168 | } 169 | } 170 | jc[l] = k; 171 | 172 | fclose(fp); 173 | free(line); 174 | 175 | { 176 | mxArray *rhs[1], *lhs[1]; 177 | rhs[0] = plhs[1]; 178 | if(mexCallMATLAB(1, lhs, 1, rhs, "transpose")) 179 | { 180 | mexPrintf("Error: cannot transpose problem\n"); 181 | fake_answer(nlhs, plhs); 182 | return; 183 | } 184 | plhs[1] = lhs[0]; 185 | } 186 | } 187 | 188 | void mexFunction( int nlhs, mxArray *plhs[], 189 | int nrhs, const mxArray *prhs[] ) 190 | { 191 | #define filename_size 256 192 | 193 | char filename[filename_size]; 194 | 195 | if(nrhs != 1 || nlhs != 2) 196 | { 197 | exit_with_help(); 198 | fake_answer(nlhs, plhs); 199 | return; 200 | } 201 | 202 | if(mxGetString(prhs[0], filename, filename_size) == 1){ 203 | mexPrintf("Error: wrong or too long filename\n"); 204 | fake_answer(nlhs, plhs); 205 | return; 206 | } 207 | 208 | read_problem(filename, nlhs, plhs); 209 | 210 | return; 211 | } 212 | 213 | -------------------------------------------------------------------------------- /matlab/libsvmwrite.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "mex.h" 5 | 6 | #ifdef MX_API_VER 7 | #if MX_API_VER < 0x07030000 8 | typedef int mwIndex; 9 | #endif 10 | #endif 11 | 12 | void exit_with_help() 13 | { 14 | mexPrintf( 15 | "Usage: libsvmwrite('filename', label_vector, instance_matrix);\n" 16 | ); 17 | } 18 | 19 | static void fake_answer(int nlhs, mxArray *plhs[]) 20 | { 21 | int i; 22 | for(i=0;i 0) 88 | { 89 | exit_with_help(); 90 | fake_answer(nlhs, plhs); 91 | return; 92 | } 93 | 94 | // Transform the input Matrix to libsvm format 95 | if(nrhs == 3) 96 | { 97 | char filename[256]; 98 | if(!mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2])) 99 | { 100 | mexPrintf("Error: label vector and instance matrix must be double\n"); 101 | return; 102 | } 103 | 104 | mxGetString(prhs[0], filename, mxGetN(prhs[0])+1); 105 | 106 | if(mxIsSparse(prhs[2])) 107 | libsvmwrite(filename, prhs[1], prhs[2]); 108 | else 109 | { 110 | mexPrintf("Instance_matrix must be sparse\n"); 111 | return; 112 | } 113 | } 114 | else 115 | { 116 | exit_with_help(); 117 | return; 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /matlab/make.m: -------------------------------------------------------------------------------- 1 | % This make.m is for MATLAB and OCTAVE under Windows, Mac, and Unix 2 | function make() 3 | try 4 | % This part is for OCTAVE 5 | if (exist ('OCTAVE_VERSION', 'builtin')) 6 | mex libsvmread.c 7 | mex libsvmwrite.c 8 | mex -I.. svmtrain.c ../svm.cpp svm_model_matlab.c 9 | mex -I.. svmpredict.c ../svm.cpp svm_model_matlab.c 10 | % This part is for MATLAB 11 | % Add -largeArrayDims on 64-bit machines of MATLAB 12 | else 13 | mex -largeArrayDims libsvmread.c 14 | mex -largeArrayDims libsvmwrite.c 15 | mex -I.. -largeArrayDims svmtrain.c ../svm.cpp svm_model_matlab.c 16 | mex -I.. -largeArrayDims svmpredict.c ../svm.cpp svm_model_matlab.c 17 | end 18 | catch err 19 | fprintf('Error: %s failed (line %d)\n', err.stack(1).file, err.stack(1).line); 20 | disp(err.message); 21 | fprintf('=> Please check README for detailed instructions.\n'); 22 | end 23 | -------------------------------------------------------------------------------- /matlab/svm_model_matlab.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "svm.h" 4 | 5 | #include "mex.h" 6 | 7 | #ifdef MX_API_VER 8 | #if MX_API_VER < 0x07030000 9 | typedef int mwIndex; 10 | #endif 11 | #endif 12 | 13 | #define NUM_OF_RETURN_FIELD 12 14 | 15 | #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) 16 | 17 | static const char *field_names[] = { 18 | "Parameters", 19 | "nr_class", 20 | "totalSV", 21 | "rho", 22 | "Label", 23 | "sv_indices", 24 | "ProbA", 25 | "ProbB", 26 | "Prob_density_marks", 27 | "nSV", 28 | "sv_coef", 29 | "SVs" 30 | }; 31 | 32 | const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model) 33 | { 34 | int i, j, n; 35 | double *ptr; 36 | mxArray *return_model, **rhs; 37 | int out_id = 0; 38 | 39 | rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD); 40 | 41 | // Parameters 42 | rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL); 43 | ptr = mxGetPr(rhs[out_id]); 44 | ptr[0] = model->param.svm_type; 45 | ptr[1] = model->param.kernel_type; 46 | ptr[2] = model->param.degree; 47 | ptr[3] = model->param.gamma; 48 | ptr[4] = model->param.coef0; 49 | out_id++; 50 | 51 | // nr_class 52 | rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL); 53 | ptr = mxGetPr(rhs[out_id]); 54 | ptr[0] = model->nr_class; 55 | out_id++; 56 | 57 | // total SV 58 | rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL); 59 | ptr = mxGetPr(rhs[out_id]); 60 | ptr[0] = model->l; 61 | out_id++; 62 | 63 | // rho 64 | n = model->nr_class*(model->nr_class-1)/2; 65 | rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL); 66 | ptr = mxGetPr(rhs[out_id]); 67 | for(i = 0; i < n; i++) 68 | ptr[i] = model->rho[i]; 69 | out_id++; 70 | 71 | // Label 72 | if(model->label) 73 | { 74 | rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL); 75 | ptr = mxGetPr(rhs[out_id]); 76 | for(i = 0; i < model->nr_class; i++) 77 | ptr[i] = model->label[i]; 78 | } 79 | else 80 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 81 | out_id++; 82 | 83 | // sv_indices 84 | if(model->sv_indices) 85 | { 86 | rhs[out_id] = mxCreateDoubleMatrix(model->l, 1, mxREAL); 87 | ptr = mxGetPr(rhs[out_id]); 88 | for(i = 0; i < model->l; i++) 89 | ptr[i] = model->sv_indices[i]; 90 | } 91 | else 92 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 93 | out_id++; 94 | 95 | // probA 96 | if(model->probA != NULL) 97 | { 98 | rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL); 99 | ptr = mxGetPr(rhs[out_id]); 100 | for(i = 0; i < n; i++) 101 | ptr[i] = model->probA[i]; 102 | } 103 | else 104 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 105 | out_id ++; 106 | 107 | // probB 108 | if(model->probB != NULL) 109 | { 110 | rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL); 111 | ptr = mxGetPr(rhs[out_id]); 112 | for(i = 0; i < n; i++) 113 | ptr[i] = model->probB[i]; 114 | } 115 | else 116 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 117 | out_id++; 118 | 119 | // prob_density_marks 120 | if(model->prob_density_marks != NULL) 121 | { 122 | int nr_marks = 10; 123 | rhs[out_id] = mxCreateDoubleMatrix(nr_marks, 1, mxREAL); 124 | ptr = mxGetPr(rhs[out_id]); 125 | for(i = 0; i < nr_marks; i++) 126 | ptr[i] = model->prob_density_marks[i]; 127 | } 128 | else 129 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 130 | out_id++; 131 | 132 | // nSV 133 | if(model->nSV) 134 | { 135 | rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL); 136 | ptr = mxGetPr(rhs[out_id]); 137 | for(i = 0; i < model->nr_class; i++) 138 | ptr[i] = model->nSV[i]; 139 | } 140 | else 141 | rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL); 142 | out_id++; 143 | 144 | // sv_coef 145 | rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL); 146 | ptr = mxGetPr(rhs[out_id]); 147 | for(i = 0; i < model->nr_class-1; i++) 148 | for(j = 0; j < model->l; j++) 149 | ptr[(i*(model->l))+j] = model->sv_coef[i][j]; 150 | out_id++; 151 | 152 | // SVs 153 | { 154 | int ir_index, nonzero_element; 155 | mwIndex *ir, *jc; 156 | mxArray *pprhs[1], *pplhs[1]; 157 | 158 | if(model->param.kernel_type == PRECOMPUTED) 159 | { 160 | nonzero_element = model->l; 161 | num_of_feature = 1; 162 | } 163 | else 164 | { 165 | nonzero_element = 0; 166 | for(i = 0; i < model->l; i++) { 167 | j = 0; 168 | while(model->SV[i][j].index != -1) 169 | { 170 | nonzero_element++; 171 | j++; 172 | } 173 | } 174 | } 175 | 176 | // SV in column, easier accessing 177 | rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL); 178 | ir = mxGetIr(rhs[out_id]); 179 | jc = mxGetJc(rhs[out_id]); 180 | ptr = mxGetPr(rhs[out_id]); 181 | jc[0] = ir_index = 0; 182 | for(i = 0;i < model->l; i++) 183 | { 184 | if(model->param.kernel_type == PRECOMPUTED) 185 | { 186 | // make a (1 x model->l) matrix 187 | ir[ir_index] = 0; 188 | ptr[ir_index] = model->SV[i][0].value; 189 | ir_index++; 190 | jc[i+1] = jc[i] + 1; 191 | } 192 | else 193 | { 194 | int x_index = 0; 195 | while (model->SV[i][x_index].index != -1) 196 | { 197 | ir[ir_index] = model->SV[i][x_index].index - 1; 198 | ptr[ir_index] = model->SV[i][x_index].value; 199 | ir_index++, x_index++; 200 | } 201 | jc[i+1] = jc[i] + x_index; 202 | } 203 | } 204 | // transpose back to SV in row 205 | pprhs[0] = rhs[out_id]; 206 | if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 207 | return "cannot transpose SV matrix"; 208 | rhs[out_id] = pplhs[0]; 209 | out_id++; 210 | } 211 | 212 | /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */ 213 | return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names); 214 | 215 | /* Fill struct matrix with input arguments */ 216 | for(i = 0; i < NUM_OF_RETURN_FIELD; i++) 217 | mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i])); 218 | /* return */ 219 | plhs[0] = return_model; 220 | mxFree(rhs); 221 | 222 | return NULL; 223 | } 224 | 225 | struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg) 226 | { 227 | int i, j, n, num_of_fields; 228 | double *ptr; 229 | int id = 0; 230 | struct svm_node *x_space; 231 | struct svm_model *model; 232 | mxArray **rhs; 233 | 234 | num_of_fields = mxGetNumberOfFields(matlab_struct); 235 | if(num_of_fields != NUM_OF_RETURN_FIELD) 236 | { 237 | *msg = "number of return field is not correct"; 238 | return NULL; 239 | } 240 | rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields); 241 | 242 | for(i=0;irho = NULL; 247 | model->probA = NULL; 248 | model->probB = NULL; 249 | model->prob_density_marks = NULL; 250 | model->label = NULL; 251 | model->sv_indices = NULL; 252 | model->nSV = NULL; 253 | model->free_sv = 1; // XXX 254 | 255 | ptr = mxGetPr(rhs[id]); 256 | model->param.svm_type = (int)ptr[0]; 257 | model->param.kernel_type = (int)ptr[1]; 258 | model->param.degree = (int)ptr[2]; 259 | model->param.gamma = ptr[3]; 260 | model->param.coef0 = ptr[4]; 261 | id++; 262 | 263 | ptr = mxGetPr(rhs[id]); 264 | model->nr_class = (int)ptr[0]; 265 | id++; 266 | 267 | ptr = mxGetPr(rhs[id]); 268 | model->l = (int)ptr[0]; 269 | id++; 270 | 271 | // rho 272 | n = model->nr_class * (model->nr_class-1)/2; 273 | model->rho = (double*) malloc(n*sizeof(double)); 274 | ptr = mxGetPr(rhs[id]); 275 | for(i=0;irho[i] = ptr[i]; 277 | id++; 278 | 279 | // label 280 | if(mxIsEmpty(rhs[id]) == 0) 281 | { 282 | model->label = (int*) malloc(model->nr_class*sizeof(int)); 283 | ptr = mxGetPr(rhs[id]); 284 | for(i=0;inr_class;i++) 285 | model->label[i] = (int)ptr[i]; 286 | } 287 | id++; 288 | 289 | // sv_indices 290 | if(mxIsEmpty(rhs[id]) == 0) 291 | { 292 | model->sv_indices = (int*) malloc(model->l*sizeof(int)); 293 | ptr = mxGetPr(rhs[id]); 294 | for(i=0;il;i++) 295 | model->sv_indices[i] = (int)ptr[i]; 296 | } 297 | id++; 298 | 299 | // probA 300 | if(mxIsEmpty(rhs[id]) == 0) 301 | { 302 | model->probA = (double*) malloc(n*sizeof(double)); 303 | ptr = mxGetPr(rhs[id]); 304 | for(i=0;iprobA[i] = ptr[i]; 306 | } 307 | id++; 308 | 309 | // probB 310 | if(mxIsEmpty(rhs[id]) == 0) 311 | { 312 | model->probB = (double*) malloc(n*sizeof(double)); 313 | ptr = mxGetPr(rhs[id]); 314 | for(i=0;iprobB[i] = ptr[i]; 316 | } 317 | id++; 318 | 319 | // prob_density_marks 320 | if(mxIsEmpty(rhs[id]) == 0) 321 | { 322 | int nr_marks = 10; 323 | model->prob_density_marks = (double*) malloc(nr_marks*sizeof(double)); 324 | ptr = mxGetPr(rhs[id]); 325 | for(i=0;iprob_density_marks[i] = ptr[i]; 327 | } 328 | id++; 329 | 330 | // nSV 331 | if(mxIsEmpty(rhs[id]) == 0) 332 | { 333 | model->nSV = (int*) malloc(model->nr_class*sizeof(int)); 334 | ptr = mxGetPr(rhs[id]); 335 | for(i=0;inr_class;i++) 336 | model->nSV[i] = (int)ptr[i]; 337 | } 338 | id++; 339 | 340 | // sv_coef 341 | ptr = mxGetPr(rhs[id]); 342 | model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double)); 343 | for( i=0 ; i< model->nr_class -1 ; i++ ) 344 | model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double)); 345 | for(i = 0; i < model->nr_class - 1; i++) 346 | for(j = 0; j < model->l; j++) 347 | model->sv_coef[i][j] = ptr[i*(model->l)+j]; 348 | id++; 349 | 350 | // SV 351 | { 352 | int sr, elements; 353 | int num_samples; 354 | mwIndex *ir, *jc; 355 | mxArray *pprhs[1], *pplhs[1]; 356 | 357 | // transpose SV 358 | pprhs[0] = rhs[id]; 359 | if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 360 | { 361 | svm_free_and_destroy_model(&model); 362 | *msg = "cannot transpose SV matrix"; 363 | return NULL; 364 | } 365 | rhs[id] = pplhs[0]; 366 | 367 | sr = (int)mxGetN(rhs[id]); 368 | 369 | ptr = mxGetPr(rhs[id]); 370 | ir = mxGetIr(rhs[id]); 371 | jc = mxGetJc(rhs[id]); 372 | 373 | num_samples = (int)mxGetNzmax(rhs[id]); 374 | 375 | elements = num_samples + sr; 376 | 377 | model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *)); 378 | x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node)); 379 | 380 | // SV is in column 381 | for(i=0;iSV[i] = &x_space[low+i]; 386 | for(j=low;jSV[i][x_index].index = (int)ir[j] + 1; 389 | model->SV[i][x_index].value = ptr[j]; 390 | x_index++; 391 | } 392 | model->SV[i][x_index].index = -1; 393 | } 394 | 395 | id++; 396 | } 397 | mxFree(rhs); 398 | 399 | return model; 400 | } 401 | -------------------------------------------------------------------------------- /matlab/svm_model_matlab.h: -------------------------------------------------------------------------------- 1 | const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model); 2 | struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **error_message); 3 | -------------------------------------------------------------------------------- /matlab/svmpredict.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "svm.h" 5 | 6 | #include "mex.h" 7 | #include "svm_model_matlab.h" 8 | 9 | #ifdef MX_API_VER 10 | #if MX_API_VER < 0x07030000 11 | typedef int mwIndex; 12 | #endif 13 | #endif 14 | 15 | #define CMD_LEN 2048 16 | 17 | int print_null(const char *s,...) {return 0;} 18 | int (*info)(const char *fmt,...) = &mexPrintf; 19 | 20 | void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x) 21 | { 22 | int i, j, low, high; 23 | mwIndex *ir, *jc; 24 | double *samples; 25 | 26 | ir = mxGetIr(prhs); 27 | jc = mxGetJc(prhs); 28 | samples = mxGetPr(prhs); 29 | 30 | // each column is one instance 31 | j = 0; 32 | low = (int)jc[index], high = (int)jc[index+1]; 33 | for(i=low;iparam.kernel_type == PRECOMPUTED) 95 | { 96 | // precomputed kernel requires dense matrix, so we make one 97 | mxArray *rhs[1], *lhs[1]; 98 | rhs[0] = mxDuplicateArray(prhs[1]); 99 | if(mexCallMATLAB(1, lhs, 1, rhs, "full")) 100 | { 101 | mexPrintf("Error: cannot full testing instance matrix\n"); 102 | fake_answer(nlhs, plhs); 103 | return; 104 | } 105 | ptr_instance = mxGetPr(lhs[0]); 106 | mxDestroyArray(rhs[0]); 107 | } 108 | else 109 | { 110 | mxArray *pprhs[1]; 111 | pprhs[0] = mxDuplicateArray(prhs[1]); 112 | if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 113 | { 114 | mexPrintf("Error: cannot transpose testing instance matrix\n"); 115 | fake_answer(nlhs, plhs); 116 | return; 117 | } 118 | } 119 | } 120 | 121 | if(predict_probability) 122 | { 123 | if(svm_type==NU_SVR || svm_type==EPSILON_SVR) 124 | info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model)); 125 | else 126 | prob_estimates = (double *) malloc(nr_class*sizeof(double)); 127 | } 128 | 129 | tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL); 130 | if(predict_probability) 131 | { 132 | // prob estimates are in plhs[2] 133 | if(svm_type==C_SVC || svm_type==NU_SVC || svm_type==ONE_CLASS) 134 | { 135 | // nr_class = 2 for ONE_CLASS 136 | tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL); 137 | } 138 | else 139 | tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL); 140 | } 141 | else 142 | { 143 | // decision values are in plhs[2] 144 | if(svm_type == ONE_CLASS || 145 | svm_type == EPSILON_SVR || 146 | svm_type == NU_SVR || 147 | nr_class == 1) // if only one class in training data, decision values are still returned. 148 | tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL); 149 | else 150 | tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL); 151 | } 152 | 153 | ptr_predict_label = mxGetPr(tplhs[0]); 154 | ptr_prob_estimates = mxGetPr(tplhs[2]); 155 | ptr_dec_values = mxGetPr(tplhs[2]); 156 | x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) ); 157 | for(instance_index=0;instance_indexparam.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse 165 | read_sparse_instance(pplhs[0], instance_index, x); 166 | else 167 | { 168 | for(i=0;i 3 || nrhs > 4 || nrhs < 3) 283 | { 284 | exit_with_help(); 285 | fake_answer(nlhs, plhs); 286 | return; 287 | } 288 | 289 | if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { 290 | mexPrintf("Error: label vector and instance matrix must be double\n"); 291 | fake_answer(nlhs, plhs); 292 | return; 293 | } 294 | 295 | if(mxIsStruct(prhs[2])) 296 | { 297 | const char *error_msg; 298 | 299 | // parse options 300 | if(nrhs==4) 301 | { 302 | int i, argc = 1; 303 | char cmd[CMD_LEN], *argv[CMD_LEN/2]; 304 | 305 | // put options in argv[] 306 | mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1); 307 | if((argv[argc] = strtok(cmd, " ")) != NULL) 308 | while((argv[++argc] = strtok(NULL, " ")) != NULL) 309 | ; 310 | 311 | for(i=1;i=argc) && argv[i-1][1] != 'q') 315 | { 316 | exit_with_help(); 317 | fake_answer(nlhs, plhs); 318 | return; 319 | } 320 | switch(argv[i-1][1]) 321 | { 322 | case 'b': 323 | prob_estimate_flag = atoi(argv[i]); 324 | break; 325 | case 'q': 326 | i--; 327 | info = &print_null; 328 | break; 329 | default: 330 | mexPrintf("Unknown option: -%c\n", argv[i-1][1]); 331 | exit_with_help(); 332 | fake_answer(nlhs, plhs); 333 | return; 334 | } 335 | } 336 | } 337 | 338 | model = matlab_matrix_to_model(prhs[2], &error_msg); 339 | if (model == NULL) 340 | { 341 | mexPrintf("Error: can't read model: %s\n", error_msg); 342 | fake_answer(nlhs, plhs); 343 | return; 344 | } 345 | 346 | if(prob_estimate_flag) 347 | { 348 | if(svm_check_probability_model(model)==0) 349 | { 350 | mexPrintf("Model does not support probabiliy estimates\n"); 351 | fake_answer(nlhs, plhs); 352 | svm_free_and_destroy_model(&model); 353 | return; 354 | } 355 | } 356 | else 357 | { 358 | if(svm_check_probability_model(model)!=0) 359 | info("Model supports probability estimates, but disabled in predicton.\n"); 360 | } 361 | 362 | predict(nlhs, plhs, prhs, model, prob_estimate_flag); 363 | // destroy model 364 | svm_free_and_destroy_model(&model); 365 | } 366 | else 367 | { 368 | mexPrintf("model file should be a struct array\n"); 369 | fake_answer(nlhs, plhs); 370 | } 371 | 372 | return; 373 | } 374 | -------------------------------------------------------------------------------- /matlab/svmtrain.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "svm.h" 6 | 7 | #include "mex.h" 8 | #include "svm_model_matlab.h" 9 | 10 | #ifdef MX_API_VER 11 | #if MX_API_VER < 0x07030000 12 | typedef int mwIndex; 13 | #endif 14 | #endif 15 | 16 | #define CMD_LEN 2048 17 | #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) 18 | 19 | void print_null(const char *s) {} 20 | void print_string_matlab(const char *s) {mexPrintf(s);} 21 | 22 | void exit_with_help() 23 | { 24 | mexPrintf( 25 | "Usage: model = svmtrain(training_label_vector, training_instance_matrix, 'libsvm_options');\n" 26 | "libsvm_options:\n" 27 | "-s svm_type : set type of SVM (default 0)\n" 28 | " 0 -- C-SVC (multi-class classification)\n" 29 | " 1 -- nu-SVC (multi-class classification)\n" 30 | " 2 -- one-class SVM\n" 31 | " 3 -- epsilon-SVR (regression)\n" 32 | " 4 -- nu-SVR (regression)\n" 33 | "-t kernel_type : set type of kernel function (default 2)\n" 34 | " 0 -- linear: u'*v\n" 35 | " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n" 36 | " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n" 37 | " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n" 38 | " 4 -- precomputed kernel (kernel values in training_instance_matrix)\n" 39 | "-d degree : set degree in kernel function (default 3)\n" 40 | "-g gamma : set gamma in kernel function (default 1/num_features)\n" 41 | "-r coef0 : set coef0 in kernel function (default 0)\n" 42 | "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n" 43 | "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n" 44 | "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n" 45 | "-m cachesize : set cache memory size in MB (default 100)\n" 46 | "-e epsilon : set tolerance of termination criterion (default 0.001)\n" 47 | "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n" 48 | "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n" 49 | "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n" 50 | "-v n: n-fold cross validation mode\n" 51 | "-q : quiet mode (no outputs)\n" 52 | ); 53 | } 54 | 55 | // svm arguments 56 | struct svm_parameter param; // set by parse_command_line 57 | struct svm_problem prob; // set by read_problem 58 | struct svm_model *model; 59 | struct svm_node *x_space; 60 | int cross_validation; 61 | int nr_fold; 62 | 63 | 64 | double do_cross_validation() 65 | { 66 | int i; 67 | int total_correct = 0; 68 | double total_error = 0; 69 | double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 70 | double *target = Malloc(double,prob.l); 71 | double retval = 0.0; 72 | 73 | svm_cross_validation(&prob,¶m,nr_fold,target); 74 | if(param.svm_type == EPSILON_SVR || 75 | param.svm_type == NU_SVR) 76 | { 77 | for(i=0;i 2) 137 | { 138 | // put options in argv[] 139 | mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1); 140 | if((argv[argc] = strtok(cmd, " ")) != NULL) 141 | while((argv[++argc] = strtok(NULL, " ")) != NULL) 142 | ; 143 | } 144 | 145 | // parse options 146 | for(i=1;i=argc && argv[i-1][1] != 'q') // since option -q has no parameter 151 | return 1; 152 | switch(argv[i-1][1]) 153 | { 154 | case 's': 155 | param.svm_type = atoi(argv[i]); 156 | break; 157 | case 't': 158 | param.kernel_type = atoi(argv[i]); 159 | break; 160 | case 'd': 161 | param.degree = atoi(argv[i]); 162 | break; 163 | case 'g': 164 | param.gamma = atof(argv[i]); 165 | break; 166 | case 'r': 167 | param.coef0 = atof(argv[i]); 168 | break; 169 | case 'n': 170 | param.nu = atof(argv[i]); 171 | break; 172 | case 'm': 173 | param.cache_size = atof(argv[i]); 174 | break; 175 | case 'c': 176 | param.C = atof(argv[i]); 177 | break; 178 | case 'e': 179 | param.eps = atof(argv[i]); 180 | break; 181 | case 'p': 182 | param.p = atof(argv[i]); 183 | break; 184 | case 'h': 185 | param.shrinking = atoi(argv[i]); 186 | break; 187 | case 'b': 188 | param.probability = atoi(argv[i]); 189 | break; 190 | case 'q': 191 | print_func = &print_null; 192 | i--; 193 | break; 194 | case 'v': 195 | cross_validation = 1; 196 | nr_fold = atoi(argv[i]); 197 | if(nr_fold < 2) 198 | { 199 | mexPrintf("n-fold cross validation: n must >= 2\n"); 200 | return 1; 201 | } 202 | break; 203 | case 'w': 204 | ++param.nr_weight; 205 | param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight); 206 | param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight); 207 | param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); 208 | param.weight[param.nr_weight-1] = atof(argv[i]); 209 | break; 210 | default: 211 | mexPrintf("Unknown option -%c\n", argv[i-1][1]); 212 | return 1; 213 | } 214 | } 215 | 216 | svm_set_print_string_function(print_func); 217 | 218 | return 0; 219 | } 220 | 221 | // read in a problem (in svmlight format) 222 | int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat) 223 | { 224 | // using size_t due to the output type of matlab functions 225 | size_t i, j, k, l; 226 | size_t elements, max_index, sc, label_vector_row_num; 227 | double *samples, *labels; 228 | 229 | prob.x = NULL; 230 | prob.y = NULL; 231 | x_space = NULL; 232 | 233 | labels = mxGetPr(label_vec); 234 | samples = mxGetPr(instance_mat); 235 | sc = mxGetN(instance_mat); 236 | 237 | elements = 0; 238 | // number of instances 239 | l = mxGetM(instance_mat); 240 | label_vector_row_num = mxGetM(label_vec); 241 | prob.l = (int)l; 242 | 243 | if(label_vector_row_num!=l) 244 | { 245 | mexPrintf("Length of label vector does not match # of instances.\n"); 246 | return -1; 247 | } 248 | 249 | if(param.kernel_type == PRECOMPUTED) 250 | elements = l * (sc + 1); 251 | else 252 | { 253 | for(i = 0; i < l; i++) 254 | { 255 | for(k = 0; k < sc; k++) 256 | if(samples[k * l + i] != 0) 257 | elements++; 258 | // count the '-1' element 259 | elements++; 260 | } 261 | } 262 | 263 | prob.y = Malloc(double,l); 264 | prob.x = Malloc(struct svm_node *,l); 265 | x_space = Malloc(struct svm_node, elements); 266 | 267 | max_index = sc; 268 | j = 0; 269 | for(i = 0; i < l; i++) 270 | { 271 | prob.x[i] = &x_space[j]; 272 | prob.y[i] = labels[i]; 273 | 274 | for(k = 0; k < sc; k++) 275 | { 276 | if(param.kernel_type == PRECOMPUTED || samples[k * l + i] != 0) 277 | { 278 | x_space[j].index = (int)k + 1; 279 | x_space[j].value = samples[k * l + i]; 280 | j++; 281 | } 282 | } 283 | x_space[j++].index = -1; 284 | } 285 | 286 | if(param.gamma == 0 && max_index > 0) 287 | param.gamma = (double)(1.0/max_index); 288 | 289 | if(param.kernel_type == PRECOMPUTED) 290 | for(i=0;i (int)max_index) 293 | { 294 | mexPrintf("Wrong input format: sample_serial_number out of range\n"); 295 | return -1; 296 | } 297 | } 298 | 299 | return 0; 300 | } 301 | 302 | int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat) 303 | { 304 | mwIndex *ir, *jc, low, high, k; 305 | // using size_t due to the output type of matlab functions 306 | size_t i, j, l, elements, max_index, label_vector_row_num; 307 | mwSize num_samples; 308 | double *samples, *labels; 309 | mxArray *instance_mat_col; // transposed instance sparse matrix 310 | 311 | prob.x = NULL; 312 | prob.y = NULL; 313 | x_space = NULL; 314 | 315 | // transpose instance matrix 316 | { 317 | mxArray *prhs[1], *plhs[1]; 318 | prhs[0] = mxDuplicateArray(instance_mat); 319 | if(mexCallMATLAB(1, plhs, 1, prhs, "transpose")) 320 | { 321 | mexPrintf("Error: cannot transpose training instance matrix\n"); 322 | return -1; 323 | } 324 | instance_mat_col = plhs[0]; 325 | mxDestroyArray(prhs[0]); 326 | } 327 | 328 | // each column is one instance 329 | labels = mxGetPr(label_vec); 330 | samples = mxGetPr(instance_mat_col); 331 | ir = mxGetIr(instance_mat_col); 332 | jc = mxGetJc(instance_mat_col); 333 | 334 | num_samples = mxGetNzmax(instance_mat_col); 335 | 336 | // number of instances 337 | l = mxGetN(instance_mat_col); 338 | label_vector_row_num = mxGetM(label_vec); 339 | prob.l = (int) l; 340 | 341 | if(label_vector_row_num!=l) 342 | { 343 | mexPrintf("Length of label vector does not match # of instances.\n"); 344 | return -1; 345 | } 346 | 347 | elements = num_samples + l; 348 | max_index = mxGetM(instance_mat_col); 349 | 350 | prob.y = Malloc(double,l); 351 | prob.x = Malloc(struct svm_node *,l); 352 | x_space = Malloc(struct svm_node, elements); 353 | 354 | j = 0; 355 | for(i=0;i 0) 370 | param.gamma = (double)(1.0/max_index); 371 | 372 | return 0; 373 | } 374 | 375 | static void fake_answer(int nlhs, mxArray *plhs[]) 376 | { 377 | int i; 378 | for(i=0;i 1) 394 | { 395 | exit_with_help(); 396 | fake_answer(nlhs, plhs); 397 | return; 398 | } 399 | 400 | // Transform the input Matrix to libsvm format 401 | if(nrhs > 1 && nrhs < 4) 402 | { 403 | int err; 404 | 405 | if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) 406 | { 407 | mexPrintf("Error: label vector and instance matrix must be double\n"); 408 | fake_answer(nlhs, plhs); 409 | return; 410 | } 411 | 412 | if(mxIsSparse(prhs[0])) 413 | { 414 | mexPrintf("Error: label vector should not be in sparse format\n"); 415 | fake_answer(nlhs, plhs); 416 | return; 417 | } 418 | 419 | if(parse_command_line(nrhs, prhs, NULL)) 420 | { 421 | exit_with_help(); 422 | svm_destroy_param(¶m); 423 | fake_answer(nlhs, plhs); 424 | return; 425 | } 426 | 427 | if(mxIsSparse(prhs[1])) 428 | { 429 | if(param.kernel_type == PRECOMPUTED) 430 | { 431 | // precomputed kernel requires dense matrix, so we make one 432 | mxArray *rhs[1], *lhs[1]; 433 | 434 | rhs[0] = mxDuplicateArray(prhs[1]); 435 | if(mexCallMATLAB(1, lhs, 1, rhs, "full")) 436 | { 437 | mexPrintf("Error: cannot generate a full training instance matrix\n"); 438 | svm_destroy_param(¶m); 439 | fake_answer(nlhs, plhs); 440 | return; 441 | } 442 | err = read_problem_dense(prhs[0], lhs[0]); 443 | mxDestroyArray(lhs[0]); 444 | mxDestroyArray(rhs[0]); 445 | } 446 | else 447 | err = read_problem_sparse(prhs[0], prhs[1]); 448 | } 449 | else 450 | err = read_problem_dense(prhs[0], prhs[1]); 451 | 452 | // svmtrain's original code 453 | error_msg = svm_check_parameter(&prob, ¶m); 454 | 455 | if(err || error_msg) 456 | { 457 | if (error_msg != NULL) 458 | mexPrintf("Error: %s\n", error_msg); 459 | svm_destroy_param(¶m); 460 | free(prob.y); 461 | free(prob.x); 462 | free(x_space); 463 | fake_answer(nlhs, plhs); 464 | return; 465 | } 466 | 467 | if(cross_validation) 468 | { 469 | double *ptr; 470 | plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL); 471 | ptr = mxGetPr(plhs[0]); 472 | ptr[0] = do_cross_validation(); 473 | } 474 | else 475 | { 476 | int nr_feat = (int)mxGetN(prhs[1]); 477 | const char *error_msg; 478 | model = svm_train(&prob, ¶m); 479 | error_msg = model_to_matlab_structure(plhs, nr_feat, model); 480 | if(error_msg) 481 | mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg); 482 | svm_free_and_destroy_model(&model); 483 | } 484 | svm_destroy_param(¶m); 485 | free(prob.y); 486 | free(prob.x); 487 | free(x_space); 488 | } 489 | else 490 | { 491 | exit_with_help(); 492 | fake_answer(nlhs, plhs); 493 | return; 494 | } 495 | } 496 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include cpp-source/* 2 | include cpp-source/*/* 3 | -------------------------------------------------------------------------------- /python/Makefile: -------------------------------------------------------------------------------- 1 | all = lib 2 | 3 | lib: 4 | make -C .. lib 5 | -------------------------------------------------------------------------------- /python/libsvm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/python/libsvm/__init__.py -------------------------------------------------------------------------------- /python/libsvm/commonutil.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from array import array 3 | import sys 4 | 5 | try: 6 | import numpy as np 7 | import scipy 8 | from scipy import sparse 9 | except: 10 | scipy = None 11 | 12 | 13 | __all__ = ['svm_read_problem', 'evaluations', 'csr_find_scale_param', 'csr_scale'] 14 | 15 | def svm_read_problem(data_source, return_scipy=False): 16 | """ 17 | svm_read_problem(data_source, return_scipy=False) -> [y, x], y: list, x: list of dictionary 18 | svm_read_problem(data_source, return_scipy=True) -> [y, x], y: ndarray, x: csr_matrix 19 | 20 | Read LIBSVM-format data from data_source and return labels y 21 | and data instances x. 22 | """ 23 | if scipy != None and return_scipy: 24 | prob_y = array('d') 25 | prob_x = array('d') 26 | row_ptr = array('l', [0]) 27 | col_idx = array('l') 28 | else: 29 | prob_y = [] 30 | prob_x = [] 31 | row_ptr = [0] 32 | col_idx = [] 33 | indx_start = 1 34 | 35 | if hasattr(data_source, "read"): 36 | file = data_source 37 | else: 38 | file = open(data_source) 39 | try: 40 | for line in file: 41 | line = line.split(None, 1) 42 | # In case an instance with all zero features 43 | if len(line) == 1: line += [''] 44 | label, features = line 45 | prob_y.append(float(label)) 46 | if scipy != None and return_scipy: 47 | nz = 0 48 | for e in features.split(): 49 | ind, val = e.split(":") 50 | if ind == '0': 51 | indx_start = 0 52 | val = float(val) 53 | if val != 0: 54 | col_idx.append(int(ind)-indx_start) 55 | prob_x.append(val) 56 | nz += 1 57 | row_ptr.append(row_ptr[-1]+nz) 58 | else: 59 | xi = {} 60 | for e in features.split(): 61 | ind, val = e.split(":") 62 | xi[int(ind)] = float(val) 63 | prob_x += [xi] 64 | except Exception as err_msg: 65 | raise err_msg 66 | finally: 67 | if not hasattr(data_source, "read"): 68 | # close file only if it was created by us 69 | file.close() 70 | 71 | if scipy != None and return_scipy: 72 | prob_y = np.frombuffer(prob_y, dtype='d') 73 | prob_x = np.frombuffer(prob_x, dtype='d') 74 | col_idx = np.frombuffer(col_idx, dtype='l') 75 | row_ptr = np.frombuffer(row_ptr, dtype='l') 76 | prob_x = sparse.csr_matrix((prob_x, col_idx, row_ptr)) 77 | return (prob_y, prob_x) 78 | 79 | def evaluations_scipy(ty, pv): 80 | """ 81 | evaluations_scipy(ty, pv) -> (ACC, MSE, SCC) 82 | ty, pv: ndarray 83 | 84 | Calculate accuracy, mean squared error and squared correlation coefficient 85 | using the true values (ty) and predicted values (pv). 86 | """ 87 | if not (scipy != None and isinstance(ty, np.ndarray) and isinstance(pv, np.ndarray)): 88 | raise TypeError("type of ty and pv must be ndarray") 89 | if len(ty) != len(pv): 90 | raise ValueError("len(ty) must be equal to len(pv)") 91 | ACC = 100.0*(ty == pv).mean() 92 | MSE = ((ty - pv)**2).mean() 93 | l = len(ty) 94 | sumv = pv.sum() 95 | sumy = ty.sum() 96 | sumvy = (pv*ty).sum() 97 | sumvv = (pv*pv).sum() 98 | sumyy = (ty*ty).sum() 99 | with np.errstate(all = 'raise'): 100 | try: 101 | SCC = ((l*sumvy-sumv*sumy)*(l*sumvy-sumv*sumy))/((l*sumvv-sumv*sumv)*(l*sumyy-sumy*sumy)) 102 | except: 103 | SCC = float('nan') 104 | return (float(ACC), float(MSE), float(SCC)) 105 | 106 | def evaluations(ty, pv, useScipy = True): 107 | """ 108 | evaluations(ty, pv, useScipy) -> (ACC, MSE, SCC) 109 | ty, pv: list, tuple or ndarray 110 | useScipy: convert ty, pv to ndarray, and use scipy functions for the evaluation 111 | 112 | Calculate accuracy, mean squared error and squared correlation coefficient 113 | using the true values (ty) and predicted values (pv). 114 | """ 115 | if scipy != None and useScipy: 116 | return evaluations_scipy(np.asarray(ty), np.asarray(pv)) 117 | if len(ty) != len(pv): 118 | raise ValueError("len(ty) must be equal to len(pv)") 119 | total_correct = total_error = 0 120 | sumv = sumy = sumvv = sumyy = sumvy = 0 121 | for v, y in zip(pv, ty): 122 | if y == v: 123 | total_correct += 1 124 | total_error += (v-y)*(v-y) 125 | sumv += v 126 | sumy += y 127 | sumvv += v*v 128 | sumyy += y*y 129 | sumvy += v*y 130 | l = len(ty) 131 | ACC = 100.0*total_correct/l 132 | MSE = total_error/l 133 | try: 134 | SCC = ((l*sumvy-sumv*sumy)*(l*sumvy-sumv*sumy))/((l*sumvv-sumv*sumv)*(l*sumyy-sumy*sumy)) 135 | except: 136 | SCC = float('nan') 137 | return (float(ACC), float(MSE), float(SCC)) 138 | 139 | def csr_find_scale_param(x, lower=-1, upper=1): 140 | assert isinstance(x, sparse.csr_matrix) 141 | assert lower < upper 142 | l, n = x.shape 143 | feat_min = x.min(axis=0).toarray().flatten() 144 | feat_max = x.max(axis=0).toarray().flatten() 145 | coef = (feat_max - feat_min) / (upper - lower) 146 | coef[coef != 0] = 1.0 / coef[coef != 0] 147 | 148 | # (x - ones(l,1) * feat_min') * diag(coef) + lower 149 | # = x * diag(coef) - ones(l, 1) * (feat_min' * diag(coef)) + lower 150 | # = x * diag(coef) + ones(l, 1) * (-feat_min' * diag(coef) + lower) 151 | # = x * diag(coef) + ones(l, 1) * offset' 152 | offset = -feat_min * coef + lower 153 | offset[coef == 0] = 0 154 | 155 | if sum(offset != 0) * l > 3 * x.getnnz(): 156 | print( 157 | "WARNING: The #nonzeros of the scaled data is at least 2 times larger than the original one.\n" 158 | "If feature values are non-negative and sparse, set lower=0 rather than the default lower=-1.", 159 | file=sys.stderr) 160 | 161 | return {'coef':coef, 'offset':offset} 162 | 163 | def csr_scale(x, scale_param): 164 | assert isinstance(x, sparse.csr_matrix) 165 | 166 | offset = scale_param['offset'] 167 | coef = scale_param['coef'] 168 | assert len(coef) == len(offset) 169 | 170 | l, n = x.shape 171 | 172 | if not n == len(coef): 173 | print("WARNING: The dimension of scaling parameters and feature number do not match.", file=sys.stderr) 174 | coef = coef.resize(n) # zeros padded if n > len(coef) 175 | offset = offset.resize(n) 176 | 177 | # scaled_x = x * diag(coef) + ones(l, 1) * offset' 178 | offset = sparse.csr_matrix(offset.reshape(1, n)) 179 | offset = sparse.vstack([offset] * l, format='csr', dtype=x.dtype) 180 | scaled_x = x.dot(sparse.diags(coef, 0, shape=(n, n))) + offset 181 | 182 | if scaled_x.getnnz() > x.getnnz(): 183 | print( 184 | "WARNING: original #nonzeros %d\n" % x.getnnz() + 185 | " > new #nonzeros %d\n" % scaled_x.getnnz() + 186 | "If feature values are non-negative and sparse, get scale_param by setting lower=0 rather than the default lower=-1.", 187 | file=sys.stderr) 188 | 189 | return scaled_x 190 | -------------------------------------------------------------------------------- /python/libsvm/svm.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | from ctypes.util import find_library 3 | from os import path 4 | from glob import glob 5 | from enum import IntEnum 6 | import sys 7 | 8 | try: 9 | import numpy as np 10 | import scipy 11 | from scipy import sparse 12 | except: 13 | scipy = None 14 | 15 | 16 | if sys.version_info[0] < 3: 17 | range = xrange 18 | from itertools import izip as zip 19 | 20 | __all__ = ['libsvm', 'svm_problem', 'svm_parameter', 21 | 'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'svm_forms', 22 | 'PRINT_STRING_FUN', 'kernel_names', 'c_double', 'svm_model'] 23 | 24 | try: 25 | dirname = path.dirname(path.abspath(__file__)) 26 | dynamic_lib_name = 'clib.cp*' 27 | path_to_so = glob(path.join(dirname, dynamic_lib_name))[0] 28 | libsvm = CDLL(path_to_so) 29 | except: 30 | try: 31 | if sys.platform == 'win32': 32 | libsvm = CDLL(path.join(dirname, r'..\..\windows\libsvm.dll')) 33 | else: 34 | libsvm = CDLL(path.join(dirname, '../../libsvm.so.4')) 35 | except: 36 | # For unix the prefix 'lib' is not considered. 37 | if find_library('svm'): 38 | libsvm = CDLL(find_library('svm')) 39 | elif find_library('libsvm'): 40 | libsvm = CDLL(find_library('libsvm')) 41 | else: 42 | raise Exception('LIBSVM library not found.') 43 | 44 | class svm_forms(IntEnum): 45 | C_SVC = 0 46 | NU_SVC = 1 47 | ONE_CLASS = 2 48 | EPSILON_SVR = 3 49 | NU_SVR = 4 50 | 51 | class kernel_names(IntEnum): 52 | LINEAR = 0 53 | POLY = 1 54 | RBF = 2 55 | SIGMOID = 3 56 | PRECOMPUTED = 4 57 | 58 | PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p) 59 | def print_null(s): 60 | return 61 | 62 | # In multi-threading, all threads share the same memory space of 63 | # the dynamic library (libsvm). Thus, we use a module-level 64 | # variable to keep a reference to ctypes print_null, preventing 65 | # python from garbage collecting it in thread B while thread A 66 | # still needs it. Check the usage of svm_set_print_string_function() 67 | # in LIBSVM README for details. 68 | ctypes_print_null = PRINT_STRING_FUN(print_null) 69 | 70 | def genFields(names, types): 71 | return list(zip(names, types)) 72 | 73 | def fillprototype(f, restype, argtypes): 74 | f.restype = restype 75 | f.argtypes = argtypes 76 | 77 | class svm_node(Structure): 78 | _names = ["index", "value"] 79 | _types = [c_int, c_double] 80 | _fields_ = genFields(_names, _types) 81 | 82 | def __init__(self, index=-1, value=0): 83 | self.index, self.value = index, value 84 | 85 | def __str__(self): 86 | return '%d:%g' % (self.index, self.value) 87 | 88 | def gen_svm_nodearray(xi, feature_max=None, isKernel=False): 89 | if feature_max: 90 | assert(isinstance(feature_max, int)) 91 | 92 | xi_shift = 0 # ensure correct indices of xi 93 | if scipy and isinstance(xi, tuple) and len(xi) == 2\ 94 | and isinstance(xi[0], np.ndarray) and isinstance(xi[1], np.ndarray): # for a sparse vector 95 | if not isKernel: 96 | index_range = xi[0] + 1 # index starts from 1 97 | else: 98 | index_range = xi[0] # index starts from 0 for precomputed kernel 99 | if feature_max: 100 | index_range = index_range[np.where(index_range <= feature_max)] 101 | elif scipy and isinstance(xi, np.ndarray): 102 | if not isKernel: 103 | xi_shift = 1 104 | index_range = xi.nonzero()[0] + 1 # index starts from 1 105 | else: 106 | index_range = np.arange(0, len(xi)) # index starts from 0 for precomputed kernel 107 | if feature_max: 108 | index_range = index_range[np.where(index_range <= feature_max)] 109 | elif isinstance(xi, (dict, list, tuple)): 110 | if isinstance(xi, dict): 111 | index_range = sorted(xi.keys()) 112 | elif isinstance(xi, (list, tuple)): 113 | if not isKernel: 114 | xi_shift = 1 115 | index_range = range(1, len(xi) + 1) # index starts from 1 116 | else: 117 | index_range = range(0, len(xi)) # index starts from 0 for precomputed kernel 118 | 119 | if feature_max: 120 | index_range = list(filter(lambda j: j <= feature_max, index_range)) 121 | if not isKernel: 122 | index_range = list(filter(lambda j:xi[j-xi_shift] != 0, index_range)) 123 | else: 124 | raise TypeError('xi should be a dictionary, list, tuple, 1-d numpy array, or tuple of (index, data)') 125 | 126 | ret = (svm_node*(len(index_range)+1))() 127 | ret[-1].index = -1 128 | 129 | if scipy and isinstance(xi, tuple) and len(xi) == 2\ 130 | and isinstance(xi[0], np.ndarray) and isinstance(xi[1], np.ndarray): # for a sparse vector 131 | # since xi=(indices, values), we must sort them simultaneously. 132 | for idx, arg in enumerate(np.argsort(index_range)): 133 | ret[idx].index = index_range[arg] 134 | ret[idx].value = (xi[1])[arg] 135 | else: 136 | for idx, j in enumerate(index_range): 137 | ret[idx].index = j 138 | ret[idx].value = xi[j - xi_shift] 139 | 140 | max_idx = 0 141 | if len(index_range) > 0: 142 | max_idx = index_range[-1] 143 | return ret, max_idx 144 | 145 | try: 146 | from numba import jit 147 | jit_enabled = True 148 | except: 149 | # We need to support two cases: when jit is called with no arguments, and when jit is called with 150 | # a keyword argument. 151 | def jit(func=None, *args, **kwargs): 152 | if func is None: 153 | # This handles the case where jit is used with parentheses: @jit(nopython=True) 154 | return lambda x: x 155 | else: 156 | # This handles the case where jit is used without parentheses: @jit 157 | return func 158 | jit_enabled = False 159 | 160 | @jit(nopython=True) 161 | def csr_to_problem_jit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start): 162 | for i in range(l): 163 | b1,e1 = x_rowptr[i], x_rowptr[i+1] 164 | b2,e2 = prob_rowptr[i], prob_rowptr[i+1]-1 165 | for j in range(b1,e1): 166 | prob_ind[j-b1+b2] = x_ind[j]+indx_start 167 | prob_val[j-b1+b2] = x_val[j] 168 | def csr_to_problem_nojit(l, x_val, x_ind, x_rowptr, prob_val, prob_ind, prob_rowptr, indx_start): 169 | for i in range(l): 170 | x_slice = slice(x_rowptr[i], x_rowptr[i+1]) 171 | prob_slice = slice(prob_rowptr[i], prob_rowptr[i+1]-1) 172 | prob_ind[prob_slice] = x_ind[x_slice]+indx_start 173 | prob_val[prob_slice] = x_val[x_slice] 174 | 175 | def csr_to_problem(x, prob, isKernel): 176 | if not x.has_sorted_indices: 177 | x.sort_indices() 178 | 179 | # Extra space for termination node and (possibly) bias term 180 | x_space = prob.x_space = np.empty((x.nnz+x.shape[0]), dtype=svm_node) 181 | # rowptr has to be a 64bit integer because it will later be used for pointer arithmetic, 182 | # which overflows when the added pointer points to an address that is numerically high. 183 | prob.rowptr = x.indptr.astype(np.int64, copy=True) 184 | prob.rowptr[1:] += np.arange(1,x.shape[0]+1) 185 | prob_ind = x_space["index"] 186 | prob_val = x_space["value"] 187 | prob_ind[:] = -1 188 | if not isKernel: 189 | indx_start = 1 # index starts from 1 190 | else: 191 | indx_start = 0 # index starts from 0 for precomputed kernel 192 | if jit_enabled: 193 | csr_to_problem_jit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start) 194 | else: 195 | csr_to_problem_nojit(x.shape[0], x.data, x.indices, x.indptr, prob_val, prob_ind, prob.rowptr, indx_start) 196 | 197 | class svm_problem(Structure): 198 | _names = ["l", "y", "x"] 199 | _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))] 200 | _fields_ = genFields(_names, _types) 201 | 202 | def __init__(self, y, x, isKernel=False): 203 | if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, np.ndarray))): 204 | raise TypeError("type of y: {0} is not supported!".format(type(y))) 205 | 206 | if isinstance(x, (list, tuple)): 207 | if len(y) != len(x): 208 | raise ValueError("len(y) != len(x)") 209 | elif scipy != None and isinstance(x, (np.ndarray, sparse.spmatrix)): 210 | if len(y) != x.shape[0]: 211 | raise ValueError("len(y) != len(x)") 212 | if isinstance(x, np.ndarray): 213 | x = np.ascontiguousarray(x) # enforce row-major 214 | if isinstance(x, sparse.spmatrix): 215 | x = x.tocsr() 216 | pass 217 | else: 218 | raise TypeError("type of x: {0} is not supported!".format(type(x))) 219 | self.l = l = len(y) 220 | 221 | max_idx = 0 222 | x_space = self.x_space = [] 223 | if scipy != None and isinstance(x, sparse.csr_matrix): 224 | csr_to_problem(x, self, isKernel) 225 | max_idx = x.shape[1] 226 | else: 227 | for i, xi in enumerate(x): 228 | tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel) 229 | x_space += [tmp_xi] 230 | max_idx = max(max_idx, tmp_idx) 231 | self.n = max_idx 232 | 233 | self.y = (c_double * l)() 234 | if scipy != None and isinstance(y, np.ndarray): 235 | np.ctypeslib.as_array(self.y, (self.l,))[:] = y 236 | else: 237 | for i, yi in enumerate(y): self.y[i] = yi 238 | 239 | self.x = (POINTER(svm_node) * l)() 240 | if scipy != None and isinstance(x, sparse.csr_matrix): 241 | base = addressof(self.x_space.ctypes.data_as(POINTER(svm_node))[0]) 242 | x_ptr = cast(self.x, POINTER(c_uint64)) 243 | x_ptr = np.ctypeslib.as_array(x_ptr,(self.l,)) 244 | x_ptr[:] = self.rowptr[:-1]*sizeof(svm_node)+base 245 | else: 246 | for i, xi in enumerate(self.x_space): self.x[i] = xi 247 | 248 | class svm_parameter(Structure): 249 | _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0", 250 | "cache_size", "eps", "C", "nr_weight", "weight_label", "weight", 251 | "nu", "p", "shrinking", "probability"] 252 | _types = [c_int, c_int, c_int, c_double, c_double, 253 | c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), 254 | c_double, c_double, c_int, c_int] 255 | _fields_ = genFields(_names, _types) 256 | 257 | def __init__(self, options = None): 258 | if options == None: 259 | options = '' 260 | self.parse_options(options) 261 | 262 | def __str__(self): 263 | s = '' 264 | attrs = svm_parameter._names + list(self.__dict__.keys()) 265 | values = map(lambda attr: getattr(self, attr), attrs) 266 | for attr, val in zip(attrs, values): 267 | s += (' %s: %s\n' % (attr, val)) 268 | s = s.strip() 269 | 270 | return s 271 | 272 | def set_to_default_values(self): 273 | self.svm_type = svm_forms.C_SVC; 274 | self.kernel_type = kernel_names.RBF 275 | self.degree = 3 276 | self.gamma = 0 277 | self.coef0 = 0 278 | self.nu = 0.5 279 | self.cache_size = 100 280 | self.C = 1 281 | self.eps = 0.001 282 | self.p = 0.1 283 | self.shrinking = 1 284 | self.probability = 0 285 | self.nr_weight = 0 286 | self.weight_label = None 287 | self.weight = None 288 | self.cross_validation = False 289 | self.nr_fold = 0 290 | self.print_func = cast(None, PRINT_STRING_FUN) 291 | 292 | def parse_options(self, options): 293 | if isinstance(options, list): 294 | argv = options 295 | elif isinstance(options, str): 296 | argv = options.split() 297 | else: 298 | raise TypeError("arg 1 should be a list or a str.") 299 | self.set_to_default_values() 300 | self.print_func = cast(None, PRINT_STRING_FUN) 301 | weight_label = [] 302 | weight = [] 303 | 304 | i = 0 305 | while i < len(argv): 306 | if argv[i] == "-s": 307 | i = i + 1 308 | self.svm_type = svm_forms(int(argv[i])) 309 | elif argv[i] == "-t": 310 | i = i + 1 311 | self.kernel_type = kernel_names(int(argv[i])) 312 | elif argv[i] == "-d": 313 | i = i + 1 314 | self.degree = int(argv[i]) 315 | elif argv[i] == "-g": 316 | i = i + 1 317 | self.gamma = float(argv[i]) 318 | elif argv[i] == "-r": 319 | i = i + 1 320 | self.coef0 = float(argv[i]) 321 | elif argv[i] == "-n": 322 | i = i + 1 323 | self.nu = float(argv[i]) 324 | elif argv[i] == "-m": 325 | i = i + 1 326 | self.cache_size = float(argv[i]) 327 | elif argv[i] == "-c": 328 | i = i + 1 329 | self.C = float(argv[i]) 330 | elif argv[i] == "-e": 331 | i = i + 1 332 | self.eps = float(argv[i]) 333 | elif argv[i] == "-p": 334 | i = i + 1 335 | self.p = float(argv[i]) 336 | elif argv[i] == "-h": 337 | i = i + 1 338 | self.shrinking = int(argv[i]) 339 | elif argv[i] == "-b": 340 | i = i + 1 341 | self.probability = int(argv[i]) 342 | elif argv[i] == "-q": 343 | self.print_func = ctypes_print_null 344 | elif argv[i] == "-v": 345 | i = i + 1 346 | self.cross_validation = 1 347 | self.nr_fold = int(argv[i]) 348 | if self.nr_fold < 2: 349 | raise ValueError("n-fold cross validation: n must >= 2") 350 | elif argv[i].startswith("-w"): 351 | i = i + 1 352 | self.nr_weight += 1 353 | weight_label += [int(argv[i-1][2:])] 354 | weight += [float(argv[i])] 355 | else: 356 | raise ValueError("Wrong options") 357 | i += 1 358 | 359 | libsvm.svm_set_print_string_function(self.print_func) 360 | self.weight_label = (c_int*self.nr_weight)() 361 | self.weight = (c_double*self.nr_weight)() 362 | for i in range(self.nr_weight): 363 | self.weight[i] = weight[i] 364 | self.weight_label[i] = weight_label[i] 365 | 366 | class svm_model(Structure): 367 | _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho', 368 | 'probA', 'probB', 'prob_density_marks', 'sv_indices', 369 | 'label', 'nSV', 'free_sv'] 370 | _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)), 371 | POINTER(POINTER(c_double)), POINTER(c_double), 372 | POINTER(c_double), POINTER(c_double), POINTER(c_double), 373 | POINTER(c_int), POINTER(c_int), POINTER(c_int), c_int] 374 | _fields_ = genFields(_names, _types) 375 | 376 | def __init__(self): 377 | self.__createfrom__ = 'python' 378 | 379 | def __del__(self): 380 | # free memory created by C to avoid memory leak 381 | if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C': 382 | libsvm.svm_free_and_destroy_model(pointer(pointer(self))) 383 | 384 | def get_svm_type(self): 385 | return libsvm.svm_get_svm_type(self) 386 | 387 | def get_nr_class(self): 388 | return libsvm.svm_get_nr_class(self) 389 | 390 | def get_svr_probability(self): 391 | return libsvm.svm_get_svr_probability(self) 392 | 393 | def get_labels(self): 394 | nr_class = self.get_nr_class() 395 | labels = (c_int * nr_class)() 396 | libsvm.svm_get_labels(self, labels) 397 | return labels[:nr_class] 398 | 399 | def get_sv_indices(self): 400 | total_sv = self.get_nr_sv() 401 | sv_indices = (c_int * total_sv)() 402 | libsvm.svm_get_sv_indices(self, sv_indices) 403 | return sv_indices[:total_sv] 404 | 405 | def get_nr_sv(self): 406 | return libsvm.svm_get_nr_sv(self) 407 | 408 | def is_probability_model(self): 409 | return (libsvm.svm_check_probability_model(self) == 1) 410 | 411 | def get_sv_coef(self): 412 | return [tuple(self.sv_coef[j][i] for j in range(self.nr_class - 1)) 413 | for i in range(self.l)] 414 | 415 | def get_SV(self): 416 | result = [] 417 | for sparse_sv in self.SV[:self.l]: 418 | row = dict() 419 | 420 | i = 0 421 | while True: 422 | if sparse_sv[i].index == -1: 423 | break 424 | row[sparse_sv[i].index] = sparse_sv[i].value 425 | i += 1 426 | 427 | result.append(row) 428 | return result 429 | 430 | def toPyModel(model_ptr): 431 | """ 432 | toPyModel(model_ptr) -> svm_model 433 | 434 | Convert a ctypes POINTER(svm_model) to a Python svm_model 435 | """ 436 | if bool(model_ptr) == False: 437 | raise ValueError("Null pointer") 438 | m = model_ptr.contents 439 | m.__createfrom__ = 'C' 440 | return m 441 | 442 | fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)]) 443 | fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)]) 444 | 445 | fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)]) 446 | fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p]) 447 | 448 | fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)]) 449 | fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)]) 450 | fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)]) 451 | fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)]) 452 | fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)]) 453 | fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)]) 454 | 455 | fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)]) 456 | fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)]) 457 | fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)]) 458 | 459 | fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)]) 460 | fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))]) 461 | fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)]) 462 | 463 | fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)]) 464 | fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)]) 465 | fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN]) 466 | -------------------------------------------------------------------------------- /python/libsvm/svmutil.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from .svm import * 3 | from .svm import __all__ as svm_all 4 | from .commonutil import * 5 | from .commonutil import __all__ as common_all 6 | 7 | try: 8 | import numpy as np 9 | import scipy 10 | from scipy import sparse 11 | except: 12 | scipy = None 13 | 14 | 15 | if sys.version_info[0] < 3: 16 | range = xrange 17 | from itertools import izip as zip 18 | _cstr = lambda s: s.encode("utf-8") if isinstance(s,unicode) else str(s) 19 | else: 20 | _cstr = lambda s: bytes(s, "utf-8") 21 | 22 | __all__ = ['svm_load_model', 'svm_predict', 'svm_save_model', 'svm_train'] + svm_all + common_all 23 | 24 | 25 | def svm_load_model(model_file_name): 26 | """ 27 | svm_load_model(model_file_name) -> model 28 | 29 | Load a LIBSVM model from model_file_name and return. 30 | """ 31 | model = libsvm.svm_load_model(_cstr(model_file_name)) 32 | if not model: 33 | print("can't open model file %s" % model_file_name) 34 | return None 35 | model = toPyModel(model) 36 | return model 37 | 38 | def svm_save_model(model_file_name, model): 39 | """ 40 | svm_save_model(model_file_name, model) -> None 41 | 42 | Save a LIBSVM model to the file model_file_name. 43 | """ 44 | libsvm.svm_save_model(_cstr(model_file_name), model) 45 | 46 | def svm_train(arg1, arg2=None, arg3=None): 47 | """ 48 | svm_train(y, x [, options]) -> model | ACC | MSE 49 | 50 | y: a list/tuple/ndarray of l true labels (type must be int/double). 51 | 52 | x: 1. a list/tuple of l training instances. Feature vector of 53 | each training instance is a list/tuple or dictionary. 54 | 55 | 2. an l * n numpy ndarray or scipy spmatrix (n: number of features). 56 | 57 | svm_train(prob [, options]) -> model | ACC | MSE 58 | svm_train(prob, param) -> model | ACC| MSE 59 | 60 | Train an SVM model from data (y, x) or an svm_problem prob using 61 | 'options' or an svm_parameter param. 62 | If '-v' is specified in 'options' (i.e., cross validation) 63 | either accuracy (ACC) or mean-squared error (MSE) is returned. 64 | options: 65 | -s svm_type : set type of SVM (default 0) 66 | 0 -- C-SVC (multi-class classification) 67 | 1 -- nu-SVC (multi-class classification) 68 | 2 -- one-class SVM 69 | 3 -- epsilon-SVR (regression) 70 | 4 -- nu-SVR (regression) 71 | -t kernel_type : set type of kernel function (default 2) 72 | 0 -- linear: u'*v 73 | 1 -- polynomial: (gamma*u'*v + coef0)^degree 74 | 2 -- radial basis function: exp(-gamma*|u-v|^2) 75 | 3 -- sigmoid: tanh(gamma*u'*v + coef0) 76 | 4 -- precomputed kernel (kernel values in training_set_file) 77 | -d degree : set degree in kernel function (default 3) 78 | -g gamma : set gamma in kernel function (default 1/num_features) 79 | -r coef0 : set coef0 in kernel function (default 0) 80 | -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1) 81 | -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5) 82 | -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1) 83 | -m cachesize : set cache memory size in MB (default 100) 84 | -e epsilon : set tolerance of termination criterion (default 0.001) 85 | -h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1) 86 | -b probability_estimates : whether to train a model for probability estimates, 0 or 1 (default 0) 87 | -wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1) 88 | -v n: n-fold cross validation mode 89 | -q : quiet mode (no outputs) 90 | """ 91 | prob, param = None, None 92 | if isinstance(arg1, (list, tuple)) or (scipy and isinstance(arg1, np.ndarray)): 93 | assert isinstance(arg2, (list, tuple)) or (scipy and isinstance(arg2, (np.ndarray, sparse.spmatrix))) 94 | y, x, options = arg1, arg2, arg3 95 | param = svm_parameter(options) 96 | prob = svm_problem(y, x, isKernel=(param.kernel_type == kernel_names.PRECOMPUTED)) 97 | elif isinstance(arg1, svm_problem): 98 | prob = arg1 99 | if isinstance(arg2, svm_parameter): 100 | param = arg2 101 | else: 102 | param = svm_parameter(arg2) 103 | if prob == None or param == None: 104 | raise TypeError("Wrong types for the arguments") 105 | 106 | if param.kernel_type == kernel_names.PRECOMPUTED: 107 | for i in range(prob.l): 108 | xi = prob.x[i] 109 | idx, val = xi[0].index, xi[0].value 110 | if idx != 0: 111 | raise ValueError('Wrong input format: first column must be 0:sample_serial_number') 112 | if val <= 0 or val > prob.n: 113 | raise ValueError('Wrong input format: sample_serial_number out of range') 114 | 115 | if param.gamma == 0 and prob.n > 0: 116 | param.gamma = 1.0 / prob.n 117 | libsvm.svm_set_print_string_function(param.print_func) 118 | err_msg = libsvm.svm_check_parameter(prob, param) 119 | if err_msg: 120 | raise ValueError('Error: %s' % err_msg) 121 | 122 | if param.cross_validation: 123 | l, nr_fold = prob.l, param.nr_fold 124 | target = (c_double * l)() 125 | libsvm.svm_cross_validation(prob, param, nr_fold, target) 126 | ACC, MSE, SCC = evaluations(prob.y[:l], target[:l]) 127 | if param.svm_type in [svm_forms.EPSILON_SVR, svm_forms.NU_SVR]: 128 | print("Cross Validation Mean squared error = %g" % MSE) 129 | print("Cross Validation Squared correlation coefficient = %g" % SCC) 130 | return MSE 131 | else: 132 | print("Cross Validation Accuracy = %g%%" % ACC) 133 | return ACC 134 | else: 135 | m = libsvm.svm_train(prob, param) 136 | m = toPyModel(m) 137 | 138 | # If prob is destroyed, data including SVs pointed by m can remain. 139 | m.x_space = prob.x_space 140 | return m 141 | 142 | def svm_predict(y, x, m, options=""): 143 | """ 144 | svm_predict(y, x, m [, options]) -> (p_labels, p_acc, p_vals) 145 | 146 | y: a list/tuple/ndarray of l true labels (type must be int/double). 147 | It is used for calculating the accuracy. Use [] if true labels are 148 | unavailable. 149 | 150 | x: 1. a list/tuple of l training instances. Feature vector of 151 | each training instance is a list/tuple or dictionary. 152 | 153 | 2. an l * n numpy ndarray or scipy spmatrix (n: number of features). 154 | 155 | Predict data (y, x) with the SVM model m. 156 | options: 157 | -b probability_estimates: whether to predict probability estimates, 158 | 0 or 1 (default 0). 159 | -q : quiet mode (no outputs). 160 | 161 | The return tuple contains 162 | p_labels: a list of predicted labels 163 | p_acc: a tuple including accuracy (for classification), mean-squared 164 | error, and squared correlation coefficient (for regression). 165 | p_vals: a list of decision values or probability estimates (if '-b 1' 166 | is specified). If k is the number of classes, for decision values, 167 | each element includes results of predicting k(k-1)/2 binary-class 168 | SVMs. For probabilities, each element contains k values indicating 169 | the probability that the testing instance is in each class. 170 | Note that the order of classes here is the same as 'model.label' 171 | field in the model structure. 172 | """ 173 | 174 | def info(s): 175 | print(s) 176 | 177 | if scipy and isinstance(x, np.ndarray): 178 | x = np.ascontiguousarray(x) # enforce row-major 179 | elif sparse and isinstance(x, sparse.spmatrix): 180 | x = x.tocsr() 181 | elif not isinstance(x, (list, tuple)): 182 | raise TypeError("type of x: {0} is not supported!".format(type(x))) 183 | 184 | if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, np.ndarray))): 185 | raise TypeError("type of y: {0} is not supported!".format(type(y))) 186 | 187 | predict_probability = 0 188 | argv = options.split() 189 | i = 0 190 | while i < len(argv): 191 | if argv[i] == '-b': 192 | i += 1 193 | predict_probability = int(argv[i]) 194 | elif argv[i] == '-q': 195 | info = print_null 196 | else: 197 | raise ValueError("Wrong options") 198 | i+=1 199 | 200 | svm_type = m.get_svm_type() 201 | is_prob_model = m.is_probability_model() 202 | nr_class = m.get_nr_class() 203 | pred_labels = [] 204 | pred_values = [] 205 | 206 | if scipy and isinstance(x, sparse.spmatrix): 207 | nr_instance = x.shape[0] 208 | else: 209 | nr_instance = len(x) 210 | 211 | if predict_probability: 212 | if not is_prob_model: 213 | raise ValueError("Model does not support probabiliy estimates") 214 | 215 | if svm_type in [svm_forms.NU_SVR, svm_forms.EPSILON_SVR]: 216 | info("Prob. model for test data: target value = predicted value + z,\n" 217 | "z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g" % m.get_svr_probability()); 218 | nr_class = 0 219 | 220 | prob_estimates = (c_double * nr_class)() 221 | for i in range(nr_instance): 222 | if scipy and isinstance(x, sparse.spmatrix): 223 | indslice = slice(x.indptr[i], x.indptr[i+1]) 224 | xi, idx = gen_svm_nodearray((x.indices[indslice], x.data[indslice]), isKernel=(m.param.kernel_type == kernel_names.PRECOMPUTED)) 225 | else: 226 | xi, idx = gen_svm_nodearray(x[i], isKernel=(m.param.kernel_type == kernel_names.PRECOMPUTED)) 227 | label = libsvm.svm_predict_probability(m, xi, prob_estimates) 228 | values = prob_estimates[:nr_class] 229 | pred_labels += [label] 230 | pred_values += [values] 231 | else: 232 | if is_prob_model: 233 | info("Model supports probability estimates, but disabled in predicton.") 234 | if svm_type in [svm_forms.ONE_CLASS, svm_forms.EPSILON_SVR, svm_forms.NU_SVC]: 235 | nr_classifier = 1 236 | else: 237 | nr_classifier = nr_class*(nr_class-1)//2 238 | dec_values = (c_double * nr_classifier)() 239 | for i in range(nr_instance): 240 | if scipy and isinstance(x, sparse.spmatrix): 241 | indslice = slice(x.indptr[i], x.indptr[i+1]) 242 | xi, idx = gen_svm_nodearray((x.indices[indslice], x.data[indslice]), isKernel=(m.param.kernel_type == kernel_names.PRECOMPUTED)) 243 | else: 244 | xi, idx = gen_svm_nodearray(x[i], isKernel=(m.param.kernel_type == kernel_names.PRECOMPUTED)) 245 | label = libsvm.svm_predict_values(m, xi, dec_values) 246 | if(nr_class == 1): 247 | values = [1] 248 | else: 249 | values = dec_values[:nr_classifier] 250 | pred_labels += [label] 251 | pred_values += [values] 252 | 253 | if len(y) == 0: 254 | y = [0] * nr_instance 255 | ACC, MSE, SCC = evaluations(y, pred_labels) 256 | 257 | if svm_type in [svm_forms.EPSILON_SVR, svm_forms.NU_SVR]: 258 | info("Mean squared error = %g (regression)" % MSE) 259 | info("Squared correlation coefficient = %g (regression)" % SCC) 260 | else: 261 | info("Accuracy = %g%% (%d/%d) (classification)" % (ACC, int(round(nr_instance*ACC/100)), nr_instance)) 262 | 263 | return pred_labels, (ACC, MSE, SCC), pred_values 264 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys, os 4 | from os import path 5 | from shutil import copyfile, rmtree 6 | from glob import glob 7 | 8 | from setuptools import setup, Extension 9 | from distutils.command.clean import clean as clean_cmd 10 | 11 | # a technique to build a shared library on windows 12 | from distutils.command.build_ext import build_ext 13 | 14 | build_ext.get_export_symbols = lambda x, y: [] 15 | 16 | 17 | PACKAGE_DIR = "libsvm" 18 | PACKAGE_NAME = "libsvm-official" 19 | VERSION = "3.36.0" 20 | cpp_dir = "cpp-source" 21 | # should be consistent with dynamic_lib_name in libsvm/svm.py 22 | dynamic_lib_name = "clib" 23 | 24 | # sources to be included to build the shared library 25 | source_codes = [ 26 | "svm.cpp", 27 | ] 28 | headers = [ 29 | "svm.h", 30 | "svm.def", 31 | ] 32 | 33 | # license parameters 34 | license_source = path.join("..", "COPYRIGHT") 35 | license_file = "LICENSE" 36 | license_name = "BSD-3-Clause" 37 | 38 | kwargs_for_extension = { 39 | "sources": [path.join(cpp_dir, f) for f in source_codes], 40 | "depends": [path.join(cpp_dir, f) for f in headers], 41 | "include_dirs": [cpp_dir], 42 | "language": "c++", 43 | } 44 | 45 | # see ../Makefile.win and enable openmp 46 | if sys.platform == "win32": 47 | kwargs_for_extension.update( 48 | { 49 | "define_macros": [("_WIN64", ""), ("_CRT_SECURE_NO_DEPRECATE", "")], 50 | "extra_link_args": [r"-DEF:{}\svm.def".format(cpp_dir)], 51 | "extra_compile_args": ["/openmp"], 52 | } 53 | ) 54 | else: 55 | kwargs_for_extension.update( 56 | { 57 | "extra_compile_args": ["-fopenmp"], 58 | "extra_link_args": ["-fopenmp"], 59 | } 60 | ) 61 | 62 | 63 | def create_cpp_source(): 64 | for f in source_codes + headers: 65 | src_file = path.join("..", f) 66 | tgt_file = path.join(cpp_dir, f) 67 | # ensure blas directory is created 68 | os.makedirs(path.dirname(tgt_file), exist_ok=True) 69 | copyfile(src_file, tgt_file) 70 | 71 | 72 | class CleanCommand(clean_cmd): 73 | def run(self): 74 | clean_cmd.run(self) 75 | to_be_removed = ["build/", "dist/", "MANIFEST", cpp_dir, "{}.egg-info".format(PACKAGE_NAME), license_file] 76 | to_be_removed += glob("./{}/{}.*".format(PACKAGE_DIR, dynamic_lib_name)) 77 | for root, dirs, files in os.walk(os.curdir, topdown=False): 78 | if "__pycache__" in dirs: 79 | to_be_removed.append(path.join(root, "__pycache__")) 80 | to_be_removed += [f for f in files if f.endswith(".pyc")] 81 | 82 | for f in to_be_removed: 83 | print("remove {}".format(f)) 84 | if f == ".": 85 | continue 86 | elif path.isfile(f): 87 | os.remove(f) 88 | elif path.isdir(f): 89 | rmtree(f) 90 | 91 | def main(): 92 | if not path.exists(cpp_dir): 93 | create_cpp_source() 94 | 95 | if not path.exists(license_file): 96 | copyfile(license_source, license_file) 97 | 98 | with open("README") as f: 99 | long_description = f.read() 100 | 101 | setup( 102 | name=PACKAGE_NAME, 103 | packages=[PACKAGE_DIR], 104 | version=VERSION, 105 | description="Python binding of LIBSVM", 106 | long_description=long_description, 107 | long_description_content_type="text/plain", 108 | author="ML group @ National Taiwan University", 109 | author_email="cjlin@csie.ntu.edu.tw", 110 | url="https://www.csie.ntu.edu.tw/~cjlin/libsvm", 111 | license=license_name, 112 | install_requires=["scipy"], 113 | ext_modules=[ 114 | Extension( 115 | "{}.{}".format(PACKAGE_DIR, dynamic_lib_name), **kwargs_for_extension 116 | ) 117 | ], 118 | cmdclass={"clean": CleanCommand}, 119 | ) 120 | 121 | 122 | main() 123 | 124 | -------------------------------------------------------------------------------- /svm-predict.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "svm.h" 7 | 8 | int print_null(const char *s,...) {return 0;} 9 | 10 | static int (*info)(const char *fmt,...) = &printf; 11 | 12 | struct svm_node *x; 13 | int max_nr_attr = 64; 14 | 15 | struct svm_model* model; 16 | int predict_probability=0; 17 | 18 | static char *line = NULL; 19 | static int max_line_len; 20 | 21 | static char* readline(FILE *input) 22 | { 23 | int len; 24 | 25 | if(fgets(line,max_line_len,input) == NULL) 26 | return NULL; 27 | 28 | while(strrchr(line,'\n') == NULL) 29 | { 30 | max_line_len *= 2; 31 | line = (char *) realloc(line,max_line_len); 32 | len = (int) strlen(line); 33 | if(fgets(line+len,max_line_len-len,input) == NULL) 34 | break; 35 | } 36 | return line; 37 | } 38 | 39 | void exit_input_error(int line_num) 40 | { 41 | fprintf(stderr,"Wrong input format at line %d\n", line_num); 42 | exit(1); 43 | } 44 | 45 | void predict(FILE *input, FILE *output) 46 | { 47 | int correct = 0; 48 | int total = 0; 49 | double error = 0; 50 | double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; 51 | 52 | int svm_type=svm_get_svm_type(model); 53 | int nr_class=svm_get_nr_class(model); 54 | double *prob_estimates=NULL; 55 | int j; 56 | 57 | if(predict_probability) 58 | { 59 | if (svm_type==NU_SVR || svm_type==EPSILON_SVR) 60 | info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model)); 61 | else if(svm_type==ONE_CLASS) 62 | { 63 | // nr_class = 2 for ONE_CLASS 64 | prob_estimates = (double *) malloc(nr_class*sizeof(double)); 65 | fprintf(output,"label normal outlier\n"); 66 | } 67 | else 68 | { 69 | int *labels=(int *) malloc(nr_class*sizeof(int)); 70 | svm_get_labels(model,labels); 71 | prob_estimates = (double *) malloc(nr_class*sizeof(double)); 72 | fprintf(output,"labels"); 73 | for(j=0;j start from 0 88 | 89 | label = strtok(line," \t\n"); 90 | if(label == NULL) // empty line 91 | exit_input_error(total+1); 92 | 93 | target_label = strtod(label,&endptr); 94 | if(endptr == label || *endptr != '\0') 95 | exit_input_error(total+1); 96 | 97 | while(1) 98 | { 99 | if(i>=max_nr_attr-1) // need one more for index = -1 100 | { 101 | max_nr_attr *= 2; 102 | x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node)); 103 | } 104 | 105 | idx = strtok(NULL,":"); 106 | val = strtok(NULL," \t"); 107 | 108 | if(val == NULL) 109 | break; 110 | errno = 0; 111 | x[i].index = (int) strtol(idx,&endptr,10); 112 | if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index) 113 | exit_input_error(total+1); 114 | else 115 | inst_max_index = x[i].index; 116 | 117 | errno = 0; 118 | x[i].value = strtod(val,&endptr); 119 | if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) 120 | exit_input_error(total+1); 121 | 122 | ++i; 123 | } 124 | x[i].index = -1; 125 | 126 | if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC || svm_type==ONE_CLASS)) 127 | { 128 | predict_label = svm_predict_probability(model,x,prob_estimates); 129 | fprintf(output,"%g",predict_label); 130 | for(j=0;j=argc-2) 201 | exit_with_help(); 202 | 203 | input = fopen(argv[i],"r"); 204 | if(input == NULL) 205 | { 206 | fprintf(stderr,"can't open input file %s\n",argv[i]); 207 | exit(1); 208 | } 209 | 210 | output = fopen(argv[i+2],"w"); 211 | if(output == NULL) 212 | { 213 | fprintf(stderr,"can't open output file %s\n",argv[i+2]); 214 | exit(1); 215 | } 216 | 217 | if((model=svm_load_model(argv[i+1]))==0) 218 | { 219 | fprintf(stderr,"can't open model file %s\n",argv[i+1]); 220 | exit(1); 221 | } 222 | 223 | x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node)); 224 | if(predict_probability) 225 | { 226 | if(svm_check_probability_model(model)==0) 227 | { 228 | fprintf(stderr,"Model does not support probabiliy estimates\n"); 229 | exit(1); 230 | } 231 | } 232 | else 233 | { 234 | if(svm_check_probability_model(model)!=0) 235 | info("Model supports probability estimates, but disabled in prediction.\n"); 236 | } 237 | 238 | predict(input,output); 239 | svm_free_and_destroy_model(&model); 240 | free(x); 241 | free(line); 242 | fclose(input); 243 | fclose(output); 244 | return 0; 245 | } 246 | -------------------------------------------------------------------------------- /svm-scale.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | void exit_with_help() 8 | { 9 | printf( 10 | "Usage: svm-scale [options] data_filename\n" 11 | "options:\n" 12 | "-l lower : x scaling lower limit (default -1)\n" 13 | "-u upper : x scaling upper limit (default +1)\n" 14 | "-y y_lower y_upper : y scaling limits (default: no y scaling)\n" 15 | "-s save_filename : save scaling parameters to save_filename\n" 16 | "-r restore_filename : restore scaling parameters from restore_filename\n" 17 | ); 18 | exit(1); 19 | } 20 | 21 | char *line = NULL; 22 | int max_line_len = 1024; 23 | double lower=-1.0,upper=1.0,y_lower,y_upper; 24 | int y_scaling = 0; 25 | double *feature_max; 26 | double *feature_min; 27 | double y_max = -DBL_MAX; 28 | double y_min = DBL_MAX; 29 | int max_index; 30 | int min_index; 31 | long int num_nonzeros = 0; 32 | long int new_num_nonzeros = 0; 33 | 34 | #define max(x,y) (((x)>(y))?(x):(y)) 35 | #define min(x,y) (((x)<(y))?(x):(y)) 36 | 37 | void output_target(double value); 38 | void output(int index, double value); 39 | char* readline(FILE *input); 40 | int clean_up(FILE *fp_restore, FILE *fp, const char *msg); 41 | 42 | int main(int argc,char **argv) 43 | { 44 | int i,index; 45 | FILE *fp, *fp_restore = NULL; 46 | char *save_filename = NULL; 47 | char *restore_filename = NULL; 48 | 49 | for(i=1;i lower) || (y_scaling && !(y_upper > y_lower))) 72 | { 73 | fprintf(stderr,"inconsistent lower/upper specification\n"); 74 | exit(1); 75 | } 76 | 77 | if(restore_filename && save_filename) 78 | { 79 | fprintf(stderr,"cannot use -r and -s simultaneously\n"); 80 | exit(1); 81 | } 82 | 83 | if(argc != i+1) 84 | exit_with_help(); 85 | 86 | fp=fopen(argv[i],"r"); 87 | 88 | if(fp==NULL) 89 | { 90 | fprintf(stderr,"can't open file %s\n", argv[i]); 91 | exit(1); 92 | } 93 | 94 | line = (char *) malloc(max_line_len*sizeof(char)); 95 | 96 | #define SKIP_TARGET\ 97 | while(isspace(*p)) ++p;\ 98 | while(!isspace(*p)) ++p; 99 | 100 | #define SKIP_ELEMENT\ 101 | while(*p!=':') ++p;\ 102 | ++p;\ 103 | while(isspace(*p)) ++p;\ 104 | while(*p && !isspace(*p)) ++p; 105 | 106 | /* assumption: min index of attributes is 1 */ 107 | /* pass 1: find out max index of attributes */ 108 | max_index = 0; 109 | min_index = 1; 110 | 111 | if(restore_filename) 112 | { 113 | int idx, c; 114 | 115 | fp_restore = fopen(restore_filename,"r"); 116 | if(fp_restore==NULL) 117 | { 118 | fprintf(stderr,"can't open file %s\n", restore_filename); 119 | exit(1); 120 | } 121 | 122 | c = fgetc(fp_restore); 123 | if(c == 'y') 124 | { 125 | readline(fp_restore); 126 | readline(fp_restore); 127 | readline(fp_restore); 128 | } 129 | readline(fp_restore); 130 | readline(fp_restore); 131 | 132 | while(fscanf(fp_restore,"%d %*f %*f\n",&idx) == 1) 133 | max_index = max(idx,max_index); 134 | rewind(fp_restore); 135 | } 136 | 137 | while(readline(fp)!=NULL) 138 | { 139 | char *p=line; 140 | 141 | SKIP_TARGET 142 | 143 | while(sscanf(p,"%d:%*f",&index)==1) 144 | { 145 | max_index = max(max_index, index); 146 | min_index = min(min_index, index); 147 | SKIP_ELEMENT 148 | num_nonzeros++; 149 | } 150 | } 151 | 152 | if(min_index < 1) 153 | fprintf(stderr, 154 | "WARNING: minimal feature index is %d, but indices should start from 1\n", min_index); 155 | 156 | rewind(fp); 157 | 158 | feature_max = (double *)malloc((max_index+1)* sizeof(double)); 159 | feature_min = (double *)malloc((max_index+1)* sizeof(double)); 160 | 161 | if(feature_max == NULL || feature_min == NULL) 162 | { 163 | fprintf(stderr,"can't allocate enough memory\n"); 164 | exit(1); 165 | } 166 | 167 | for(i=0;i<=max_index;i++) 168 | { 169 | feature_max[i]=-DBL_MAX; 170 | feature_min[i]=DBL_MAX; 171 | } 172 | 173 | /* pass 2: find out min/max value */ 174 | while(readline(fp)!=NULL) 175 | { 176 | char *p=line; 177 | int next_index=1; 178 | double target; 179 | double value; 180 | 181 | if (sscanf(p,"%lf",&target) != 1) 182 | return clean_up(fp_restore, fp, "ERROR: failed to read labels\n"); 183 | y_max = max(y_max,target); 184 | y_min = min(y_min,target); 185 | 186 | SKIP_TARGET 187 | 188 | while(sscanf(p,"%d:%lf",&index,&value)==2) 189 | { 190 | for(i=next_index;i num_nonzeros) 327 | fprintf(stderr, 328 | "WARNING: original #nonzeros %ld\n" 329 | " > new #nonzeros %ld\n" 330 | "If feature values are non-negative and sparse, use -l 0 rather than the default -l -1\n", 331 | num_nonzeros, new_num_nonzeros); 332 | 333 | free(line); 334 | free(feature_max); 335 | free(feature_min); 336 | fclose(fp); 337 | return 0; 338 | } 339 | 340 | char* readline(FILE *input) 341 | { 342 | int len; 343 | 344 | if(fgets(line,max_line_len,input) == NULL) 345 | return NULL; 346 | 347 | while(strrchr(line,'\n') == NULL) 348 | { 349 | max_line_len *= 2; 350 | line = (char *) realloc(line, max_line_len); 351 | len = (int) strlen(line); 352 | if(fgets(line+len,max_line_len-len,input) == NULL) 353 | break; 354 | } 355 | return line; 356 | } 357 | 358 | void output_target(double value) 359 | { 360 | if(y_scaling) 361 | { 362 | if(value == y_min) 363 | value = y_lower; 364 | else if(value == y_max) 365 | value = y_upper; 366 | else value = y_lower + (y_upper-y_lower) * 367 | (value - y_min)/(y_max-y_min); 368 | } 369 | printf("%.17g ",value); 370 | } 371 | 372 | void output(int index, double value) 373 | { 374 | /* skip single-valued attribute */ 375 | if(feature_max[index] == feature_min[index]) 376 | return; 377 | 378 | if(value == feature_min[index]) 379 | value = lower; 380 | else if(value == feature_max[index]) 381 | value = upper; 382 | else 383 | value = lower + (upper-lower) * 384 | (value-feature_min[index])/ 385 | (feature_max[index]-feature_min[index]); 386 | 387 | if(value != 0) 388 | { 389 | printf("%d:%g ",index, value); 390 | new_num_nonzeros++; 391 | } 392 | } 393 | 394 | int clean_up(FILE *fp_restore, FILE *fp, const char* msg) 395 | { 396 | fprintf(stderr, "%s", msg); 397 | free(line); 398 | free(feature_max); 399 | free(feature_min); 400 | fclose(fp); 401 | if (fp_restore) 402 | fclose(fp_restore); 403 | return -1; 404 | } 405 | 406 | -------------------------------------------------------------------------------- /svm-toy/qt/Makefile: -------------------------------------------------------------------------------- 1 | # use ``export QT_SELECT=qt5'' in a command window for using qt5 2 | # may need to adjust the path of header files 3 | CXX? = g++ 4 | INCLUDE = /usr/include/x86_64-linux-gnu/qt5 5 | CFLAGS = -Wall -O3 -I$(INCLUDE) -I$(INCLUDE)/QtWidgets -I$(INCLUDE)/QtGui -I$(INCLUDE)/QtCore -fPIC -std=c++11 6 | LIB = -lQt5Widgets -lQt5Gui -lQt5Core 7 | MOC = /usr/bin/moc 8 | 9 | svm-toy: svm-toy.cpp svm-toy.moc ../../svm.o 10 | $(CXX) $(CFLAGS) svm-toy.cpp ../../svm.o -o svm-toy $(LIB) 11 | 12 | svm-toy.moc: svm-toy.cpp 13 | $(MOC) svm-toy.cpp -o svm-toy.moc 14 | 15 | ../../svm.o: ../../svm.cpp ../../svm.h 16 | make -C ../.. svm.o 17 | 18 | clean: 19 | rm -f *~ svm-toy svm-toy.moc ../../svm.o 20 | -------------------------------------------------------------------------------- /svm-toy/qt/svm-toy.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../svm.h" 8 | using namespace std; 9 | 10 | #define DEFAULT_PARAM "-t 2 -c 100" 11 | #define XLEN 500 12 | #define YLEN 500 13 | 14 | QRgb colors[] = 15 | { 16 | qRgb(0,0,0), 17 | qRgb(0,120,120), 18 | qRgb(120,120,0), 19 | qRgb(120,0,120), 20 | qRgb(0,200,200), 21 | qRgb(200,200,0), 22 | qRgb(200,0,200) 23 | }; 24 | 25 | class SvmToyWindow : public QWidget 26 | { 27 | 28 | Q_OBJECT 29 | 30 | public: 31 | SvmToyWindow(); 32 | ~SvmToyWindow(); 33 | protected: 34 | virtual void mousePressEvent( QMouseEvent* ); 35 | virtual void paintEvent( QPaintEvent* ); 36 | 37 | private: 38 | QPixmap buffer; 39 | QPixmap icon1; 40 | QPixmap icon2; 41 | QPixmap icon3; 42 | QPushButton button_change_icon; 43 | QPushButton button_run; 44 | QPushButton button_clear; 45 | QPushButton button_save; 46 | QPushButton button_load; 47 | QLineEdit input_line; 48 | QPainter buffer_painter; 49 | struct point { 50 | double x, y; 51 | signed char value; 52 | }; 53 | list point_list; 54 | int current_value; 55 | const QPixmap& choose_icon(int v) 56 | { 57 | if(v==1) return icon1; 58 | else if(v==2) return icon2; 59 | else return icon3; 60 | } 61 | void clear_all() 62 | { 63 | point_list.clear(); 64 | buffer.fill(Qt::black); 65 | repaint(); 66 | } 67 | void draw_point(const point& p) 68 | { 69 | const QPixmap& icon = choose_icon(p.value); 70 | buffer_painter.drawPixmap((int)(p.x*XLEN),(int)(p.y*YLEN),icon); 71 | repaint(); 72 | } 73 | void draw_all_points() 74 | { 75 | for(list::iterator p = point_list.begin(); p != point_list.end();p++) 76 | draw_point(*p); 77 | } 78 | private slots: 79 | void button_change_icon_clicked() 80 | { 81 | ++current_value; 82 | if(current_value > 3) current_value = 1; 83 | button_change_icon.setIcon(choose_icon(current_value)); 84 | } 85 | void button_run_clicked() 86 | { 87 | // guard 88 | if(point_list.empty()) return; 89 | 90 | svm_parameter param; 91 | int i,j; 92 | 93 | // default values 94 | param.svm_type = C_SVC; 95 | param.kernel_type = RBF; 96 | param.degree = 3; 97 | param.gamma = 0; 98 | param.coef0 = 0; 99 | param.nu = 0.5; 100 | param.cache_size = 100; 101 | param.C = 1; 102 | param.eps = 1e-3; 103 | param.p = 0.1; 104 | param.shrinking = 1; 105 | param.probability = 0; 106 | param.nr_weight = 0; 107 | param.weight_label = NULL; 108 | param.weight = NULL; 109 | 110 | // parse options 111 | const char *p = input_line.text().toLatin1().constData(); 112 | 113 | while (1) { 114 | while (*p && *p != '-') 115 | p++; 116 | 117 | if (*p == '\0') 118 | break; 119 | 120 | p++; 121 | switch (*p++) { 122 | case 's': 123 | param.svm_type = atoi(p); 124 | break; 125 | case 't': 126 | param.kernel_type = atoi(p); 127 | break; 128 | case 'd': 129 | param.degree = atoi(p); 130 | break; 131 | case 'g': 132 | param.gamma = atof(p); 133 | break; 134 | case 'r': 135 | param.coef0 = atof(p); 136 | break; 137 | case 'n': 138 | param.nu = atof(p); 139 | break; 140 | case 'm': 141 | param.cache_size = atof(p); 142 | break; 143 | case 'c': 144 | param.C = atof(p); 145 | break; 146 | case 'e': 147 | param.eps = atof(p); 148 | break; 149 | case 'p': 150 | param.p = atof(p); 151 | break; 152 | case 'h': 153 | param.shrinking = atoi(p); 154 | break; 155 | case 'b': 156 | param.probability = atoi(p); 157 | break; 158 | case 'w': 159 | ++param.nr_weight; 160 | param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight); 161 | param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight); 162 | param.weight_label[param.nr_weight-1] = atoi(p); 163 | while(*p && !isspace(*p)) ++p; 164 | param.weight[param.nr_weight-1] = atof(p); 165 | break; 166 | } 167 | } 168 | 169 | // build problem 170 | svm_problem prob; 171 | 172 | prob.l = point_list.size(); 173 | prob.y = new double[prob.l]; 174 | 175 | if(param.kernel_type == PRECOMPUTED) 176 | { 177 | } 178 | else if(param.svm_type == EPSILON_SVR || 179 | param.svm_type == NU_SVR) 180 | { 181 | if(param.gamma == 0) param.gamma = 1; 182 | svm_node *x_space = new svm_node[2 * prob.l]; 183 | prob.x = new svm_node *[prob.l]; 184 | 185 | i = 0; 186 | for (list ::iterator q = point_list.begin(); q != point_list.end(); q++, i++) 187 | { 188 | x_space[2 * i].index = 1; 189 | x_space[2 * i].value = q->x; 190 | x_space[2 * i + 1].index = -1; 191 | prob.x[i] = &x_space[2 * i]; 192 | prob.y[i] = q->y; 193 | } 194 | 195 | // build model & classify 196 | svm_model *model = svm_train(&prob, ¶m); 197 | svm_node x[2]; 198 | x[0].index = 1; 199 | x[1].index = -1; 200 | int *j = new int[XLEN]; 201 | 202 | for (i = 0; i < XLEN; i++) 203 | { 204 | x[0].value = (double) i / XLEN; 205 | j[i] = (int)(YLEN*svm_predict(model, x)); 206 | } 207 | 208 | buffer_painter.setPen(colors[0]); 209 | buffer_painter.drawLine(0,0,0,YLEN-1); 210 | 211 | int p = (int)(param.p * YLEN); 212 | for(i = 1; i < XLEN; i++) 213 | { 214 | buffer_painter.setPen(colors[0]); 215 | buffer_painter.drawLine(i,0,i,YLEN-1); 216 | 217 | buffer_painter.setPen(colors[5]); 218 | buffer_painter.drawLine(i-1,j[i-1],i,j[i]); 219 | 220 | if(param.svm_type == EPSILON_SVR) 221 | { 222 | buffer_painter.setPen(colors[2]); 223 | buffer_painter.drawLine(i-1,j[i-1]+p,i,j[i]+p); 224 | 225 | buffer_painter.setPen(colors[2]); 226 | buffer_painter.drawLine(i-1,j[i-1]-p,i,j[i]-p); 227 | } 228 | } 229 | 230 | svm_free_and_destroy_model(&model); 231 | delete[] j; 232 | delete[] x_space; 233 | delete[] prob.x; 234 | delete[] prob.y; 235 | } 236 | else 237 | { 238 | if(param.gamma == 0) param.gamma = 0.5; 239 | svm_node *x_space = new svm_node[3 * prob.l]; 240 | prob.x = new svm_node *[prob.l]; 241 | 242 | i = 0; 243 | for (list ::iterator q = point_list.begin(); q != point_list.end(); q++, i++) 244 | { 245 | x_space[3 * i].index = 1; 246 | x_space[3 * i].value = q->x; 247 | x_space[3 * i + 1].index = 2; 248 | x_space[3 * i + 1].value = q->y; 249 | x_space[3 * i + 2].index = -1; 250 | prob.x[i] = &x_space[3 * i]; 251 | prob.y[i] = q->value; 252 | } 253 | 254 | // build model & classify 255 | svm_model *model = svm_train(&prob, ¶m); 256 | svm_node x[3]; 257 | x[0].index = 1; 258 | x[1].index = 2; 259 | x[2].index = -1; 260 | 261 | for (i = 0; i < XLEN; i++) 262 | for (j = 0; j < YLEN ; j++) { 263 | x[0].value = (double) i / XLEN; 264 | x[1].value = (double) j / YLEN; 265 | double d = svm_predict(model, x); 266 | if (param.svm_type == ONE_CLASS && d<0) d=2; 267 | buffer_painter.setPen(colors[(int)d]); 268 | buffer_painter.drawPoint(i,j); 269 | } 270 | 271 | svm_free_and_destroy_model(&model); 272 | delete[] x_space; 273 | delete[] prob.x; 274 | delete[] prob.y; 275 | } 276 | free(param.weight_label); 277 | free(param.weight); 278 | draw_all_points(); 279 | } 280 | void button_clear_clicked() 281 | { 282 | clear_all(); 283 | } 284 | void button_save_clicked() 285 | { 286 | QString filename = QFileDialog::getSaveFileName(); 287 | if(!filename.isNull()) 288 | { 289 | FILE *fp = fopen(filename.toLatin1().constData(),"w"); 290 | 291 | const char *p = input_line.text().toLatin1().constData(); 292 | const char* svm_type_str = strstr(p, "-s "); 293 | int svm_type = C_SVC; 294 | if(svm_type_str != NULL) 295 | sscanf(svm_type_str, "-s %d", &svm_type); 296 | 297 | if(fp) 298 | { 299 | if(svm_type == EPSILON_SVR || svm_type == NU_SVR) 300 | { 301 | for(list::iterator p = point_list.begin(); p != point_list.end();p++) 302 | fprintf(fp,"%f 1:%f\n", p->y, p->x); 303 | } 304 | else 305 | { 306 | for(list::iterator p = point_list.begin(); p != point_list.end();p++) 307 | fprintf(fp,"%d 1:%f 2:%f\n", p->value, p->x, p->y); 308 | } 309 | fclose(fp); 310 | } 311 | } 312 | } 313 | void button_load_clicked() 314 | { 315 | QString filename = QFileDialog::getOpenFileName(); 316 | if(!filename.isNull()) 317 | { 318 | FILE *fp = fopen(filename.toLatin1().constData(),"r"); 319 | if(fp) 320 | { 321 | clear_all(); 322 | char buf[4096]; 323 | while(fgets(buf,sizeof(buf),fp)) 324 | { 325 | int v; 326 | double x,y; 327 | if(sscanf(buf,"%d%*d:%lf%*d:%lf",&v,&x,&y)==3) 328 | { 329 | point p = {x,y,v}; 330 | point_list.push_back(p); 331 | } 332 | else if(sscanf(buf,"%lf%*d:%lf",&y,&x)==2) 333 | { 334 | point p = {x,y,current_value}; 335 | point_list.push_back(p); 336 | } 337 | else 338 | break; 339 | } 340 | fclose(fp); 341 | draw_all_points(); 342 | } 343 | } 344 | 345 | } 346 | }; 347 | 348 | #include "svm-toy.moc" 349 | 350 | SvmToyWindow::SvmToyWindow() 351 | :button_change_icon(this) 352 | ,button_run("Run",this) 353 | ,button_clear("Clear",this) 354 | ,button_save("Save",this) 355 | ,button_load("Load",this) 356 | ,input_line(this) 357 | ,current_value(1) 358 | { 359 | buffer = QPixmap(XLEN,YLEN); 360 | buffer.fill(Qt::black); 361 | 362 | buffer_painter.begin(&buffer); 363 | 364 | QObject::connect(&button_change_icon, SIGNAL(clicked()), this, 365 | SLOT(button_change_icon_clicked())); 366 | QObject::connect(&button_run, SIGNAL(clicked()), this, 367 | SLOT(button_run_clicked())); 368 | QObject::connect(&button_clear, SIGNAL(clicked()), this, 369 | SLOT(button_clear_clicked())); 370 | QObject::connect(&button_save, SIGNAL(clicked()), this, 371 | SLOT(button_save_clicked())); 372 | QObject::connect(&button_load, SIGNAL(clicked()), this, 373 | SLOT(button_load_clicked())); 374 | QObject::connect(&input_line, SIGNAL(returnPressed()), this, 375 | SLOT(button_run_clicked())); 376 | 377 | // don't blank the window before repainting 378 | setAttribute(Qt::WA_NoBackground); 379 | 380 | icon1 = QPixmap(4,4); 381 | icon2 = QPixmap(4,4); 382 | icon3 = QPixmap(4,4); 383 | 384 | 385 | QPainter painter; 386 | painter.begin(&icon1); 387 | painter.fillRect(0,0,4,4,QBrush(colors[4])); 388 | painter.end(); 389 | 390 | painter.begin(&icon2); 391 | painter.fillRect(0,0,4,4,QBrush(colors[5])); 392 | painter.end(); 393 | 394 | painter.begin(&icon3); 395 | painter.fillRect(0,0,4,4,QBrush(colors[6])); 396 | painter.end(); 397 | 398 | button_change_icon.setGeometry( 0, YLEN, 50, 25 ); 399 | button_run.setGeometry( 50, YLEN, 50, 25 ); 400 | button_clear.setGeometry( 100, YLEN, 50, 25 ); 401 | button_save.setGeometry( 150, YLEN, 50, 25); 402 | button_load.setGeometry( 200, YLEN, 50, 25); 403 | input_line.setGeometry( 250, YLEN, 250, 25); 404 | 405 | input_line.setText(DEFAULT_PARAM); 406 | button_change_icon.setIcon(icon1); 407 | } 408 | 409 | SvmToyWindow::~SvmToyWindow() 410 | { 411 | buffer_painter.end(); 412 | } 413 | 414 | void SvmToyWindow::mousePressEvent( QMouseEvent* event ) 415 | { 416 | point p = {(double)event->x()/XLEN, (double)event->y()/YLEN, current_value}; 417 | point_list.push_back(p); 418 | draw_point(p); 419 | } 420 | 421 | void SvmToyWindow::paintEvent( QPaintEvent* ) 422 | { 423 | // copy the image from the buffer pixmap to the window 424 | QPainter p(this); 425 | p.drawPixmap(0, 0, buffer); 426 | } 427 | 428 | int main( int argc, char* argv[] ) 429 | { 430 | QApplication myapp( argc, argv ); 431 | 432 | SvmToyWindow* mywidget = new SvmToyWindow(); 433 | mywidget->setGeometry( 100, 100, XLEN, YLEN+25 ); 434 | 435 | mywidget->show(); 436 | return myapp.exec(); 437 | } 438 | -------------------------------------------------------------------------------- /svm-toy/windows/svm-toy.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "../../svm.h" 8 | using namespace std; 9 | 10 | #define DEFAULT_PARAM "-t 2 -c 100" 11 | #define XLEN 500 12 | #define YLEN 500 13 | #define DrawLine(dc,x1,y1,x2,y2,c) \ 14 | do { \ 15 | HPEN hpen = CreatePen(PS_SOLID,0,c); \ 16 | HPEN horig = SelectPen(dc,hpen); \ 17 | MoveToEx(dc,x1,y1,NULL); \ 18 | LineTo(dc,x2,y2); \ 19 | SelectPen(dc,horig); \ 20 | DeletePen(hpen); \ 21 | } while(0) 22 | 23 | using namespace std; 24 | 25 | COLORREF colors[] = 26 | { 27 | RGB(0,0,0), 28 | RGB(0,120,120), 29 | RGB(120,120,0), 30 | RGB(120,0,120), 31 | RGB(0,200,200), 32 | RGB(200,200,0), 33 | RGB(200,0,200) 34 | }; 35 | 36 | HWND main_window; 37 | HBITMAP buffer; 38 | HDC window_dc; 39 | HDC buffer_dc; 40 | HBRUSH brush1, brush2, brush3; 41 | HWND edit; 42 | 43 | enum { 44 | ID_BUTTON_CHANGE, ID_BUTTON_RUN, ID_BUTTON_CLEAR, 45 | ID_BUTTON_LOAD, ID_BUTTON_SAVE, ID_EDIT 46 | }; 47 | 48 | struct point { 49 | double x, y; 50 | signed char value; 51 | }; 52 | 53 | list point_list; 54 | int current_value = 1; 55 | 56 | LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM); 57 | 58 | int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, 59 | PSTR szCmdLine, int iCmdShow) 60 | { 61 | static char szAppName[] = "SvmToy"; 62 | MSG msg; 63 | WNDCLASSEX wndclass; 64 | 65 | wndclass.cbSize = sizeof(wndclass); 66 | wndclass.style = CS_HREDRAW | CS_VREDRAW; 67 | wndclass.lpfnWndProc = WndProc; 68 | wndclass.cbClsExtra = 0; 69 | wndclass.cbWndExtra = 0; 70 | wndclass.hInstance = hInstance; 71 | wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION); 72 | wndclass.hCursor = LoadCursor(NULL, IDC_ARROW); 73 | wndclass.hbrBackground = (HBRUSH) GetStockObject(BLACK_BRUSH); 74 | wndclass.lpszMenuName = NULL; 75 | wndclass.lpszClassName = szAppName; 76 | wndclass.hIconSm = LoadIcon(NULL, IDI_APPLICATION); 77 | 78 | RegisterClassEx(&wndclass); 79 | 80 | main_window = CreateWindow(szAppName, // window class name 81 | "SVM Toy", // window caption 82 | WS_OVERLAPPEDWINDOW,// window style 83 | CW_USEDEFAULT, // initial x position 84 | CW_USEDEFAULT, // initial y position 85 | XLEN, // initial x size 86 | YLEN+52, // initial y size 87 | NULL, // parent window handle 88 | NULL, // window menu handle 89 | hInstance, // program instance handle 90 | NULL); // creation parameters 91 | 92 | ShowWindow(main_window, iCmdShow); 93 | UpdateWindow(main_window); 94 | 95 | CreateWindow("button", "Change", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON, 96 | 0, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CHANGE, hInstance, NULL); 97 | CreateWindow("button", "Run", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON, 98 | 50, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_RUN, hInstance, NULL); 99 | CreateWindow("button", "Clear", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON, 100 | 100, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CLEAR, hInstance, NULL); 101 | CreateWindow("button", "Save", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON, 102 | 150, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_SAVE, hInstance, NULL); 103 | CreateWindow("button", "Load", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON, 104 | 200, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_LOAD, hInstance, NULL); 105 | 106 | edit = CreateWindow("edit", NULL, WS_CHILD | WS_VISIBLE, 107 | 250, YLEN, 250, 25, main_window, (HMENU) ID_EDIT, hInstance, NULL); 108 | 109 | Edit_SetText(edit,DEFAULT_PARAM); 110 | 111 | brush1 = CreateSolidBrush(colors[4]); 112 | brush2 = CreateSolidBrush(colors[5]); 113 | brush3 = CreateSolidBrush(colors[6]); 114 | 115 | window_dc = GetDC(main_window); 116 | buffer = CreateCompatibleBitmap(window_dc, XLEN, YLEN); 117 | buffer_dc = CreateCompatibleDC(window_dc); 118 | SelectObject(buffer_dc, buffer); 119 | PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS); 120 | 121 | while (GetMessage(&msg, NULL, 0, 0)) { 122 | TranslateMessage(&msg); 123 | DispatchMessage(&msg); 124 | } 125 | return msg.wParam; 126 | } 127 | 128 | int getfilename( HWND hWnd , char *filename, int len, int save) 129 | { 130 | OPENFILENAME OpenFileName; 131 | memset(&OpenFileName,0,sizeof(OpenFileName)); 132 | filename[0]='\0'; 133 | 134 | OpenFileName.lStructSize = sizeof(OPENFILENAME); 135 | OpenFileName.hwndOwner = hWnd; 136 | OpenFileName.lpstrFile = filename; 137 | OpenFileName.nMaxFile = len; 138 | OpenFileName.Flags = 0; 139 | 140 | return save?GetSaveFileName(&OpenFileName):GetOpenFileName(&OpenFileName); 141 | } 142 | 143 | void clear_all() 144 | { 145 | point_list.clear(); 146 | PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS); 147 | InvalidateRect(main_window, 0, 0); 148 | } 149 | 150 | HBRUSH choose_brush(int v) 151 | { 152 | if(v==1) return brush1; 153 | else if(v==2) return brush2; 154 | else return brush3; 155 | } 156 | 157 | void draw_point(const point & p) 158 | { 159 | RECT rect; 160 | rect.left = int(p.x*XLEN); 161 | rect.top = int(p.y*YLEN); 162 | rect.right = int(p.x*XLEN) + 3; 163 | rect.bottom = int(p.y*YLEN) + 3; 164 | FillRect(window_dc, &rect, choose_brush(p.value)); 165 | FillRect(buffer_dc, &rect, choose_brush(p.value)); 166 | } 167 | 168 | void draw_all_points() 169 | { 170 | for(list::iterator p = point_list.begin(); p != point_list.end(); p++) 171 | draw_point(*p); 172 | } 173 | 174 | void button_run_clicked() 175 | { 176 | // guard 177 | if(point_list.empty()) return; 178 | 179 | svm_parameter param; 180 | int i,j; 181 | 182 | // default values 183 | param.svm_type = C_SVC; 184 | param.kernel_type = RBF; 185 | param.degree = 3; 186 | param.gamma = 0; 187 | param.coef0 = 0; 188 | param.nu = 0.5; 189 | param.cache_size = 100; 190 | param.C = 1; 191 | param.eps = 1e-3; 192 | param.p = 0.1; 193 | param.shrinking = 1; 194 | param.probability = 0; 195 | param.nr_weight = 0; 196 | param.weight_label = NULL; 197 | param.weight = NULL; 198 | 199 | // parse options 200 | char str[1024]; 201 | Edit_GetLine(edit, 0, str, sizeof(str)); 202 | const char *p = str; 203 | 204 | while (1) { 205 | while (*p && *p != '-') 206 | p++; 207 | 208 | if (*p == '\0') 209 | break; 210 | 211 | p++; 212 | switch (*p++) { 213 | case 's': 214 | param.svm_type = atoi(p); 215 | break; 216 | case 't': 217 | param.kernel_type = atoi(p); 218 | break; 219 | case 'd': 220 | param.degree = atoi(p); 221 | break; 222 | case 'g': 223 | param.gamma = atof(p); 224 | break; 225 | case 'r': 226 | param.coef0 = atof(p); 227 | break; 228 | case 'n': 229 | param.nu = atof(p); 230 | break; 231 | case 'm': 232 | param.cache_size = atof(p); 233 | break; 234 | case 'c': 235 | param.C = atof(p); 236 | break; 237 | case 'e': 238 | param.eps = atof(p); 239 | break; 240 | case 'p': 241 | param.p = atof(p); 242 | break; 243 | case 'h': 244 | param.shrinking = atoi(p); 245 | break; 246 | case 'b': 247 | param.probability = atoi(p); 248 | break; 249 | case 'w': 250 | ++param.nr_weight; 251 | param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight); 252 | param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight); 253 | param.weight_label[param.nr_weight-1] = atoi(p); 254 | while(*p && !isspace(*p)) ++p; 255 | param.weight[param.nr_weight-1] = atof(p); 256 | break; 257 | } 258 | } 259 | 260 | // build problem 261 | svm_problem prob; 262 | 263 | prob.l = point_list.size(); 264 | prob.y = new double[prob.l]; 265 | 266 | if(param.kernel_type == PRECOMPUTED) 267 | { 268 | } 269 | else if(param.svm_type == EPSILON_SVR || 270 | param.svm_type == NU_SVR) 271 | { 272 | if(param.gamma == 0) param.gamma = 1; 273 | svm_node *x_space = new svm_node[2 * prob.l]; 274 | prob.x = new svm_node *[prob.l]; 275 | 276 | i = 0; 277 | for (list::iterator q = point_list.begin(); q != point_list.end(); q++, i++) 278 | { 279 | x_space[2 * i].index = 1; 280 | x_space[2 * i].value = q->x; 281 | x_space[2 * i + 1].index = -1; 282 | prob.x[i] = &x_space[2 * i]; 283 | prob.y[i] = q->y; 284 | } 285 | 286 | // build model & classify 287 | svm_model *model = svm_train(&prob, ¶m); 288 | svm_node x[2]; 289 | x[0].index = 1; 290 | x[1].index = -1; 291 | int *j = new int[XLEN]; 292 | 293 | for (i = 0; i < XLEN; i++) 294 | { 295 | x[0].value = (double) i / XLEN; 296 | j[i] = (int)(YLEN*svm_predict(model, x)); 297 | } 298 | 299 | DrawLine(buffer_dc,0,0,0,YLEN,colors[0]); 300 | DrawLine(window_dc,0,0,0,YLEN,colors[0]); 301 | 302 | int p = (int)(param.p * YLEN); 303 | for(int i=1; i < XLEN; i++) 304 | { 305 | DrawLine(buffer_dc,i,0,i,YLEN,colors[0]); 306 | DrawLine(window_dc,i,0,i,YLEN,colors[0]); 307 | 308 | DrawLine(buffer_dc,i-1,j[i-1],i,j[i],colors[5]); 309 | DrawLine(window_dc,i-1,j[i-1],i,j[i],colors[5]); 310 | 311 | if(param.svm_type == EPSILON_SVR) 312 | { 313 | DrawLine(buffer_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]); 314 | DrawLine(window_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]); 315 | 316 | DrawLine(buffer_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]); 317 | DrawLine(window_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]); 318 | } 319 | } 320 | 321 | svm_free_and_destroy_model(&model); 322 | delete[] j; 323 | delete[] x_space; 324 | delete[] prob.x; 325 | delete[] prob.y; 326 | } 327 | else 328 | { 329 | if(param.gamma == 0) param.gamma = 0.5; 330 | svm_node *x_space = new svm_node[3 * prob.l]; 331 | prob.x = new svm_node *[prob.l]; 332 | 333 | i = 0; 334 | for (list::iterator q = point_list.begin(); q != point_list.end(); q++, i++) 335 | { 336 | x_space[3 * i].index = 1; 337 | x_space[3 * i].value = q->x; 338 | x_space[3 * i + 1].index = 2; 339 | x_space[3 * i + 1].value = q->y; 340 | x_space[3 * i + 2].index = -1; 341 | prob.x[i] = &x_space[3 * i]; 342 | prob.y[i] = q->value; 343 | } 344 | 345 | // build model & classify 346 | svm_model *model = svm_train(&prob, ¶m); 347 | svm_node x[3]; 348 | x[0].index = 1; 349 | x[1].index = 2; 350 | x[2].index = -1; 351 | 352 | for (i = 0; i < XLEN; i++) 353 | for (j = 0; j < YLEN; j++) { 354 | x[0].value = (double) i / XLEN; 355 | x[1].value = (double) j / YLEN; 356 | double d = svm_predict(model, x); 357 | if (param.svm_type == ONE_CLASS && d<0) d=2; 358 | SetPixel(window_dc, i, j, colors[(int)d]); 359 | SetPixel(buffer_dc, i, j, colors[(int)d]); 360 | } 361 | 362 | svm_free_and_destroy_model(&model); 363 | delete[] x_space; 364 | delete[] prob.x; 365 | delete[] prob.y; 366 | } 367 | free(param.weight_label); 368 | free(param.weight); 369 | draw_all_points(); 370 | } 371 | 372 | LRESULT CALLBACK WndProc(HWND hwnd, UINT iMsg, WPARAM wParam, LPARAM lParam) 373 | { 374 | HDC hdc; 375 | PAINTSTRUCT ps; 376 | 377 | switch (iMsg) { 378 | case WM_LBUTTONDOWN: 379 | { 380 | int x = LOWORD(lParam); 381 | int y = HIWORD(lParam); 382 | point p = {(double)x/XLEN, (double)y/YLEN, current_value}; 383 | point_list.push_back(p); 384 | draw_point(p); 385 | } 386 | return 0; 387 | case WM_PAINT: 388 | { 389 | hdc = BeginPaint(hwnd, &ps); 390 | BitBlt(hdc, 0, 0, XLEN, YLEN, buffer_dc, 0, 0, SRCCOPY); 391 | EndPaint(hwnd, &ps); 392 | } 393 | return 0; 394 | case WM_COMMAND: 395 | { 396 | int id = LOWORD(wParam); 397 | switch (id) { 398 | case ID_BUTTON_CHANGE: 399 | ++current_value; 400 | if(current_value > 3) current_value = 1; 401 | break; 402 | case ID_BUTTON_RUN: 403 | button_run_clicked(); 404 | break; 405 | case ID_BUTTON_CLEAR: 406 | clear_all(); 407 | break; 408 | case ID_BUTTON_SAVE: 409 | { 410 | char filename[1024]; 411 | if(getfilename(hwnd,filename,1024,1)) 412 | { 413 | FILE *fp = fopen(filename,"w"); 414 | 415 | char str[1024]; 416 | Edit_GetLine(edit, 0, str, sizeof(str)); 417 | const char *p = str; 418 | const char* svm_type_str = strstr(p, "-s "); 419 | int svm_type = C_SVC; 420 | if(svm_type_str != NULL) 421 | sscanf(svm_type_str, "-s %d", &svm_type); 422 | 423 | if(fp) 424 | { 425 | if(svm_type == EPSILON_SVR || svm_type == NU_SVR) 426 | { 427 | for(list::iterator p = point_list.begin(); p != point_list.end();p++) 428 | fprintf(fp,"%f 1:%f\n", p->y, p->x); 429 | } 430 | else 431 | { 432 | for(list::iterator p = point_list.begin(); p != point_list.end();p++) 433 | fprintf(fp,"%d 1:%f 2:%f\n", p->value, p->x, p->y); 434 | } 435 | fclose(fp); 436 | } 437 | } 438 | } 439 | break; 440 | case ID_BUTTON_LOAD: 441 | { 442 | char filename[1024]; 443 | if(getfilename(hwnd,filename,1024,0)) 444 | { 445 | FILE *fp = fopen(filename,"r"); 446 | if(fp) 447 | { 448 | clear_all(); 449 | char buf[4096]; 450 | while(fgets(buf,sizeof(buf),fp)) 451 | { 452 | int v; 453 | double x,y; 454 | if(sscanf(buf,"%d%*d:%lf%*d:%lf",&v,&x,&y)==3) 455 | { 456 | point p = {x,y,v}; 457 | point_list.push_back(p); 458 | } 459 | else if(sscanf(buf,"%lf%*d:%lf",&y,&x)==2) 460 | { 461 | point p = {x,y,current_value}; 462 | point_list.push_back(p); 463 | } 464 | else 465 | break; 466 | } 467 | fclose(fp); 468 | draw_all_points(); 469 | } 470 | } 471 | } 472 | break; 473 | } 474 | } 475 | return 0; 476 | case WM_DESTROY: 477 | PostQuitMessage(0); 478 | return 0; 479 | } 480 | 481 | return DefWindowProc(hwnd, iMsg, wParam, lParam); 482 | } 483 | -------------------------------------------------------------------------------- /svm-train.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "svm.h" 7 | #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) 8 | 9 | void print_null(const char *s) {} 10 | 11 | void exit_with_help() 12 | { 13 | printf( 14 | "Usage: svm-train [options] training_set_file [model_file]\n" 15 | "options:\n" 16 | "-s svm_type : set type of SVM (default 0)\n" 17 | " 0 -- C-SVC (multi-class classification)\n" 18 | " 1 -- nu-SVC (multi-class classification)\n" 19 | " 2 -- one-class SVM\n" 20 | " 3 -- epsilon-SVR (regression)\n" 21 | " 4 -- nu-SVR (regression)\n" 22 | "-t kernel_type : set type of kernel function (default 2)\n" 23 | " 0 -- linear: u'*v\n" 24 | " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n" 25 | " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n" 26 | " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n" 27 | " 4 -- precomputed kernel (kernel values in training_set_file)\n" 28 | "-d degree : set degree in kernel function (default 3)\n" 29 | "-g gamma : set gamma in kernel function (default 1/num_features)\n" 30 | "-r coef0 : set coef0 in kernel function (default 0)\n" 31 | "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n" 32 | "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n" 33 | "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n" 34 | "-m cachesize : set cache memory size in MB (default 100)\n" 35 | "-e epsilon : set tolerance of termination criterion (default 0.001)\n" 36 | "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n" 37 | "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n" 38 | "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n" 39 | "-v n: n-fold cross validation mode\n" 40 | "-q : quiet mode (no outputs)\n" 41 | ); 42 | exit(1); 43 | } 44 | 45 | void exit_input_error(int line_num) 46 | { 47 | fprintf(stderr,"Wrong input format at line %d\n", line_num); 48 | exit(1); 49 | } 50 | 51 | void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name); 52 | void read_problem(const char *filename); 53 | void do_cross_validation(); 54 | 55 | struct svm_parameter param; // set by parse_command_line 56 | struct svm_problem prob; // set by read_problem 57 | struct svm_model *model; 58 | struct svm_node *x_space; 59 | int cross_validation; 60 | int nr_fold; 61 | 62 | static char *line = NULL; 63 | static int max_line_len; 64 | 65 | static char* readline(FILE *input) 66 | { 67 | int len; 68 | 69 | if(fgets(line,max_line_len,input) == NULL) 70 | return NULL; 71 | 72 | while(strrchr(line,'\n') == NULL) 73 | { 74 | max_line_len *= 2; 75 | line = (char *) realloc(line,max_line_len); 76 | len = (int) strlen(line); 77 | if(fgets(line+len,max_line_len-len,input) == NULL) 78 | break; 79 | } 80 | return line; 81 | } 82 | 83 | int main(int argc, char **argv) 84 | { 85 | char input_file_name[1024]; 86 | char model_file_name[1024]; 87 | const char *error_msg; 88 | 89 | parse_command_line(argc, argv, input_file_name, model_file_name); 90 | read_problem(input_file_name); 91 | error_msg = svm_check_parameter(&prob,¶m); 92 | 93 | if(error_msg) 94 | { 95 | fprintf(stderr,"ERROR: %s\n",error_msg); 96 | exit(1); 97 | } 98 | 99 | if(cross_validation) 100 | { 101 | do_cross_validation(); 102 | } 103 | else 104 | { 105 | model = svm_train(&prob,¶m); 106 | if(svm_save_model(model_file_name,model)) 107 | { 108 | fprintf(stderr, "can't save model to file %s\n", model_file_name); 109 | exit(1); 110 | } 111 | svm_free_and_destroy_model(&model); 112 | } 113 | svm_destroy_param(¶m); 114 | free(prob.y); 115 | free(prob.x); 116 | free(x_space); 117 | free(line); 118 | 119 | return 0; 120 | } 121 | 122 | void do_cross_validation() 123 | { 124 | int i; 125 | int total_correct = 0; 126 | double total_error = 0; 127 | double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 128 | double *target = Malloc(double,prob.l); 129 | 130 | svm_cross_validation(&prob,¶m,nr_fold,target); 131 | if(param.svm_type == EPSILON_SVR || 132 | param.svm_type == NU_SVR) 133 | { 134 | for(i=0;i=argc) 189 | exit_with_help(); 190 | switch(argv[i-1][1]) 191 | { 192 | case 's': 193 | param.svm_type = atoi(argv[i]); 194 | break; 195 | case 't': 196 | param.kernel_type = atoi(argv[i]); 197 | break; 198 | case 'd': 199 | param.degree = atoi(argv[i]); 200 | break; 201 | case 'g': 202 | param.gamma = atof(argv[i]); 203 | break; 204 | case 'r': 205 | param.coef0 = atof(argv[i]); 206 | break; 207 | case 'n': 208 | param.nu = atof(argv[i]); 209 | break; 210 | case 'm': 211 | param.cache_size = atof(argv[i]); 212 | break; 213 | case 'c': 214 | param.C = atof(argv[i]); 215 | break; 216 | case 'e': 217 | param.eps = atof(argv[i]); 218 | break; 219 | case 'p': 220 | param.p = atof(argv[i]); 221 | break; 222 | case 'h': 223 | param.shrinking = atoi(argv[i]); 224 | break; 225 | case 'b': 226 | param.probability = atoi(argv[i]); 227 | break; 228 | case 'q': 229 | print_func = &print_null; 230 | i--; 231 | break; 232 | case 'v': 233 | cross_validation = 1; 234 | nr_fold = atoi(argv[i]); 235 | if(nr_fold < 2) 236 | { 237 | fprintf(stderr,"n-fold cross validation: n must >= 2\n"); 238 | exit_with_help(); 239 | } 240 | break; 241 | case 'w': 242 | ++param.nr_weight; 243 | param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight); 244 | param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight); 245 | param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]); 246 | param.weight[param.nr_weight-1] = atof(argv[i]); 247 | break; 248 | default: 249 | fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]); 250 | exit_with_help(); 251 | } 252 | } 253 | 254 | svm_set_print_string_function(print_func); 255 | 256 | // determine filenames 257 | 258 | if(i>=argc) 259 | exit_with_help(); 260 | 261 | strcpy(input_file_name, argv[i]); 262 | 263 | if(i start from 0 323 | readline(fp); 324 | prob.x[i] = &x_space[j]; 325 | label = strtok(line," \t\n"); 326 | if(label == NULL) // empty line 327 | exit_input_error(i+1); 328 | 329 | prob.y[i] = strtod(label,&endptr); 330 | if(endptr == label || *endptr != '\0') 331 | exit_input_error(i+1); 332 | 333 | while(1) 334 | { 335 | idx = strtok(NULL,":"); 336 | val = strtok(NULL," \t"); 337 | 338 | if(val == NULL) 339 | break; 340 | 341 | errno = 0; 342 | x_space[j].index = (int) strtol(idx,&endptr,10); 343 | if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index) 344 | exit_input_error(i+1); 345 | else 346 | inst_max_index = x_space[j].index; 347 | 348 | errno = 0; 349 | x_space[j].value = strtod(val,&endptr); 350 | if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr))) 351 | exit_input_error(i+1); 352 | 353 | ++j; 354 | } 355 | 356 | if(inst_max_index > max_index) 357 | max_index = inst_max_index; 358 | x_space[j++].index = -1; 359 | } 360 | 361 | if(param.gamma == 0 && max_index > 0) 362 | param.gamma = 1.0/max_index; 363 | 364 | if(param.kernel_type == PRECOMPUTED) 365 | for(i=0;i max_index) 373 | { 374 | fprintf(stderr,"Wrong input format: sample_serial_number out of range\n"); 375 | exit(1); 376 | } 377 | } 378 | 379 | fclose(fp); 380 | } 381 | -------------------------------------------------------------------------------- /svm.def: -------------------------------------------------------------------------------- 1 | LIBRARY libsvm 2 | EXPORTS 3 | svm_train @1 4 | svm_cross_validation @2 5 | svm_save_model @3 6 | svm_load_model @4 7 | svm_get_svm_type @5 8 | svm_get_nr_class @6 9 | svm_get_labels @7 10 | svm_get_svr_probability @8 11 | svm_predict_values @9 12 | svm_predict @10 13 | svm_predict_probability @11 14 | svm_free_model_content @12 15 | svm_free_and_destroy_model @13 16 | svm_destroy_param @14 17 | svm_check_parameter @15 18 | svm_check_probability_model @16 19 | svm_set_print_string_function @17 20 | svm_get_sv_indices @18 21 | svm_get_nr_sv @19 22 | -------------------------------------------------------------------------------- /svm.h: -------------------------------------------------------------------------------- 1 | #ifndef _LIBSVM_H 2 | #define _LIBSVM_H 3 | 4 | #define LIBSVM_VERSION 336 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | extern int libsvm_version; 11 | 12 | struct svm_node 13 | { 14 | int index; 15 | double value; 16 | }; 17 | 18 | struct svm_problem 19 | { 20 | int l; 21 | double *y; 22 | struct svm_node **x; 23 | }; 24 | 25 | enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */ 26 | enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */ 27 | 28 | struct svm_parameter 29 | { 30 | int svm_type; 31 | int kernel_type; 32 | int degree; /* for poly */ 33 | double gamma; /* for poly/rbf/sigmoid */ 34 | double coef0; /* for poly/sigmoid */ 35 | 36 | /* these are for training only */ 37 | double cache_size; /* in MB */ 38 | double eps; /* stopping criteria */ 39 | double C; /* for C_SVC, EPSILON_SVR and NU_SVR */ 40 | int nr_weight; /* for C_SVC */ 41 | int *weight_label; /* for C_SVC */ 42 | double* weight; /* for C_SVC */ 43 | double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */ 44 | double p; /* for EPSILON_SVR */ 45 | int shrinking; /* use the shrinking heuristics */ 46 | int probability; /* do probability estimates */ 47 | }; 48 | 49 | // 50 | // svm_model 51 | // 52 | struct svm_model 53 | { 54 | struct svm_parameter param; /* parameter */ 55 | int nr_class; /* number of classes, = 2 in regression/one class svm */ 56 | int l; /* total #SV */ 57 | struct svm_node **SV; /* SVs (SV[l]) */ 58 | double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ 59 | double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */ 60 | double *probA; /* pariwise probability information */ 61 | double *probB; 62 | double *prob_density_marks; /* probability information for ONE_CLASS */ 63 | int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */ 64 | 65 | /* for classification only */ 66 | 67 | int *label; /* label of each class (label[k]) */ 68 | int *nSV; /* number of SVs for each class (nSV[k]) */ 69 | /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */ 70 | /* XXX */ 71 | int free_sv; /* 1 if svm_model is created by svm_load_model*/ 72 | /* 0 if svm_model is created by svm_train */ 73 | }; 74 | 75 | struct svm_model *svm_train(const struct svm_problem *prob, const struct svm_parameter *param); 76 | void svm_cross_validation(const struct svm_problem *prob, const struct svm_parameter *param, int nr_fold, double *target); 77 | 78 | int svm_save_model(const char *model_file_name, const struct svm_model *model); 79 | struct svm_model *svm_load_model(const char *model_file_name); 80 | 81 | int svm_get_svm_type(const struct svm_model *model); 82 | int svm_get_nr_class(const struct svm_model *model); 83 | void svm_get_labels(const struct svm_model *model, int *label); 84 | void svm_get_sv_indices(const struct svm_model *model, int *sv_indices); 85 | int svm_get_nr_sv(const struct svm_model *model); 86 | double svm_get_svr_probability(const struct svm_model *model); 87 | 88 | double svm_predict_values(const struct svm_model *model, const struct svm_node *x, double* dec_values); 89 | double svm_predict(const struct svm_model *model, const struct svm_node *x); 90 | double svm_predict_probability(const struct svm_model *model, const struct svm_node *x, double* prob_estimates); 91 | 92 | void svm_free_model_content(struct svm_model *model_ptr); 93 | void svm_free_and_destroy_model(struct svm_model **model_ptr_ptr); 94 | void svm_destroy_param(struct svm_parameter *param); 95 | 96 | const char *svm_check_parameter(const struct svm_problem *prob, const struct svm_parameter *param); 97 | int svm_check_probability_model(const struct svm_model *model); 98 | 99 | void svm_set_print_string_function(void (*print_func)(const char *)); 100 | 101 | #ifdef __cplusplus 102 | } 103 | #endif 104 | 105 | #endif /* _LIBSVM_H */ 106 | -------------------------------------------------------------------------------- /tools/README: -------------------------------------------------------------------------------- 1 | This directory includes some useful codes: 2 | 3 | 1. subset selection tools. 4 | 2. parameter selection tools. 5 | 3. LIBSVM format checking tools 6 | 7 | Part I: Subset selection tools 8 | 9 | Introduction 10 | ============ 11 | 12 | Training large data is time consuming. Sometimes one should work on a 13 | smaller subset first. The python script subset.py randomly selects a 14 | specified number of samples. For classification data, we provide a 15 | stratified selection to ensure the same class distribution in the 16 | subset. 17 | 18 | Usage: subset.py [options] dataset number [output1] [output2] 19 | 20 | This script selects a subset of the given data set. 21 | 22 | options: 23 | -s method : method of selection (default 0) 24 | 0 -- stratified selection (classification only) 25 | 1 -- random selection 26 | 27 | output1 : the subset (optional) 28 | output2 : the rest of data (optional) 29 | 30 | If output1 is omitted, the subset will be printed on the screen. 31 | 32 | Example 33 | ======= 34 | 35 | > python subset.py heart_scale 100 file1 file2 36 | 37 | From heart_scale 100 samples are randomly selected and stored in 38 | file1. All remaining instances are stored in file2. 39 | 40 | 41 | Part II: Parameter Selection Tools 42 | 43 | Introduction 44 | ============ 45 | 46 | grid.py is a parameter selection tool for C-SVM classification using 47 | the RBF (radial basis function) kernel. It uses cross validation (CV) 48 | technique to estimate the accuracy of each parameter combination in 49 | the specified range and helps you to decide the best parameters for 50 | your problem. 51 | 52 | grid.py directly executes libsvm binaries (so no python binding is needed) 53 | for cross validation and then draw contour of CV accuracy using gnuplot. 54 | You must have libsvm and gnuplot installed before using it. The package 55 | gnuplot is available at http://www.gnuplot.info/ 56 | 57 | On Mac OSX, the precompiled gnuplot file needs the library Aquarterm, 58 | which thus must be installed as well. In addition, this version of 59 | gnuplot does not support png, so you need to change "set term png 60 | transparent small" and use other image formats. For example, you may 61 | have "set term pbm small color". 62 | 63 | Usage: grid.py [grid_options] [svm_options] dataset 64 | 65 | grid_options : 66 | -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2) 67 | begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end} 68 | "null" -- do not grid with c 69 | -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2) 70 | begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end} 71 | "null" -- do not grid with g 72 | -v n : n-fold cross validation (default 5) 73 | -svmtrain pathname : set svm executable path and name 74 | -gnuplot {pathname | "null"} : 75 | pathname -- set gnuplot executable path and name 76 | "null" -- do not plot 77 | -out {pathname | "null"} : (default dataset.out) 78 | pathname -- set output file path and name 79 | "null" -- do not output file 80 | -png pathname : set graphic output file path and name (default dataset.png) 81 | -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out) 82 | Use this option only if some parameters have been checked for the SAME data. 83 | 84 | svm_options : additional options for svm-train 85 | 86 | The program conducts v-fold cross validation using parameter C (and gamma) 87 | = 2^begin, 2^(begin+step), ..., 2^end. 88 | 89 | You can specify where the libsvm executable and gnuplot are using the 90 | -svmtrain and -gnuplot parameters. 91 | 92 | For windows users, please use pgnuplot.exe. If you are using gnuplot 93 | 3.7.1, please upgrade to version 3.7.3 or higher. The version 3.7.1 94 | has a bug. If you use cygwin on windows, please use gunplot-x11. 95 | 96 | If the task is terminated accidentally or you would like to change the 97 | range of parameters, you can apply '-resume' to save time by re-using 98 | previous results. You may specify the output file of a previous run 99 | or use the default (i.e., dataset.out) without giving a name. Please 100 | note that the same condition must be used in two runs. For example, 101 | you cannot use '-v 10' earlier and resume the task with '-v 5'. 102 | 103 | The value of some options can be "null." For example, `-log2c -1,0,1 104 | -log2 "null"' means that C=2^-1,2^0,2^1 and g=LIBSVM's default gamma 105 | value. That is, you do not conduct parameter selection on gamma. 106 | 107 | Example 108 | ======= 109 | 110 | > python grid.py -log2c -5,5,1 -log2g -4,0,1 -v 5 -m 300 heart_scale 111 | 112 | Users (in particular MS Windows users) may need to specify the path of 113 | executable files. You can either change paths in the beginning of 114 | grid.py or specify them in the command line. For example, 115 | 116 | > grid.py -log2c -5,5,1 -svmtrain "c:\Program Files\libsvm\windows\svm-train.exe" -gnuplot c:\tmp\gnuplot\binary\pgnuplot.exe -v 10 heart_scale 117 | 118 | Output: two files 119 | dataset.png: the CV accuracy contour plot generated by gnuplot 120 | dataset.out: the CV accuracy at each (log2(C),log2(gamma)) 121 | 122 | The following example saves running time by loading the output file of a previous run. 123 | 124 | > python grid.py -log2c -7,7,1 -log2g -5,2,1 -v 5 -resume heart_scale.out heart_scale 125 | 126 | Parallel grid search 127 | ==================== 128 | 129 | You can conduct a parallel grid search by dispatching jobs to a 130 | cluster of computers which share the same file system. First, you add 131 | machine names in grid.py: 132 | 133 | ssh_workers = ["linux1", "linux5", "linux5"] 134 | 135 | and then setup your ssh so that the authentication works without 136 | asking a password. 137 | 138 | The same machine (e.g., linux5 here) can be listed more than once if 139 | it has multiple CPUs or has more RAM. If the local machine is the 140 | best, you can also enlarge the nr_local_worker. For example: 141 | 142 | nr_local_worker = 2 143 | 144 | Example: 145 | 146 | > python grid.py heart_scale 147 | [local] -1 -1 78.8889 (best c=0.5, g=0.5, rate=78.8889) 148 | [linux5] -1 -7 83.3333 (best c=0.5, g=0.0078125, rate=83.3333) 149 | [linux5] 5 -1 77.037 (best c=0.5, g=0.0078125, rate=83.3333) 150 | [linux1] 5 -7 83.3333 (best c=0.5, g=0.0078125, rate=83.3333) 151 | . 152 | . 153 | . 154 | 155 | If -log2c, -log2g, or -v is not specified, default values are used. 156 | 157 | If your system uses telnet instead of ssh, you list the computer names 158 | in telnet_workers. 159 | 160 | Calling grid in Python 161 | ====================== 162 | 163 | In addition to using grid.py as a command-line tool, you can use it as a 164 | Python module. 165 | 166 | >>> rate, param = find_parameters(dataset, options) 167 | 168 | You need to specify `dataset' and `options' (default ''). See the following example. 169 | 170 | > python 171 | 172 | >>> from grid import * 173 | >>> rate, param = find_parameters('../heart_scale', '-log2c -1,1,1 -log2g -1,1,1') 174 | [local] 0.0 0.0 rate=74.8148 (best c=1.0, g=1.0, rate=74.8148) 175 | [local] 0.0 -1.0 rate=77.037 (best c=1.0, g=0.5, rate=77.037) 176 | . 177 | . 178 | [local] -1.0 -1.0 rate=78.8889 (best c=0.5, g=0.5, rate=78.8889) 179 | . 180 | . 181 | >>> rate 182 | 78.8889 183 | >>> param 184 | {'c': 0.5, 'g': 0.5} 185 | 186 | 187 | Part III: LIBSVM format checking tools 188 | 189 | Introduction 190 | ============ 191 | 192 | `svm-train' conducts only a simple check of the input data. To do a 193 | detailed check, we provide a python script `checkdata.py.' 194 | 195 | Usage: checkdata.py dataset 196 | 197 | Exit status (returned value): 1 if there are errors, 0 otherwise. 198 | 199 | This tool is written by Rong-En Fan at National Taiwan University. 200 | 201 | Example 202 | ======= 203 | 204 | > cat bad_data 205 | 1 3:1 2:4 206 | > python checkdata.py bad_data 207 | line 1: feature indices must be in an ascending order, previous/current features 3:1 2:4 208 | Found 1 lines with error. 209 | 210 | 211 | -------------------------------------------------------------------------------- /tools/checkdata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # A format checker for LIBSVM 5 | # 6 | 7 | # 8 | # Copyright (c) 2007, Rong-En Fan 9 | # 10 | # All rights reserved. 11 | # 12 | # This program is distributed under the same license of the LIBSVM package. 13 | # 14 | 15 | from sys import argv, exit 16 | import os.path 17 | 18 | def err(line_no, msg): 19 | print("line {0}: {1}".format(line_no, msg)) 20 | 21 | # works like float() but does not accept nan and inf 22 | def my_float(x): 23 | if x.lower().find("nan") != -1 or x.lower().find("inf") != -1: 24 | raise ValueError 25 | 26 | return float(x) 27 | 28 | def main(): 29 | if len(argv) != 2: 30 | print("Usage: {0} dataset".format(argv[0])) 31 | exit(1) 32 | 33 | dataset = argv[1] 34 | 35 | if not os.path.exists(dataset): 36 | print("dataset {0} not found".format(dataset)) 37 | exit(1) 38 | 39 | line_no = 1 40 | error_line_count = 0 41 | for line in open(dataset, 'r'): 42 | line_error = False 43 | 44 | # each line must end with a newline character 45 | if line[-1] != '\n': 46 | err(line_no, "missing a newline character in the end") 47 | line_error = True 48 | 49 | nodes = line.split() 50 | 51 | # check label 52 | try: 53 | label = nodes.pop(0) 54 | 55 | if label.find(',') != -1: 56 | # multi-label format 57 | try: 58 | for l in label.split(','): 59 | l = my_float(l) 60 | except: 61 | err(line_no, "label {0} is not a valid multi-label form".format(label)) 62 | line_error = True 63 | else: 64 | try: 65 | label = my_float(label) 66 | except: 67 | err(line_no, "label {0} is not a number".format(label)) 68 | line_error = True 69 | except: 70 | err(line_no, "missing label, perhaps an empty line?") 71 | line_error = True 72 | 73 | # check features 74 | prev_index = -1 75 | for i in range(len(nodes)): 76 | try: 77 | (index, value) = nodes[i].split(':') 78 | 79 | index = int(index) 80 | value = my_float(value) 81 | 82 | # precomputed kernel's index starts from 0 and LIBSVM 83 | # checks it. Hence, don't treat index 0 as an error. 84 | if index < 0: 85 | err(line_no, "feature index must be positive; wrong feature {0}".format(nodes[i])) 86 | line_error = True 87 | elif index <= prev_index: 88 | err(line_no, "feature indices must be in an ascending order, previous/current features {0} {1}".format(nodes[i-1], nodes[i])) 89 | line_error = True 90 | prev_index = index 91 | except: 92 | err(line_no, "feature '{0}' not an : pair, integer, real number ".format(nodes[i])) 93 | line_error = True 94 | 95 | line_no += 1 96 | 97 | if line_error: 98 | error_line_count += 1 99 | 100 | if error_line_count > 0: 101 | print("Found {0} lines with error.".format(error_line_count)) 102 | return 1 103 | else: 104 | print("No error.") 105 | return 0 106 | 107 | if __name__ == "__main__": 108 | exit(main()) 109 | -------------------------------------------------------------------------------- /tools/easy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | from subprocess import * 6 | 7 | if len(sys.argv) <= 1: 8 | print('Usage: {0} training_file [testing_file]'.format(sys.argv[0])) 9 | raise SystemExit 10 | 11 | # svm, grid, and gnuplot executable files 12 | 13 | is_win32 = (sys.platform == 'win32') 14 | if not is_win32: 15 | svmscale_exe = "../svm-scale" 16 | svmtrain_exe = "../svm-train" 17 | svmpredict_exe = "../svm-predict" 18 | grid_py = "./grid.py" 19 | gnuplot_exe = "/usr/bin/gnuplot" 20 | else: 21 | # example for windows 22 | svmscale_exe = r"..\windows\svm-scale.exe" 23 | svmtrain_exe = r"..\windows\svm-train.exe" 24 | svmpredict_exe = r"..\windows\svm-predict.exe" 25 | gnuplot_exe = r"c:\tmp\gnuplot\binary\pgnuplot.exe" 26 | grid_py = r".\grid.py" 27 | 28 | assert os.path.exists(svmscale_exe),"svm-scale executable not found" 29 | assert os.path.exists(svmtrain_exe),"svm-train executable not found" 30 | assert os.path.exists(svmpredict_exe),"svm-predict executable not found" 31 | assert os.path.exists(gnuplot_exe),"gnuplot executable not found" 32 | assert os.path.exists(grid_py),"grid.py not found" 33 | 34 | train_pathname = sys.argv[1] 35 | assert os.path.exists(train_pathname),"training file not found" 36 | file_name = os.path.split(train_pathname)[1] 37 | scaled_file = file_name + ".scale" 38 | model_file = file_name + ".model" 39 | range_file = file_name + ".range" 40 | 41 | if len(sys.argv) > 2: 42 | test_pathname = sys.argv[2] 43 | file_name = os.path.split(test_pathname)[1] 44 | assert os.path.exists(test_pathname),"testing file not found" 45 | scaled_test_file = file_name + ".scale" 46 | predict_test_file = file_name + ".predict" 47 | 48 | cmd = '{0} -s "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, train_pathname, scaled_file) 49 | print('Scaling training data...') 50 | Popen(cmd, shell = True, stdout = PIPE).communicate() 51 | 52 | cmd = '{0} -svmtrain "{1}" -gnuplot "{2}" "{3}"'.format(grid_py, svmtrain_exe, gnuplot_exe, scaled_file) 53 | print('Cross validation...') 54 | f = Popen(cmd, shell = True, stdout = PIPE).stdout 55 | 56 | line = '' 57 | while True: 58 | last_line = line 59 | line = f.readline() 60 | if not line: break 61 | c,g,rate = map(float,last_line.split()) 62 | 63 | print('Best c={0}, g={1} CV rate={2}'.format(c,g,rate)) 64 | 65 | cmd = '{0} -c {1} -g {2} "{3}" "{4}"'.format(svmtrain_exe,c,g,scaled_file,model_file) 66 | print('Training...') 67 | Popen(cmd, shell = True, stdout = PIPE).communicate() 68 | 69 | print('Output model: {0}'.format(model_file)) 70 | if len(sys.argv) > 2: 71 | cmd = '{0} -r "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, test_pathname, scaled_test_file) 72 | print('Scaling testing data...') 73 | Popen(cmd, shell = True, stdout = PIPE).communicate() 74 | 75 | cmd = '{0} "{1}" "{2}" "{3}"'.format(svmpredict_exe, scaled_test_file, model_file, predict_test_file) 76 | print('Testing...') 77 | Popen(cmd, shell = True).communicate() 78 | 79 | print('Output prediction: {0}'.format(predict_test_file)) 80 | -------------------------------------------------------------------------------- /tools/subset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os, sys, math, random 4 | from collections import defaultdict 5 | 6 | if sys.version_info[0] >= 3: 7 | xrange = range 8 | 9 | def exit_with_help(argv): 10 | print("""\ 11 | Usage: {0} [options] dataset subset_size [output1] [output2] 12 | 13 | This script randomly selects a subset of the dataset. 14 | 15 | options: 16 | -s method : method of selection (default 0) 17 | 0 -- stratified selection (classification only) 18 | 1 -- random selection 19 | 20 | output1 : the subset (optional) 21 | output2 : rest of the data (optional) 22 | If output1 is omitted, the subset will be printed on the screen.""".format(argv[0])) 23 | exit(1) 24 | 25 | def process_options(argv): 26 | argc = len(argv) 27 | if argc < 3: 28 | exit_with_help(argv) 29 | 30 | # default method is stratified selection 31 | method = 0 32 | subset_file = sys.stdout 33 | rest_file = None 34 | 35 | i = 1 36 | while i < argc: 37 | if argv[i][0] != "-": 38 | break 39 | if argv[i] == "-s": 40 | i = i + 1 41 | method = int(argv[i]) 42 | if method not in [0,1]: 43 | print("Unknown selection method {0}".format(method)) 44 | exit_with_help(argv) 45 | i = i + 1 46 | 47 | dataset = argv[i] 48 | subset_size = int(argv[i+1]) 49 | if i+2 < argc: 50 | subset_file = open(argv[i+2],'w') 51 | if i+3 < argc: 52 | rest_file = open(argv[i+3],'w') 53 | 54 | return dataset, subset_size, method, subset_file, rest_file 55 | 56 | def random_selection(dataset, subset_size): 57 | l = sum(1 for line in open(dataset,'r')) 58 | return sorted(random.sample(xrange(l), subset_size)) 59 | 60 | def stratified_selection(dataset, subset_size): 61 | labels = [line.split(None,1)[0] for line in open(dataset)] 62 | label_linenums = defaultdict(list) 63 | for i, label in enumerate(labels): 64 | label_linenums[label] += [i] 65 | 66 | l = len(labels) 67 | remaining = subset_size 68 | ret = [] 69 | 70 | # classes with fewer data are sampled first; otherwise 71 | # some rare classes may not be selected 72 | for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])): 73 | linenums = label_linenums[label] 74 | label_size = len(linenums) 75 | # at least one instance per class 76 | s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l))))) 77 | if s == 0: 78 | sys.stderr.write('''\ 79 | Error: failed to have at least one instance per class 80 | 1. You may have regression data. 81 | 2. Your classification data is unbalanced or too small. 82 | Please use -s 1. 83 | ''') 84 | sys.exit(-1) 85 | remaining -= s 86 | ret += [linenums[i] for i in random.sample(xrange(label_size), s)] 87 | return sorted(ret) 88 | 89 | def main(argv=sys.argv): 90 | dataset, subset_size, method, subset_file, rest_file = process_options(argv) 91 | #uncomment the following line to fix the random seed 92 | #random.seed(0) 93 | selected_lines = [] 94 | 95 | if method == 0: 96 | selected_lines = stratified_selection(dataset, subset_size) 97 | elif method == 1: 98 | selected_lines = random_selection(dataset, subset_size) 99 | 100 | #select instances based on selected_lines 101 | dataset = open(dataset,'r') 102 | prev_selected_linenum = -1 103 | for i in xrange(len(selected_lines)): 104 | for cnt in xrange(selected_lines[i]-prev_selected_linenum-1): 105 | line = dataset.readline() 106 | if rest_file: 107 | rest_file.write(line) 108 | subset_file.write(dataset.readline()) 109 | prev_selected_linenum = selected_lines[i] 110 | subset_file.close() 111 | 112 | if rest_file: 113 | for line in dataset: 114 | rest_file.write(line) 115 | rest_file.close() 116 | dataset.close() 117 | 118 | if __name__ == '__main__': 119 | main(sys.argv) 120 | 121 | -------------------------------------------------------------------------------- /windows/libsvm.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/libsvm.dll -------------------------------------------------------------------------------- /windows/libsvmread.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/libsvmread.mexw64 -------------------------------------------------------------------------------- /windows/libsvmwrite.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/libsvmwrite.mexw64 -------------------------------------------------------------------------------- /windows/svm-predict.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svm-predict.exe -------------------------------------------------------------------------------- /windows/svm-scale.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svm-scale.exe -------------------------------------------------------------------------------- /windows/svm-toy.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svm-toy.exe -------------------------------------------------------------------------------- /windows/svm-train.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svm-train.exe -------------------------------------------------------------------------------- /windows/svmpredict.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svmpredict.mexw64 -------------------------------------------------------------------------------- /windows/svmtrain.mexw64: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjlin1/libsvm/65367d015d9e368cc79f5c56924455c3a4fb5e48/windows/svmtrain.mexw64 --------------------------------------------------------------------------------