├── .gitignore ├── students-filters-master ├── StudentFilters.zip ├── filters-0.0.1-SNAPSHOT.jar ├── build.xml ├── Description.props ├── pom.xml ├── UNLICENSE.txt ├── src │ ├── test │ │ └── java │ │ │ └── test │ │ │ └── filters │ │ │ └── unsupervised │ │ │ └── attribute │ │ │ └── IndependentComponentsTest.java │ └── main │ │ └── java │ │ ├── filters │ │ ├── NegativeEntropyEstimator.java │ │ ├── LogCosh.java │ │ └── FastICA.java │ │ └── weka │ │ └── filters │ │ └── unsupervised │ │ └── attribute │ │ └── IndependentComponents.java └── README.md ├── src └── main │ ├── java │ └── com │ │ └── derek │ │ └── ml │ │ ├── models │ │ ├── FileName.java │ │ ├── ML.java │ │ ├── Options.java │ │ ├── EMModel.java │ │ ├── SVMModel.java │ │ ├── Cluster.java │ │ ├── FeatureReduction.java │ │ ├── NeuralNetworkModel.java │ │ ├── DecisionTree.java │ │ └── NearestNeighbor.java │ │ ├── Server.java │ │ ├── controllers │ │ ├── Converter.java │ │ ├── SVMController.java │ │ ├── Clustering.java │ │ ├── KNNController.java │ │ ├── NeuralNetworkController.java │ │ ├── FeatureReductionController.java │ │ └── DecisionTreeController.java │ │ └── services │ │ ├── EvaluationService.java │ │ ├── LoadData.java │ │ ├── SVMService.java │ │ ├── KNNService.java │ │ ├── NeuralNetworkService.java │ │ ├── DecisionTreeService.java │ │ ├── FileFactory.java │ │ ├── FeatureReductionService.java │ │ └── ClusterService.java │ └── resources │ ├── arffs │ ├── car_test.arff │ └── car_bin_test.arff │ └── csv │ └── car_bin_test.csv ├── pom.xml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | target 4 | .classpath 5 | .project 6 | .settings 7 | .DS_Store 8 | application.properties 9 | logs -------------------------------------------------------------------------------- /students-filters-master/StudentFilters.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmmiller612/Machine_Learning_Spring_Weka/HEAD/students-filters-master/StudentFilters.zip -------------------------------------------------------------------------------- /students-filters-master/filters-0.0.1-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmmiller612/Machine_Learning_Spring_Weka/HEAD/students-filters-master/filters-0.0.1-SNAPSHOT.jar -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/FileName.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class FileName { 5 | private String fileName; 6 | 7 | public String getFileName() { 8 | return fileName; 9 | } 10 | 11 | public void setFileName(String fileName) { 12 | this.fileName = fileName; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/ML.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class ML { 5 | 6 | protected ML.Files fileName = Files.Car; 7 | protected ML.TestType testType = ML.TestType.CrossValidation; 8 | 9 | public Files getFileName() { 10 | return fileName; 11 | } 12 | 13 | public void setFileName(Files fileName) { 14 | this.fileName = fileName; 15 | } 16 | 17 | public TestType getTestType() { 18 | return testType; 19 | } 20 | 21 | public void setTestType(TestType testType) { 22 | this.testType = testType; 23 | } 24 | 25 | public static enum Files { 26 | Boston, Census, Car, CarBin, CensusBin, CensusKm, CensusEm; 27 | } 28 | 29 | public static enum TestType { 30 | CrossValidation, TestData, Train; 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /students-filters-master/build.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/Options.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class Options { 5 | 6 | public Options(){} 7 | 8 | public Options(boolean featureSelection){ 9 | this.featureSelection = featureSelection; 10 | } 11 | 12 | public Options(boolean featureSelection, boolean noClass){ 13 | this.featureSelection = featureSelection; 14 | this.noClass = noClass; 15 | } 16 | 17 | private boolean featureSelection = false; 18 | private boolean noClass = false; 19 | 20 | public boolean isFeatureSelection() { 21 | return featureSelection; 22 | } 23 | 24 | public void setFeatureSelection(boolean featureSelection) { 25 | this.featureSelection = featureSelection; 26 | } 27 | 28 | public boolean isNoClass() { 29 | return noClass; 30 | } 31 | 32 | public void setNoClass(boolean noClass) { 33 | this.noClass = noClass; 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/Server.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml; 2 | 3 | import com.derek.ml.models.*; 4 | import com.derek.ml.services.DecisionTreeService; 5 | import com.derek.ml.services.KNNService; 6 | import com.derek.ml.services.NeuralNetworkService; 7 | import com.derek.ml.services.SVMService; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.boot.SpringApplication; 10 | import org.springframework.boot.autoconfigure.EnableAutoConfiguration; 11 | import org.springframework.boot.autoconfigure.SpringBootApplication; 12 | import org.springframework.boot.autoconfigure.web.WebMvcAutoConfiguration; 13 | import org.springframework.context.ApplicationContext; 14 | import org.springframework.context.annotation.ComponentScan; 15 | import org.springframework.context.annotation.Configuration; 16 | 17 | @SpringBootApplication 18 | public class Server extends WebMvcAutoConfiguration { 19 | 20 | public static void main (String[] args) throws Exception { 21 | ApplicationContext ctx = new SpringApplication(Server.class).run(args); 22 | System.out.println("MACHINE LEARNING IS RUNNING"); 23 | } 24 | } 25 | 26 | -------------------------------------------------------------------------------- /students-filters-master/Description.props: -------------------------------------------------------------------------------- 1 | # Template Description file for a Weka package 2 | # 3 | 4 | # Package name (required) 5 | PackageName=StudentFilters 6 | 7 | # Version (required) 8 | Version=1.0.0 9 | 10 | #Date (year-month-day) 11 | Date=2014-08-02 12 | 13 | # Title (required) 14 | Title=Student Filters 15 | 16 | # Category (recommended) 17 | Category=Preprocessing 18 | 19 | # Author (required) 20 | Author=Chris Gearhart 21 | 22 | # Maintainer (required) 23 | Maintainer=Chris Gearhart 24 | 25 | # License (required) 26 | License=Unlicense 27 | 28 | # Description (required) 29 | Description=Student's filters is a set of unsupervised learning algorithms (initially only a port of parallel FastICA for Independent Component Analysis) that should be particularly useful for the ML classes at Georgia Tech. 30 | 31 | # Package URL for obtaining the package archive (required) 32 | PackageURL=https://github.com/cgearhart/students-filters/raw/master/StudentFilters.zip 33 | 34 | # URL for further information 35 | URL=https://github.com/cgearhart/students-filters/ 36 | 37 | # Dependencies (format: packageName (equality/inequality version_number) 38 | Depends=weka (>=3.7.1) 39 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/EMModel.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class EMModel extends ML{ 5 | private int iterations = 500; 6 | private int clusters = 2; 7 | private double standardDeviations = 1.0E-6D; 8 | private FeatureSelection featureSelection = null; 9 | 10 | public int getIterations() { 11 | return iterations; 12 | } 13 | 14 | public void setIterations(int iterations) { 15 | this.iterations = iterations; 16 | } 17 | 18 | public int getClusters() { 19 | return clusters; 20 | } 21 | 22 | public void setClusters(int clusters) { 23 | this.clusters = clusters; 24 | } 25 | 26 | public double getStandardDeviations() { 27 | return standardDeviations; 28 | } 29 | 30 | public void setStandardDeviations(double standardDeviations) { 31 | this.standardDeviations = standardDeviations; 32 | } 33 | 34 | public FeatureSelection getFeatureSelection() { 35 | return featureSelection; 36 | } 37 | 38 | public void setFeatureSelection(FeatureSelection featureSelection) { 39 | this.featureSelection = featureSelection; 40 | } 41 | 42 | public enum FeatureSelection { 43 | ICA, PCA, RP, CFS; 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /students-filters-master/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4.0.0 3 | filters 4 | filters 5 | 0.0.1-SNAPSHOT 6 | WEKA filters plugin 7 | 8 | 9 | 10 | maven-compiler-plugin 11 | 3.1 12 | 13 | 1.7 14 | 1.7 15 | 16 | 17 | 18 | 19 | 20 | 21 | com.googlecode.efficient-java-matrix-library 22 | ejml 23 | 0.25 24 | 25 | 26 | nz.ac.waikato.cms.weka 27 | weka-dev 28 | 3.7.10 29 | 30 | 31 | junit 32 | junit 33 | 4.11 34 | 35 | 36 | -------------------------------------------------------------------------------- /students-filters-master/UNLICENSE.txt: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to -------------------------------------------------------------------------------- /students-filters-master/src/test/java/test/filters/unsupervised/attribute/IndependentComponentsTest.java: -------------------------------------------------------------------------------- 1 | package test.filters.unsupervised.attribute; 2 | 3 | import java.util.Random; 4 | 5 | 6 | import org.ejml.simple.SimpleMatrix; 7 | import org.junit.BeforeClass; 8 | import org.junit.Test; 9 | 10 | import weka.core.Attribute; 11 | import weka.core.CheckOptionHandler; 12 | import weka.core.Instances; 13 | import weka.filters.Filter; 14 | import weka.filters.unsupervised.attribute.IndependentComponents; 15 | 16 | public class IndependentComponentsTest { 17 | 18 | static Instances signals; 19 | 20 | @BeforeClass 21 | public static void mixSignals() { 22 | // Read instances from data file containing 2000 samples mixing two 23 | // signals, sin(2t) + U[0,0.2] and square(3t) + U[0,0.2], with the 24 | // mixing matrix [[1,1], [0.5,2]] 25 | 26 | for (int j = 0; j < 3; j++) { 27 | signals.insertAttributeAt(new Attribute("S_" + Integer.toString(j)), j); 28 | } 29 | 30 | } 31 | 32 | @Test 33 | public void testFilter() { 34 | Instances result; 35 | IndependentComponents ICF = new IndependentComponents(); 36 | 37 | try { 38 | result = Filter.useFilter(signals, ICF); 39 | } catch (Exception e) { 40 | e.printStackTrace(); 41 | } 42 | 43 | // assert something about result 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/SVMModel.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class SVMModel extends ML { 5 | 6 | private KernelType kernelType = KernelType.Linear; 7 | private boolean featureSelection = false; 8 | 9 | public SVMModel(){} 10 | 11 | public SVMModel(KernelType kernelType, ML.Files fileName){ 12 | this.kernelType = kernelType; 13 | super.setFileName(fileName); 14 | } 15 | 16 | public static enum KernelType { 17 | Linear, Sigmoid, Polynomial, RBF; 18 | } 19 | 20 | public void setFileName(Files fileName){ 21 | super.setFileName(fileName); 22 | } 23 | public Files getFileName(){ 24 | return super.getFileName(); 25 | } 26 | 27 | public KernelType getKernelType() { 28 | return kernelType; 29 | } 30 | 31 | public void setKernelType(KernelType kernelType) { 32 | this.kernelType = kernelType; 33 | } 34 | 35 | public void setTestType(TestType testType){ 36 | super.setTestType(testType); 37 | } 38 | 39 | public TestType getTestType(){ 40 | return super.getTestType(); 41 | } 42 | 43 | public boolean isFeatureSelection() { 44 | return featureSelection; 45 | } 46 | 47 | public void setFeatureSelection(boolean featureSelection) { 48 | this.featureSelection = featureSelection; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/Converter.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.ML; 4 | import com.derek.ml.models.Options; 5 | import com.derek.ml.services.FileFactory; 6 | import com.derek.ml.services.LoadData; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.stereotype.Component; 9 | import org.springframework.stereotype.Controller; 10 | import org.springframework.web.bind.annotation.RequestMapping; 11 | import org.springframework.web.bind.annotation.RequestMethod; 12 | import org.springframework.web.bind.annotation.ResponseBody; 13 | import weka.core.Instances; 14 | import weka.filters.Filter; 15 | import weka.filters.supervised.attribute.NominalToBinary; 16 | 17 | @Controller 18 | public class Converter { 19 | 20 | @Autowired 21 | private FileFactory fileFactory; 22 | 23 | @Autowired 24 | private LoadData loadData; 25 | 26 | @ResponseBody 27 | @RequestMapping(value="/convert", method={RequestMethod.GET}) 28 | public void doConvert() throws Exception{ 29 | NominalToBinary nominalToBinary = new NominalToBinary(); 30 | FileFactory.TrainTest data = fileFactory.handlePublicCensus(new Options(false, false)); 31 | nominalToBinary.setInputFormat(data.test); 32 | Instances instances = Filter.useFilter(data.test, nominalToBinary); 33 | loadData.saveToArff(instances, "justATest2.arff"); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/Cluster.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class Cluster extends ML{ 5 | private int clusters = 2; 6 | private Distances distances = Distances.Euclidean; 7 | private int iterations = 500; 8 | 9 | private FeatureSelection featureSelection = null; 10 | 11 | public int getClusters() { 12 | return clusters; 13 | } 14 | 15 | public void setClusters(int clusters) { 16 | this.clusters = clusters; 17 | } 18 | 19 | public Distances getDistances() { 20 | return distances; 21 | } 22 | 23 | public void setDistances(Distances distances) { 24 | this.distances = distances; 25 | } 26 | 27 | public enum Distances { 28 | Manhatten, Euclidean; 29 | } 30 | 31 | public int getIterations() { 32 | return iterations; 33 | } 34 | 35 | public void setIterations(int iterations) { 36 | this.iterations = iterations; 37 | } 38 | 39 | public void setFileName(Files fileName){ 40 | this.fileName = fileName; 41 | } 42 | 43 | public Files getFileName(){ 44 | return this.fileName; 45 | } 46 | 47 | public FeatureSelection getFeatureSelection() { 48 | return featureSelection; 49 | } 50 | 51 | public void setFeatureSelection(FeatureSelection featureSelection) { 52 | this.featureSelection = featureSelection; 53 | } 54 | 55 | public enum FeatureSelection { 56 | ICA, PCA, RP, CFS; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/FeatureReduction.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class FeatureReduction { 5 | private double varianceCovered = .95; 6 | private int maximumAttributeNames = 10; 7 | private double percent = 50; 8 | private int numberOfAttributes = 7; 9 | private int numberOfIterations = 100; 10 | 11 | 12 | public int getNumberOfIterations() { 13 | return numberOfIterations; 14 | } 15 | 16 | public void setNumberOfIterations(int numberOfIterations) { 17 | this.numberOfIterations = numberOfIterations; 18 | } 19 | 20 | public double getVarianceCovered() { 21 | return varianceCovered; 22 | } 23 | 24 | public void setVarianceCovered(double varianceCovered) { 25 | this.varianceCovered = varianceCovered; 26 | } 27 | 28 | public int getMaximumAttributeNames() { 29 | return maximumAttributeNames; 30 | } 31 | 32 | public void setMaximumAttributeNames(int maximumAttributeNames) { 33 | this.maximumAttributeNames = maximumAttributeNames; 34 | } 35 | 36 | public double getPercent() { 37 | return percent; 38 | } 39 | 40 | public void setPercent(double percent) { 41 | this.percent = percent; 42 | } 43 | 44 | public int getNumberOfAttributes() { 45 | return numberOfAttributes; 46 | } 47 | 48 | public void setNumberOfAttributes(int numberOfAttributes) { 49 | this.numberOfAttributes = numberOfAttributes; 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/SVMController.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.SVMModel; 4 | import com.derek.ml.services.SVMService; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.stereotype.Controller; 7 | import org.springframework.web.bind.annotation.RequestMapping; 8 | import org.springframework.web.bind.annotation.RequestMethod; 9 | import org.springframework.web.bind.annotation.ResponseBody; 10 | 11 | @Controller 12 | public class SVMController { 13 | 14 | @Autowired 15 | private SVMService svmService; 16 | 17 | @ResponseBody 18 | @RequestMapping(value ="/svm", method={RequestMethod.GET}) 19 | public String svmHandle(SVMModel svmModel) throws Exception{ 20 | return svmService.handleSVM(svmModel); 21 | } 22 | 23 | @ResponseBody 24 | @RequestMapping(value = "/svm/test", method = {RequestMethod.GET}) 25 | public String svmTestHandle(SVMModel svmModel) throws Exception{ 26 | return svmService.handleSplitData(svmModel, 1, ""); 27 | } 28 | 29 | @ResponseBody 30 | @RequestMapping(value="/svm/model", method={RequestMethod.POST}) 31 | public void createModel(SVMModel svmModel) throws Exception{ 32 | svmService.createModel(svmModel); 33 | } 34 | 35 | @ResponseBody 36 | @RequestMapping(value="/svm/model", method={RequestMethod.GET}) 37 | public String getModel(SVMModel svmModel) throws Exception{ 38 | return svmService.getModel(svmModel); 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/Clustering.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.EMModel; 4 | import com.derek.ml.models.Cluster; 5 | import com.derek.ml.services.ClusterService; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Controller; 8 | import org.springframework.web.bind.annotation.RequestMapping; 9 | import org.springframework.web.bind.annotation.RequestMethod; 10 | import org.springframework.web.bind.annotation.ResponseBody; 11 | 12 | 13 | @Controller 14 | public class Clustering { 15 | 16 | @Autowired 17 | private ClusterService clusterService; 18 | 19 | @ResponseBody 20 | @RequestMapping(value ="/kMeans", method={RequestMethod.GET}) 21 | public String handleKmeans(Cluster cluster) throws Exception{ 22 | return clusterService.handleKmeans(cluster); 23 | } 24 | 25 | @ResponseBody 26 | @RequestMapping(value = "/em", method = {RequestMethod.GET}) 27 | public String handleEM(Cluster emModel) throws Exception{ 28 | return clusterService.handleEM(emModel); 29 | } 30 | 31 | @ResponseBody 32 | @RequestMapping(value = "/em/plot", method = {RequestMethod.GET}) 33 | public void handleEMPlot() throws Exception{ 34 | clusterService.plotEMWithFeature(); 35 | } 36 | 37 | @ResponseBody 38 | @RequestMapping(value = "/kMeans/plot", method = {RequestMethod.GET}) 39 | public void handleKMPlot() throws Exception{ 40 | clusterService.plotKMWithFeature(); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/NeuralNetworkModel.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class NeuralNetworkModel extends ML{ 5 | private int epochRate = 250; 6 | private int hiddenLayers = 1; 7 | private boolean featureSelection = false; 8 | 9 | public NeuralNetworkModel(){} 10 | 11 | public NeuralNetworkModel(int hiddenLayers, int epochRate, ML.Files fileName){ 12 | this.hiddenLayers = hiddenLayers; 13 | this.epochRate = epochRate; 14 | super.setFileName(fileName); 15 | } 16 | 17 | public NeuralNetworkModel(int hiddenLayers, int epochRate, ML.Files fileName, boolean featureSelection){ 18 | this.hiddenLayers = hiddenLayers; 19 | this.epochRate = epochRate; 20 | super.setFileName(fileName); 21 | this.featureSelection = featureSelection; 22 | } 23 | 24 | public int getEpochRate() { 25 | return epochRate; 26 | } 27 | 28 | public void setEpochRate(int epochRate) { 29 | this.epochRate = epochRate; 30 | } 31 | 32 | public int getHiddenLayers() { 33 | return hiddenLayers; 34 | } 35 | 36 | public void setHiddenLayers(int hiddenLayers) { 37 | this.hiddenLayers = hiddenLayers; 38 | } 39 | 40 | public void setFileName(Files fileName){ 41 | super.setFileName(fileName); 42 | } 43 | 44 | public Files getFileName(){ 45 | return super.getFileName(); 46 | } 47 | 48 | public boolean isFeatureSelection() { 49 | return featureSelection; 50 | } 51 | 52 | public void setFeatureSelection(boolean featureSelection) { 53 | this.featureSelection = featureSelection; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/KNNController.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.NearestNeighbor; 4 | import com.derek.ml.services.FileFactory; 5 | import com.derek.ml.services.KNNService; 6 | import com.derek.ml.services.LoadData; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.stereotype.Controller; 9 | import org.springframework.web.bind.annotation.RequestMapping; 10 | import org.springframework.web.bind.annotation.RequestMethod; 11 | import org.springframework.web.bind.annotation.ResponseBody; 12 | 13 | @Controller 14 | public class KNNController { 15 | 16 | @Autowired 17 | private KNNService knnService; 18 | 19 | @ResponseBody 20 | @RequestMapping(value ="/knn", method={RequestMethod.GET}) 21 | public String knnClassify(NearestNeighbor nearestNeighbor) throws Exception{ 22 | return knnService.handleKNNService(nearestNeighbor); 23 | } 24 | 25 | @ResponseBody 26 | @RequestMapping(value ="/knn/test", method={RequestMethod.GET}) 27 | public String knnTesting(NearestNeighbor nearestNeighbor) throws Exception{ 28 | return knnService.handleSplitData(nearestNeighbor, 1, ""); 29 | } 30 | 31 | @ResponseBody 32 | @RequestMapping(value="/knn/model", method={RequestMethod.POST}) 33 | public void createModel(NearestNeighbor nearestNeighbor) throws Exception{ 34 | knnService.createModel(nearestNeighbor); 35 | } 36 | 37 | @ResponseBody 38 | @RequestMapping(value="/knn/model", method={RequestMethod.GET}) 39 | public String getModel(NearestNeighbor nearestNeighbor) throws Exception{ 40 | return knnService.getModel(nearestNeighbor); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/EvaluationService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.ML; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.stereotype.Service; 6 | import weka.classifiers.Classifier; 7 | import weka.classifiers.Evaluation; 8 | import weka.core.Instances; 9 | 10 | import java.util.Random; 11 | 12 | @Service 13 | public class EvaluationService { 14 | 15 | @Autowired 16 | FileFactory fileFactory; 17 | 18 | public String evaluateData(Instances data, Classifier classifier) throws Exception{ 19 | return evaluateData(data, classifier, 10); 20 | } 21 | 22 | public String evaluateData(Instances data, Classifier classifier, int numberOfFolds) throws Exception{ 23 | return evaluateData(data, classifier, numberOfFolds, null); 24 | } 25 | 26 | public String evaluateData(Instances data, Classifier classifier, Instances testData) throws Exception { 27 | return evaluateData(data, classifier, 10, testData); 28 | } 29 | 30 | public String evaluateData(Instances data, Classifier classifier, int numberOfFolds, Instances testData) throws Exception{ 31 | Evaluation evaluation = new Evaluation(data); 32 | if (testData == null){ 33 | evaluation.crossValidateModel(classifier, data, numberOfFolds, new Random(1)); 34 | } else { 35 | evaluation.evaluateModel(classifier, testData); 36 | } 37 | String retString = ""; 38 | retString += evaluation.toSummaryString() + " \n"; 39 | retString += evaluation.toClassDetailsString() + " \n"; 40 | retString += evaluation.toMatrixString() + " \n"; 41 | return retString; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/NeuralNetworkController.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.NeuralNetworkModel; 4 | import com.derek.ml.services.FileFactory; 5 | import com.derek.ml.services.LoadData; 6 | import com.derek.ml.services.NeuralNetworkService; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.stereotype.Controller; 9 | import org.springframework.web.bind.annotation.RequestMapping; 10 | import org.springframework.web.bind.annotation.RequestMethod; 11 | import org.springframework.web.bind.annotation.ResponseBody; 12 | 13 | @Controller 14 | public class NeuralNetworkController { 15 | 16 | @Autowired 17 | private NeuralNetworkService neuralNetworkService; 18 | 19 | @ResponseBody 20 | @RequestMapping(value ="/neuralnetwork", method={RequestMethod.GET}) 21 | public String neuralNetwork(NeuralNetworkModel neuralNetworkModel) throws Exception{ 22 | return neuralNetworkService.handleNeuralNetwork(neuralNetworkModel); 23 | } 24 | 25 | @ResponseBody 26 | @RequestMapping(value = "/neuralnetwork/test", method={RequestMethod.GET}) 27 | public String neuralnetworkTest(NeuralNetworkModel neuralNetworkModel) throws Exception{ 28 | return neuralNetworkService.handleSplitData(neuralNetworkModel, 1, ""); 29 | } 30 | 31 | @ResponseBody 32 | @RequestMapping(value="/neuralnetwork/model", method={RequestMethod.POST}) 33 | public void createModel(NeuralNetworkModel nn) throws Exception{ 34 | neuralNetworkService.createModel(nn); 35 | } 36 | 37 | @ResponseBody 38 | @RequestMapping(value="/neuralnetwork/model", method={RequestMethod.GET}) 39 | public String getModel(NeuralNetworkModel nn) throws Exception{ 40 | return neuralNetworkService.getModel(nn); 41 | } 42 | 43 | @ResponseBody 44 | @RequestMapping(value="/neuralnetwork/reduction", method={RequestMethod.GET}) 45 | public String reduction() throws Exception{ 46 | return neuralNetworkService.neuralNetworkWithReduction(); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/DecisionTree.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class DecisionTree extends ML{ 5 | 6 | private TreeType treeType = TreeType.J48; 7 | private String confidence = "0.25"; 8 | private Boolean unpruned = false; 9 | private Integer minNumObj = 2; 10 | private boolean boost = false; 11 | private boolean featureSelection = false; 12 | 13 | public DecisionTree(){} 14 | 15 | public DecisionTree(int minNumObj, String confidence, boolean boost, ML.Files fileName){ 16 | this.minNumObj = minNumObj; 17 | this.confidence = confidence; 18 | this.boost = boost; 19 | super.setFileName(fileName); 20 | } 21 | 22 | public static enum TreeType { 23 | ID3, 24 | J48, 25 | ALL 26 | } 27 | 28 | public Integer getMinNumObj() { 29 | return minNumObj; 30 | } 31 | 32 | public void setMinNumObj(Integer minNumObj) { 33 | this.minNumObj = minNumObj; 34 | } 35 | 36 | public TreeType getTreeType() { 37 | return treeType; 38 | } 39 | 40 | public void setTreeType(TreeType treeType) { 41 | this.treeType = treeType; 42 | } 43 | 44 | public String getConfidence() { 45 | return confidence; 46 | } 47 | 48 | public void setConfidence(String confidence) { 49 | this.confidence = confidence; 50 | } 51 | 52 | public Boolean getUnpruned() { 53 | return unpruned; 54 | } 55 | 56 | public void setUnpruned(Boolean unpruned) { 57 | this.unpruned = unpruned; 58 | } 59 | 60 | public boolean isBoost() { 61 | return boost; 62 | } 63 | 64 | public void setBoost(boolean boost) { 65 | this.boost = boost; 66 | } 67 | 68 | public void setFileName(Files fileName){ 69 | super.setFileName(fileName); 70 | } 71 | 72 | public Files getFileName(){ 73 | return super.getFileName(); 74 | } 75 | 76 | public boolean isFeatureSelection() { 77 | return featureSelection; 78 | } 79 | 80 | public void setFeatureSelection(boolean featureSelection) { 81 | this.featureSelection = featureSelection; 82 | } 83 | 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/FeatureReductionController.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.services.FeatureReductionService; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.stereotype.Controller; 6 | import org.springframework.web.bind.annotation.RequestMapping; 7 | import org.springframework.web.bind.annotation.RequestMethod; 8 | import org.springframework.web.bind.annotation.ResponseBody; 9 | 10 | @Controller 11 | public class FeatureReductionController { 12 | 13 | @Autowired 14 | private FeatureReductionService featureReductionService; 15 | 16 | @ResponseBody 17 | @RequestMapping(value ="/featureReduction/pca", method={RequestMethod.GET}) 18 | public String handleFeatureReduction() throws Exception{ 19 | return featureReductionService.handlePCAFeatures(); 20 | } 21 | 22 | @ResponseBody 23 | @RequestMapping(value ="/featureReduction/rp", method={RequestMethod.GET}) 24 | public String handleRP() throws Exception{ 25 | return featureReductionService.handleRandomizedProjectionFeatures(); 26 | } 27 | 28 | @ResponseBody 29 | @RequestMapping(value ="/featureReduction/ica", method={RequestMethod.GET}) 30 | public String handleICA() throws Exception{ 31 | return featureReductionService.handleICAFeatures(); 32 | } 33 | 34 | @ResponseBody 35 | @RequestMapping(value ="/featureReduction/cfs", method={RequestMethod.GET}) 36 | public String handleCfsSubsetEval() throws Exception{ 37 | return featureReductionService.handleCFSSubsetEval(); 38 | } 39 | 40 | @ResponseBody 41 | @RequestMapping(value ="/featureReduction/rp/plot", method={RequestMethod.GET}) 42 | public void handlePlotRp() throws Exception{ 43 | featureReductionService.plotRP(); 44 | } 45 | 46 | @ResponseBody 47 | @RequestMapping(value ="/featureReduction/pca/plot", method={RequestMethod.GET}) 48 | public void handlePlotPCA() throws Exception{ 49 | featureReductionService.plotPCA(); 50 | } 51 | 52 | @ResponseBody 53 | @RequestMapping(value ="/featureReduction/ica/plot", method={RequestMethod.GET}) 54 | public void handlePlotICA() throws Exception{ 55 | featureReductionService.plotICA(); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/models/NearestNeighbor.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.models; 2 | 3 | 4 | public class NearestNeighbor extends ML { 5 | 6 | private TreeTypes treeTypes = TreeTypes.Linear; 7 | private int k = 1; 8 | private boolean holdOneOut = false; 9 | private boolean useMeanError = false; 10 | private boolean boost = false; 11 | private boolean featureSelection = false; 12 | 13 | public NearestNeighbor(){} 14 | 15 | public NearestNeighbor(int k, ML.Files fileName){ 16 | this.k = k; 17 | super.setFileName(fileName); 18 | } 19 | 20 | public NearestNeighbor(int k, ML.Files fileName, boolean featureSelection){ 21 | this.k = k; 22 | super.setFileName(fileName); 23 | this.featureSelection = featureSelection; 24 | } 25 | 26 | public static enum TreeTypes { 27 | BallTree, CoverTree, Linear 28 | } 29 | 30 | public TreeTypes getTreeTypes() { 31 | return treeTypes; 32 | } 33 | 34 | public void setTreeTypes(TreeTypes treeTypes) { 35 | this.treeTypes = treeTypes; 36 | } 37 | 38 | public int getK() { 39 | return k; 40 | } 41 | 42 | public void setK(int k) { 43 | this.k = k; 44 | } 45 | 46 | public boolean isHoldOneOut() { 47 | return holdOneOut; 48 | } 49 | 50 | public void setHoldOneOut(boolean holdOneOut) { 51 | this.holdOneOut = holdOneOut; 52 | } 53 | 54 | public boolean isUseMeanError() { 55 | return useMeanError; 56 | } 57 | 58 | public void setUseMeanError(boolean useMeanError) { 59 | this.useMeanError = useMeanError; 60 | } 61 | 62 | public void setFileName(Files fileName){ 63 | super.setFileName(fileName); 64 | } 65 | 66 | public Files getFileName(){ 67 | return super.getFileName(); 68 | } 69 | 70 | public void setTestType(TestType testType){ 71 | super.setTestType(testType); 72 | } 73 | 74 | public TestType getTestType(){ 75 | return super.getTestType(); 76 | } 77 | 78 | public boolean isBoost() { 79 | return boost; 80 | } 81 | 82 | public void setBoost(boolean boost) { 83 | this.boost = boost; 84 | } 85 | 86 | public boolean isFeatureSelection() { 87 | return featureSelection; 88 | } 89 | 90 | public void setFeatureSelection(boolean featureSelection) { 91 | this.featureSelection = featureSelection; 92 | } 93 | 94 | 95 | } 96 | -------------------------------------------------------------------------------- /students-filters-master/src/main/java/filters/NegativeEntropyEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * This is free and unencumbered software released into the public domain. 3 | * 4 | * Anyone is free to copy, modify, publish, use, compile, sell, or 5 | * distribute this software, either in source code form or as a compiled 6 | * binary, for any purpose, commercial or non-commercial, and by any 7 | * means. 8 | * 9 | * In jurisdictions that recognize copyright laws, the author or authors 10 | * of this software dedicate any and all copyright interest in the 11 | * software to the public domain. We make this dedication for the benefit 12 | * of the public at large and to the detriment of our heirs and 13 | * successors. We intend this dedication to be an overt act of 14 | * relinquishment in perpetuity of all present and future rights to this 15 | * software under copyright law. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | * IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 21 | * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 22 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | * OTHER DEALINGS IN THE SOFTWARE. 24 | * 25 | * For more information, please refer to 26 | */ 27 | 28 | package filters; 29 | 30 | import org.ejml.simple.SimpleMatrix; 31 | 32 | 33 | /** 34 | * The entropy function is an approximation to neg-entropy in FastICA. Commonly 35 | * used nonlinear functions include log-cosh (recommended as a good, 36 | * general-purpose function), cubic, and exponential families. The function is 37 | * applied element-wise to each attribute of the data as part of the FastICA 38 | * algorithm. 39 | * 40 | * @author Chris Gearhart 41 | */ 42 | public interface NegativeEntropyEstimator { 43 | 44 | /** 45 | * Estimate the negative entropy of the input data matrix and store 46 | * the results. 47 | * 48 | * @param x {@link SimpleMatrix} containing column vectors of data 49 | * to transform 50 | */ 51 | abstract void estimate(SimpleMatrix x); 52 | 53 | /** 54 | * 55 | * @return {@link SimpleMatrix} containing the value of the G function 56 | * applied to each value of the input matrix 57 | */ 58 | abstract SimpleMatrix getGx(); 59 | 60 | /** 61 | * 62 | * @return {@link SimpleMatrix} containing the value of the average of 63 | * the first derivative of the G function applied to each 64 | * value of the input matrix 65 | */ 66 | abstract SimpleMatrix getG_x(); 67 | 68 | } 69 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/controllers/DecisionTreeController.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.controllers; 2 | 3 | import com.derek.ml.models.DecisionTree; 4 | import com.derek.ml.models.FileName; 5 | import com.derek.ml.services.DecisionTreeService; 6 | import com.derek.ml.services.FileFactory; 7 | import com.derek.ml.services.LoadData; 8 | import org.springframework.beans.factory.annotation.Autowired; 9 | import org.springframework.stereotype.Controller; 10 | import org.springframework.web.bind.annotation.RequestBody; 11 | import org.springframework.web.bind.annotation.RequestMapping; 12 | import org.springframework.web.bind.annotation.RequestMethod; 13 | import org.springframework.web.bind.annotation.ResponseBody; 14 | import weka.core.Instances; 15 | 16 | 17 | @Controller 18 | public class DecisionTreeController { 19 | 20 | @Autowired 21 | DecisionTreeService decisionTreeService; 22 | 23 | @Autowired 24 | LoadData loadData; 25 | 26 | @Autowired 27 | FileFactory fileFactory; 28 | 29 | @ResponseBody 30 | @RequestMapping(value ="/decisiontree", method={RequestMethod.GET}) 31 | public String getDecisionTreeAccuracy(DecisionTree decisionTree) throws Exception{ 32 | return decisionTreeService.getDecisionTreeInformation(decisionTree); 33 | } 34 | 35 | @ResponseBody 36 | @RequestMapping(value = "/regressiontree", method = {RequestMethod.GET}) 37 | public String getRegressionTree(DecisionTree decisionTree) throws Exception { 38 | return decisionTreeService.handleRegressionTree(decisionTree); 39 | } 40 | 41 | @ResponseBody 42 | @RequestMapping(value="/decisiontree/test", method={RequestMethod.GET}) 43 | public String testingError(DecisionTree decisionTree) throws Exception{ 44 | return decisionTreeService.handleSplitData(decisionTree, 1, ""); 45 | } 46 | 47 | @ResponseBody 48 | @RequestMapping(value="/decisiontree/model", method={RequestMethod.POST}) 49 | public void createModel(DecisionTree decisionTree) throws Exception{ 50 | decisionTreeService.createModel(decisionTree); 51 | } 52 | 53 | @ResponseBody 54 | @RequestMapping(value="/decisiontree/model", method={RequestMethod.GET}) 55 | public String getModel(DecisionTree decisionTree) throws Exception{ 56 | return decisionTreeService.getModel(decisionTree); 57 | } 58 | 59 | @ResponseBody 60 | @RequestMapping(value ="/createArff", method={RequestMethod.POST}) 61 | public void createArff(@RequestBody FileName fileName) throws Exception{ 62 | Instances instances = loadData.getDataFromCsvFile(fileName.getFileName() + ".csv"); 63 | loadData.saveToArff(instances, fileName.getFileName() + ".arff"); 64 | } 65 | 66 | @ResponseBody 67 | @RequestMapping(value ="/discretize", method={RequestMethod.POST}) 68 | public void discretizeCensus() throws Exception{ 69 | fileFactory.saveDiscretizedArff(); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/LoadData.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import org.springframework.stereotype.Service; 4 | import weka.classifiers.Classifier; 5 | import weka.core.Instance; 6 | import weka.core.Instances; 7 | import weka.core.converters.ArffLoader; 8 | import weka.core.converters.ArffSaver; 9 | import weka.core.converters.CSVLoader; 10 | 11 | import java.io.*; 12 | 13 | @Service 14 | public class LoadData { 15 | 16 | public Instances getDataFromCsvFile(String fileName) throws Exception { 17 | CSVLoader loader = new CSVLoader(); 18 | loader.setSource(new File("src/main/resources/csv/" + fileName)); 19 | return loader.getDataSet(); 20 | } 21 | 22 | public void saveToArff(Instances instances, String fileName) throws IOException { 23 | ArffSaver arffSaver = new ArffSaver(); 24 | arffSaver.setInstances(instances); 25 | arffSaver.setFile(new File("src/main/resources/" + fileName)); 26 | arffSaver.setDestination(new File("src/main/resources/arffs/" + fileName)); 27 | arffSaver.writeBatch(); 28 | } 29 | 30 | public Instances getDataFromArff(String fileName) throws IOException { 31 | BufferedReader reader = new BufferedReader(new FileReader("src/main/resources/arffs/" + fileName)); 32 | ArffLoader.ArffReader arff = new ArffLoader.ArffReader(reader, 100000); 33 | Instances data = arff.getStructure(); 34 | data.setClassIndex(data.numAttributes() - 1); 35 | Instance inst; 36 | while ((inst = arff.readInstance(data)) != null){ 37 | data.add(inst); 38 | } 39 | return data; 40 | } 41 | 42 | public Instances getDataFromArff(String fileName, boolean noClass) throws IOException { 43 | BufferedReader reader = new BufferedReader(new FileReader("src/main/resources/arffs/" + fileName)); 44 | ArffLoader.ArffReader arff = new ArffLoader.ArffReader(reader, 100000); 45 | Instances data = arff.getStructure(); 46 | if (!noClass){ 47 | data.setClassIndex(data.numAttributes() - 1); 48 | } 49 | Instance inst; 50 | while ((inst = arff.readInstance(data)) != null){ 51 | data.add(inst); 52 | } 53 | return data; 54 | } 55 | 56 | public void saveModel(Classifier cls, String name) throws Exception{ 57 | ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream("src/main/resources/models/" + name)); 58 | objectOutputStream.writeObject(cls); 59 | objectOutputStream.flush(); 60 | objectOutputStream.close(); 61 | } 62 | 63 | public Classifier getModel(String name) throws Exception{ 64 | ObjectInputStream oos = new ObjectInputStream(new FileInputStream("src/main/resources/models/" + name)); 65 | Classifier cls = (Classifier)oos.readObject(); 66 | oos.close(); 67 | return cls; 68 | 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /students-filters-master/src/main/java/filters/LogCosh.java: -------------------------------------------------------------------------------- 1 | /* 2 | * This is free and unencumbered software released into the public domain. 3 | * 4 | * Anyone is free to copy, modify, publish, use, compile, sell, or 5 | * distribute this software, either in source code form or as a compiled 6 | * binary, for any purpose, commercial or non-commercial, and by any 7 | * means. 8 | * 9 | * In jurisdictions that recognize copyright laws, the author or authors 10 | * of this software dedicate any and all copyright interest in the 11 | * software to the public domain. We make this dedication for the benefit 12 | * of the public at large and to the detriment of our heirs and 13 | * successors. We intend this dedication to be an overt act of 14 | * relinquishment in perpetuity of all present and future rights to this 15 | * software under copyright law. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | * IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 21 | * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 22 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | * OTHER DEALINGS IN THE SOFTWARE. 24 | * 25 | * For more information, please refer to 26 | */ 27 | 28 | package filters; 29 | 30 | import org.ejml.simple.SimpleMatrix; 31 | 32 | 33 | /** 34 | * Default function used in FastICA to approxmate neg-entropy. 35 | * 36 | * @author Chris Gearhart 37 | * 38 | */ 39 | public class LogCosh implements NegativeEntropyEstimator { 40 | 41 | // Element-wise application of the neg-entropy function applied to data matrix 42 | private SimpleMatrix gx; 43 | 44 | // Column-wise average of the first derivative of gx; i.e., the average of 1 - gx[i]**2 45 | private SimpleMatrix g_x; 46 | 47 | // Scaling factor 48 | private final double alpha; 49 | 50 | public LogCosh(double alpha) { 51 | this.alpha = alpha; 52 | } 53 | 54 | public LogCosh() { 55 | this(1.); 56 | } 57 | 58 | /** 59 | * 60 | * @param x - {@link SimpleMatrix} of column vectors for each feature 61 | */ 62 | @Override 63 | public void estimate(SimpleMatrix x) { 64 | double val; 65 | double tmp; 66 | int m = x.numRows(); 67 | int n = x.numCols(); 68 | 69 | gx = new SimpleMatrix(m, n); 70 | g_x = new SimpleMatrix(1, n); 71 | for (int j = 0; j < n; j++) { 72 | tmp = 0; 73 | for (int i = 0; i < m; i++) { 74 | val = Math.tanh(x.get(i, j)); 75 | gx.set(i, j, val); 76 | tmp += alpha * (1 - Math.pow(val, 2)); 77 | } 78 | g_x.set(0, j, tmp / new Double(m)); 79 | } 80 | 81 | } 82 | 83 | @Override 84 | public SimpleMatrix getGx() { 85 | return gx; 86 | } 87 | 88 | @Override 89 | public SimpleMatrix getG_x() { 90 | return g_x; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.derek.ml 8 | derek-assignment-1 9 | 0.1.0 10 | jar 11 | derek-assignment-1 12 | 13 | 14 | org.springframework.boot 15 | spring-boot-starter-parent 16 | 1.2.6.RELEASE 17 | 18 | 19 | 20 | 21 | UTF-8 22 | UTF-8 23 | 4.2.0.RELEASE 24 | 1.8.6 25 | 1.7 26 | 27 | 28 | 29 | 30 | org.springframework.boot 31 | spring-boot-starter-web 32 | 33 | 34 | nz.ac.waikato.cms.weka 35 | weka-dev 36 | 3.7.5 37 | 38 | 39 | nz.ac.waikato.cms.weka 40 | LibSVM 41 | 1.0.4 42 | 43 | 44 | filters 45 | filters 46 | 0.0.1-SNAPSHOT 47 | 48 | 49 | com.googlecode.efficient-java-matrix-library 50 | ejml 51 | 0.25 52 | 53 | 54 | 55 | 56 | 57 | 58 | org.springframework.boot 59 | spring-boot-maven-plugin 60 | 61 | 62 | org.apache.maven.plugins 63 | maven-compiler-plugin 64 | 3.0 65 | 66 | 1.7 67 | 1.7 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | spring-releases 76 | Spring Releases 77 | https://repo.spring.io/libs-release 78 | 79 | 80 | org.jboss.repository.releases 81 | JBoss Maven Release Repository 82 | https://repository.jboss.org/nexus/content/repositories/releases 83 | 84 | 85 | project.local 86 | project 87 | file:${project.basedir}/students-filters-master 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /students-filters-master/README.md: -------------------------------------------------------------------------------- 1 | # Students.Filters 2 | 3 | ## Introduction 4 | 5 | Students.Filters is a [package](http://weka.wikispaces.com/Packages) that provides [unsupervised learning](http://en.wikipedia.org/wiki/Unsupervised_learning) filters for the [WEKA](http://www.cs.waikato.ac.nz/~ml/weka/index.html) machine learning toolkit version >3.7. Development will prioritize filters that are useful to students taking machine learning at Georgia Tech; initially only an [Independent Component Analysis](http://en.wikipedia.org/wiki/Independent_component_analysis) filter using the [FastICA](http://research.ics.aalto.fi/ica/newindex.shtml) algorithm has been implemented. 6 | 7 | ## Installation 8 | 9 | The preferred installation method is to use the WEKA package manager. The git repository contains additional files for an Eclipse project with Maven dependencies for the [EJML](https://code.google.com/p/efficient-java-matrix-library/) package, and Ant build files for the `jar`. 10 | 11 | ### WEKA Package Manager 12 | 13 | See instructions on the WEKA [homepage](http://weka.wikispaces.com/How+do+I+use+the+package+manager%3F). If the package is not available from the official package page, it can be installed directly from: 14 | 15 | https://github.com/cgearhart/students-filters/raw/master/StudentFilters.zip 16 | 17 | ### Git Repository 18 | 19 | The source code & package file can be intalled from git: 20 | 21 | git clone https://github.com/cgearhart/students-filters.git 22 | 23 | ## Use 24 | 25 | The filter can be used like other WEKA filters from the command line, from the WEKA GUI, or directly within your own Java code. The specific options for each file can be found in the source code, documentation, or from the command line with the `-h` flag. 26 | 27 | ### Command Line 28 | 29 | Read the [instructions](http://weka.wikispaces.com/How+do+I+use+WEKA+from+command+line%3F) first. Make sure that `weka.jar` and the `StudentFilters.jar` files are in the classpath and in order. Options for each filter can be determined with the `-h` argument. The filter can then be directly invoked (or chained like other WEKA filters), e.g.: 30 | 31 | java -cp /weka.jar:/studentfilters.jar weka.filters.unsupervised.attribute.IndependentComponent -i -o -W -A -1 -N 200 -T 1E-4 32 | 33 | ### IDE 34 | 35 | The FastICA algorithm is implemented indepdent of WEKA, so it can be included without adding WEKA to your project by including the `StudentFilters.jar` file and importing `filters.FastICA`. However, using the WEKA-compatible IndepdentComponents filter requires the `weka.jar` in the classpath, and can be imported as `weka.filters.unsupervised.attribute.IndependentComponents`. See the WEKA [documentation](http://weka.wikispaces.com/Use+WEKA+in+your+Java+code) for more details. 36 | 37 | The `build.xml` file can be used with [Apache Ant](http://ant.apache.org/) to rebuild `StudentFilters.jar` by running: 38 | 39 | ant build 40 | 41 | NOTE: the EMJL library needs to be installed on your system in the expected location; follow the [instructions](https://code.google.com/p/efficient-java-matrix-library/) to install it with [Maven](http://maven.apache.org/). 42 | 43 | ### WEKA GUI 44 | 45 | Once the filter is installed with the package manager, or has been simply unzipped to the package folder on the weka path, it will automatically appear in the WEKA gui. (The GUI must usually be restarted after new packages are added.) See the WEKA [documentation](http://weka.wikispaces.com/How+do+I+use+the+package+manager%3F) for more details. 46 | 47 | ### Alternative Maven Build 48 | 49 | The `pom.xml` file can be used with [Apache Maven](http://maven.apache.org/) to rebuild `filters-0.0.1-SNAPSHOT.jar` by running: 50 | 51 | mvn install -Dmaven.test.skip=true 52 | 53 | NOTE: dependencies will be handled automatically by Maven. 54 | 55 | GUI can then be launched with 56 | 57 | java -Xmx1g -classpath /.m2/repository/com/googlecode/efficient-java-matrix-library/ejml/0.25/ejml-0.25.jar:/.m2/repository/nz/ac/waikato/cms/weka/weka-dev/3.7.10/weka-dev-3.7.10.jar:/.m2/repository/net/sf/squirrel-sql/thirdparty-non-maven/java-cup/0.11a/java-cup-0.11a.jar:/.m2/repository/org/pentaho/pentaho-commons/pentaho-package-manager/1.0.8/pentaho-package-manager-1.0.8.jar:/.m2/repository/junit/junit/4.11/junit-4.11.jar:/.m2/repository/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar weka.gui.Main 58 | 59 | ## License 60 | 61 | The filters are dependent on [WEKA](http://www.cs.waikato.ac.nz/~ml/weka/index.html) (licensed under [GPL](http://www.gnu.org/licenses/gpl.html)) and the Efficient Java Matrix Library ([EJML](https://code.google.com/p/efficient-java-matrix-library/)) (licensed under [Apache License 2.0](http://www.apache.org/licenses/LICENSE-2.0)). The [FastICA](http://research.ics.aalto.fi/ica/newindex.shtml) algorithm is released under the [GPL](http://research.ics.aalto.fi/ica/fastica/about.shtml). The implementation in this package is based on the [scikit-learn](http://scikit-learn.org/stable/index.html) implementation which is released under [BSD](https://github.com/scikit-learn/scikit-learn/blob/master/COPYING). To the extent that there may be any original copyright, it is licensed under the [Unlicense](http://unlicense.org/) - i.e., it is released to the Public Domain. -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/SVMService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.ML; 4 | import com.derek.ml.models.Options; 5 | import com.derek.ml.models.SVMModel; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Service; 8 | import weka.classifiers.Classifier; 9 | import weka.classifiers.functions.LibSVM; 10 | import weka.core.Instances; 11 | import weka.core.SelectedTag; 12 | 13 | @Service 14 | public class SVMService { 15 | 16 | @Autowired 17 | private FileFactory fileFactory; 18 | 19 | @Autowired 20 | private EvaluationService evaluationService; 21 | 22 | @Autowired 23 | LoadData loadData; 24 | 25 | public String handleSVM(SVMModel svmModel) throws Exception{ 26 | FileFactory.TrainTest instances = fileFactory.getInstancesFromFile(svmModel.getFileName(), new Options(svmModel.isFeatureSelection())); 27 | LibSVM libSVM = new LibSVM(); 28 | if (svmModel.getKernelType() == SVMModel.KernelType.Sigmoid){ 29 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_SIGMOID, LibSVM.TAGS_KERNELTYPE)); 30 | } 31 | else if (svmModel.getKernelType() == SVMModel.KernelType.Linear){ 32 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_LINEAR, LibSVM.TAGS_KERNELTYPE)); 33 | } 34 | else if (svmModel.getKernelType() == SVMModel.KernelType.Polynomial){ 35 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE)); 36 | } 37 | else if (svmModel.getKernelType() == SVMModel.KernelType.RBF){ 38 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_RBF, LibSVM.TAGS_KERNELTYPE)); 39 | } 40 | libSVM.buildClassifier(instances.train); 41 | return handleLibSvmEvaluation(libSVM, svmModel, instances); 42 | } 43 | 44 | public String handleSplitData(SVMModel svmModel, int num, String retString) throws Exception{ 45 | if (num <= 100){ 46 | retString += "Amount " + Integer.toString(num) + "\n"; 47 | FileFactory.TrainTest data; 48 | if (svmModel.getFileName() == ML.Files.Census){ 49 | data = fileFactory.handlePublicCensus(num, new Options(svmModel.isFeatureSelection())); 50 | } else { 51 | data = fileFactory.handlePublicCar(num); 52 | } 53 | 54 | LibSVM cls = svmClassifier(svmModel, data.train); 55 | Instances d; 56 | if (svmModel.getTestType() == ML.TestType.Train){ 57 | if (svmModel.getFileName() == ML.Files.Car){ 58 | d = fileFactory.handlePublicCar(0).train; 59 | } else { 60 | d = fileFactory.handlePublicCensus(0, new Options(svmModel.isFeatureSelection())).train; 61 | } 62 | } else { 63 | d = data.test; 64 | } 65 | return handleSplitData(svmModel, num==1 ? num+9 : num+10, retString + "\n \n" + evaluationService.evaluateData(data.train, cls, d)); 66 | } 67 | return retString; 68 | } 69 | 70 | public void createModel(SVMModel svm) throws Exception{ 71 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(svm.getFileName(), new Options(svm.isFeatureSelection())); 72 | Classifier cls = svmClassifier(svm, data.train); 73 | loadData.saveModel(cls, getString(svm)); 74 | } 75 | 76 | public String getModel(SVMModel svm) throws Exception { 77 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(svm.getFileName(), new Options(svm.isFeatureSelection())); 78 | Classifier cls = loadData.getModel(getString(svm)); 79 | svm.setTestType(ML.TestType.TestData); 80 | return handleLibSvmEvaluation((LibSVM) cls, svm, data); 81 | } 82 | 83 | private String getString(SVMModel svm){ 84 | String s = "SVM-KernelType=" + svm.getKernelType().toString() + 85 | "-FileName=" + svm.getFileName(); 86 | if (svm.isFeatureSelection()){ 87 | s += "-feature=true"; 88 | } 89 | return s + ".model"; 90 | } 91 | 92 | private LibSVM svmClassifier(SVMModel svmModel, Instances data) throws Exception{ 93 | LibSVM libSVM = new LibSVM(); 94 | if (svmModel.getKernelType() == SVMModel.KernelType.Sigmoid){ 95 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_SIGMOID, LibSVM.TAGS_KERNELTYPE)); 96 | } 97 | else if (svmModel.getKernelType() == SVMModel.KernelType.Linear){ 98 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_LINEAR, LibSVM.TAGS_KERNELTYPE)); 99 | } 100 | else if (svmModel.getKernelType() == SVMModel.KernelType.Polynomial){ 101 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_POLYNOMIAL, LibSVM.TAGS_KERNELTYPE)); 102 | } 103 | else if (svmModel.getKernelType() == SVMModel.KernelType.RBF){ 104 | libSVM.setKernelType(new SelectedTag(LibSVM.KERNELTYPE_RBF, LibSVM.TAGS_KERNELTYPE)); 105 | } 106 | libSVM.buildClassifier(data); 107 | return libSVM; 108 | } 109 | 110 | private String handleLibSvmEvaluation(LibSVM libSVM, SVMModel svmModel, FileFactory.TrainTest instances) throws Exception{ 111 | if (svmModel.getTestType().equals(ML.TestType.CrossValidation)){ 112 | return evaluationService.evaluateData(instances.train, libSVM, 10); 113 | } 114 | else if (svmModel.getTestType().equals(ML.TestType.Train)){ 115 | return evaluationService.evaluateData(instances.train, libSVM, instances.train); 116 | } 117 | else { 118 | return evaluationService.evaluateData(instances.train, libSVM, instances.test); 119 | } 120 | } 121 | 122 | } 123 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/KNNService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.ML; 4 | import com.derek.ml.models.NearestNeighbor; 5 | import com.derek.ml.models.Options; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Service; 8 | import weka.classifiers.Classifier; 9 | import weka.classifiers.Evaluation; 10 | import weka.classifiers.lazy.IBk; 11 | import weka.classifiers.meta.AdaBoostM1; 12 | import weka.core.Instances; 13 | import weka.core.neighboursearch.BallTree; 14 | import weka.core.neighboursearch.CoverTree; 15 | import weka.core.neighboursearch.LinearNNSearch; 16 | import weka.core.neighboursearch.NearestNeighbourSearch; 17 | 18 | import java.util.Enumeration; 19 | import java.util.Random; 20 | 21 | @Service 22 | public class KNNService { 23 | 24 | @Autowired 25 | private FileFactory fileFactory; 26 | 27 | @Autowired 28 | private EvaluationService evaluationService; 29 | 30 | @Autowired 31 | LoadData loadData; 32 | 33 | public String handleKNNService(NearestNeighbor nearestNeighbor) throws Exception{ 34 | FileFactory.TrainTest instances = fileFactory.getInstancesFromFile(nearestNeighbor.getFileName(), new Options(nearestNeighbor.isFeatureSelection())); 35 | Classifier iBk = handleIBK(nearestNeighbor, instances.train); 36 | return evaluateKNN(nearestNeighbor, instances, iBk); 37 | } 38 | 39 | public Classifier handleIBK(NearestNeighbor nearestNeighbor, Instances instances) throws Exception{ 40 | IBk iBk = new IBk(); 41 | iBk.setKNN(nearestNeighbor.getK()); 42 | iBk.setNearestNeighbourSearchAlgorithm(getNearestNeighborAlgorithm(nearestNeighbor)); 43 | iBk.setCrossValidate(nearestNeighbor.isHoldOneOut()); 44 | iBk.setMeanSquared(nearestNeighbor.isUseMeanError()); 45 | if (nearestNeighbor.isBoost()){ 46 | AdaBoostM1 adaBoostM1 = new AdaBoostM1(); 47 | adaBoostM1.setClassifier(iBk); 48 | adaBoostM1.buildClassifier(instances); 49 | return adaBoostM1; 50 | } 51 | iBk.buildClassifier(instances); 52 | return iBk; 53 | } 54 | 55 | public String evaluateKNN(NearestNeighbor nearestNeighbor, FileFactory.TrainTest instances, Classifier iBk) throws Exception{ 56 | if (nearestNeighbor.getTestType() == ML.TestType.TestData){ 57 | return evaluationService.evaluateData(instances.train, iBk, instances.test) + "\n \n " + iBk.toString(); 58 | } else if (nearestNeighbor.getTestType() == ML.TestType.Train){ 59 | return evaluationService.evaluateData(instances.train, iBk, instances.train) + "\n \n " + iBk.toString(); 60 | } 61 | return evaluationService.evaluateData(instances.train, iBk) + "\n \n " + iBk.toString(); 62 | } 63 | 64 | public void createModel(NearestNeighbor nearestNeighbor) throws Exception{ 65 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(nearestNeighbor.getFileName(), new Options(nearestNeighbor.isFeatureSelection())); 66 | Classifier cls = handleIBK(nearestNeighbor, data.train); 67 | loadData.saveModel(cls, getString(nearestNeighbor)); 68 | } 69 | 70 | public String getModel(NearestNeighbor nearestNeighbor) throws Exception{ 71 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(nearestNeighbor.getFileName(), new Options(nearestNeighbor.isFeatureSelection())); 72 | Classifier cls = loadData.getModel(getString(nearestNeighbor)); 73 | nearestNeighbor.setTestType(ML.TestType.TestData); 74 | return evaluateKNN(nearestNeighbor, data, cls); 75 | } 76 | 77 | private String getString(NearestNeighbor nearestNeighbor){ 78 | String nn = "KNearestNeighbor-k=" + nearestNeighbor.getK() + "-fileName=" + nearestNeighbor.getFileName(); 79 | if (nearestNeighbor.isFeatureSelection()){ 80 | nn += "-feature=true"; 81 | } 82 | return nn + ".model"; 83 | } 84 | 85 | public String handleSplitData(NearestNeighbor nearestNeighbor, int num, String retString) throws Exception{ 86 | if (num <= 100){ 87 | retString += "Amount " + Integer.toString(num) + "\n"; 88 | FileFactory.TrainTest data; 89 | if (nearestNeighbor.getFileName() == ML.Files.Census){ 90 | data = fileFactory.handlePublicCensus(num, new Options(nearestNeighbor.isFeatureSelection())); 91 | } else { 92 | data = fileFactory.handlePublicCar(num); 93 | } 94 | Classifier cls = handleIBK(nearestNeighbor, data.train); 95 | Instances d; 96 | if (nearestNeighbor.getTestType() == ML.TestType.Train){ 97 | if (nearestNeighbor.getFileName() == ML.Files.Car){ 98 | d = fileFactory.handlePublicCar(0).train; 99 | } else { 100 | d = fileFactory.handlePublicCensus(0, new Options(nearestNeighbor.isFeatureSelection())).train; 101 | } 102 | } else { 103 | d = data.test; 104 | } 105 | return handleSplitData(nearestNeighbor, num==1 ? num+9 : num+10, retString + "\n \n" + evaluationService.evaluateData(data.train, cls, d)); 106 | } 107 | return retString; 108 | } 109 | 110 | private NearestNeighbourSearch getNearestNeighborAlgorithm(NearestNeighbor nearestNeighbor){ 111 | if (nearestNeighbor.getTreeTypes() == NearestNeighbor.TreeTypes.Linear){ 112 | return new LinearNNSearch(); 113 | } else if (nearestNeighbor.getTreeTypes() == NearestNeighbor.TreeTypes.BallTree){ 114 | return new BallTree(); 115 | } else if (nearestNeighbor.getTreeTypes() == NearestNeighbor.TreeTypes.CoverTree){ 116 | return new CoverTree(); 117 | } 118 | return null; 119 | } 120 | 121 | } 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning with Weka and Spring examples 2 | 3 | Code url: https://github.com/dmmiller612/Machine_Learning_Spring_Weka 4 | 5 | # Instructions 6 | 7 | If wanting to run the server locally, instead of just using the Weka models located in /src/main/resources/models, there are a couple of dependencies needed: Maven and Java. 8 | 9 | 1. This uses java 1.7, but should work with 1.8 as well. For the JRE, `sudo apt-get install default-jre` . For the jdk, `sudo apt-get install default-jdk`. 10 | 11 | 2. This uses Maven 3.x . To Install maven 3, use `sudo apt-get install maven`. 12 | 13 | 3. Go to the root of the assignment code repository and type: `mvn clean package` into the command line. Then type `java -jar target/derek-assignment-1-0.1.0.jar`. Once running the jar, all of optimal models will start to run against the test datasets of both the Car Evaluation and Census dataset. This is here just to make it easier to visualize, so that you do not have to use the rest api. If you want to use the rest api, see documentation below. 14 | 15 | 4. IF three does not work, it is because the plugin did not properly install. Running this command should do the trick inside of the students-filters-master 16 | "mvn install:install-file -Dfile=filters-0.0.1-SNAPSHOT.jar -DgroupId=filters -DartifactId=filters -Dversion=0.0.1-SNAPSHOT -Dpackaging=jar" 17 | 18 | 19 | # Navigating the Source Code 20 | 21 | ### src/main/java/com/derek/ml/controllers 22 | 23 | contains the rest endpoints. 24 | 25 | ### src/main/java/com/derek/ml/services 26 | 27 | contains all of the logic and configuration of weka models. ClusterService -> k-means and EM, FeatureReductionService -> ICA, PCA, RP, CFS, KNNService -> KNN, NeuralNetworkService -> Neural Network, DecisionTreeService->Decision Trees (boosted an unboosted), SVMService->SVM 28 | 29 | ### src/main/java/com/derek/ml/models 30 | 31 | DTO passing layers 32 | 33 | 34 | # Navigating the Resources 35 | 36 | ### src/main/resources/csv 37 | 38 | Contains all of the initial csv files used (Arffs are only used for the models, however) 39 | 40 | ### src/main/resources/arffs 41 | 42 | Contains all of the arffs used. car_train.arff and car_test.arff are the training and testing instances for the car evaluation dataset. census.arff and censusTest.arff are the training and testing instances for the Census dataset. 43 | 44 | ### src/main/resources/models 45 | 46 | Contains several models used for the supervised learning analysis. If you don’t want to run the code locally, you can just use these models against the training and test arffs listed above. 47 | 48 | 49 | # Using the Rest API (Optional) 50 | 51 | I thought I would just add this to show the code I used for experimentation with Weka. I used the api, so that I could do multiple concurrent requests. 52 | 53 | Universal Query parameters: fileName : {Car, Census, CarBin, CensusBin}, testType : {CrossValidation, TestData, Train} 54 | 55 | # Cluster 56 | 57 | Endpoints: /kMeans and /em 58 | Query Params => clusters : int, distances : {Euclidean, Manhatten}, iterations: int, featureSelection: {ICA, PCA, RP, CFS}; 59 | 60 | # Feature Reduction 61 | 62 | Endpoints: /featureReduction/pca /featureReduction/ica /featureReduction/rp /featureReduction/cfs 63 | 64 | # Decision Trees 65 | 66 | Endpoint: /decisiontree 67 | 68 | Query Params => minNumObj : int, boost : boolean, confidence : String, treeType : {ID3, J48} 69 | 70 | Example Requests: 71 | http://localhost:8080/decisiontree?fileName=Car&testType=TestData&minNumObj=2&confidence=.25 72 | http://localhost:8080/decisiontree?fileName=Census&testType=CrossValidation&minNumObj=2&confidence=.25 73 | http://localhost:8080/decisiontree?fileName=Car&testType=CrossValidation&minNumObj=2&confidence=.25&boost=true //with boosting 74 | 75 | Using incremental testing example: 76 | http://localhost:8080/decisiontree/test?fileName=Car&testType=TestData&minNumObj=2&confidence=.25&boost=true 77 | 78 | # KNN 79 | 80 | Endpoint: /knn 81 | 82 | Query Params => k : int, boost : boolean, treeTypes {BallTree, CoverTree, Linear}, useFeatureSelection : boolean (applies only to Census file) 83 | 84 | Examples : 85 | http://localhost:8080/knn?fileName=Car&testType=TestData&k=3 86 | http://localhost:8080/knn?fileName=Census&testType=TestData&k=5&featureSelection=true 87 | http://localhost:8080/knn?fileName=Census&testType=TestData&k=5&boost=true 88 | 89 | Using incremental testing example: 90 | http://localhost:8080/knn/test?fileName=Census&testType=TestData&k=5 91 | 92 | # ANN 93 | 94 | Endpoint: /neuralnetwork 95 | 96 | Query Params => hiddenLayers : int, epochRate : int, featureSelection : boolean (applies only to Census file) 97 | 98 | Examples: 99 | http://localhost:8080/neuralnetwork?fileName=Car&testType=TestData&hiddenLayers=10&epochRate=500 100 | http://localhost:8080/neuralnetwork?fileName=Census&testType=TestData&hiddenLayers=5&epochRate=500&featureSelection=true 101 | 102 | Using incremental testing example: 103 | http://localhost:8080/neuralnetwork/test?fileName=Car&testType=TestData&hiddenLayers=10&epochRate=500 104 | 105 | # SVM 106 | 107 | Endpoint: /svm 108 | 109 | Query Params => kernelType : {Polynomial, RBF, Sigmoid, Linear} 110 | 111 | Examples: 112 | http://localhost:8080/svm?fileName=Car&testType=TestData&kernelType=Polynomial 113 | http://localhost:8080/svm?fileName=Census&testType=TestData&kernelType=RBF 114 | http://localhost:8080/svm?fileName=Census&testType=TestData&kernelType=Sigmoid 115 | http://localhost:8080/svm?fileName=Census&testType=TestData&kernelType=Linear 116 | 117 | 118 | # MODELS 119 | 120 | The model names contain the parameters that were used, fileName, and algorithm name. 121 | 122 | Decision Tree Naming Convention: decisionTree + minNumObj + Boosted + confidence + fileName + .model 123 | Example: decisionTree-minNumObj=100-Boosted=false-C=0.25-file=Census.model 124 | 125 | KNN Naming Convention: KNearestNeighbor + k + fileName + .model 126 | Example : KNearestNeighbor-k=20-fileName=Car.model 127 | 128 | ANN Naming Convention: ANN + hiddenLayers + epochRate + FileName + (Optional) featureSelection + .model 129 | Example : ANN-hiddenLayers=10-epochRate=250-FileName=Census.model 130 | 131 | SVM Naming Convention: SVM + kernelType + FileName + .model 132 | Example : SVM-KernelType=Linear-FileName=Car.model 133 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/NeuralNetworkService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.ML; 4 | import com.derek.ml.models.NeuralNetworkModel; 5 | import com.derek.ml.models.Options; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Service; 8 | import weka.classifiers.Classifier; 9 | import weka.classifiers.functions.MultilayerPerceptron; 10 | import weka.core.Instances; 11 | 12 | @Service 13 | public class NeuralNetworkService { 14 | 15 | @Autowired 16 | private FileFactory fileFactory; 17 | 18 | @Autowired 19 | private EvaluationService evaluationService; 20 | 21 | @Autowired 22 | private FeatureReductionService featureReductionService; 23 | 24 | @Autowired 25 | LoadData loadData; 26 | 27 | public String handleNeuralNetwork(NeuralNetworkModel neuralNetworkModel) throws Exception{ 28 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(neuralNetworkModel.getFileName(), new Options(neuralNetworkModel.isFeatureSelection())); 29 | MultilayerPerceptron multilayerPerceptron = handleClassification(data.train, neuralNetworkModel); 30 | return handleEvaluation(multilayerPerceptron, neuralNetworkModel, data); 31 | } 32 | 33 | public MultilayerPerceptron handleClassification(Instances data, NeuralNetworkModel neuralNetworkModel) throws Exception{ 34 | MultilayerPerceptron multilayerPerceptron = new MultilayerPerceptron(); 35 | multilayerPerceptron.setTrainingTime(neuralNetworkModel.getEpochRate()); 36 | multilayerPerceptron.setHiddenLayers(Integer.toString(neuralNetworkModel.getHiddenLayers())); 37 | multilayerPerceptron.buildClassifier(data); 38 | return multilayerPerceptron; 39 | } 40 | 41 | public String handleEvaluation(MultilayerPerceptron multilayerPerceptron, NeuralNetworkModel neuralNetworkModel, FileFactory.TrainTest data) throws Exception{ 42 | if (neuralNetworkModel.getTestType() == ML.TestType.TestData){ 43 | return evaluationService.evaluateData(data.train, multilayerPerceptron, data.test) + "\n \n" + multilayerPerceptron.toString(); 44 | } else if (neuralNetworkModel.getTestType() == ML.TestType.Train){ 45 | return evaluationService.evaluateData(data.train, multilayerPerceptron, data.train) + "\n \n" + multilayerPerceptron.toString(); 46 | } 47 | else { 48 | return evaluationService.evaluateData(data.train, multilayerPerceptron, 6) + " \n \n" + multilayerPerceptron.toString(); 49 | } 50 | } 51 | 52 | public String handleSplitData(NeuralNetworkModel neuralNetworkModel, int num, String retString) throws Exception{ 53 | if (num <= 100){ 54 | retString += "Amount " + Integer.toString(num) + "\n"; 55 | FileFactory.TrainTest data; 56 | if (neuralNetworkModel.getFileName() == ML.Files.Census){ 57 | data = fileFactory.handlePublicCensus(num, new Options(neuralNetworkModel.isFeatureSelection())); 58 | } else { 59 | data = fileFactory.handlePublicCar(num); 60 | } 61 | MultilayerPerceptron multilayerPerceptron = handleClassification(data.train, neuralNetworkModel); 62 | Instances d; 63 | if (neuralNetworkModel.getTestType() == ML.TestType.Train){ 64 | if (neuralNetworkModel.getFileName() == ML.Files.Car){ 65 | d = fileFactory.handlePublicCar(0).train; 66 | } else { 67 | d = fileFactory.handlePublicCensus(0, new Options(neuralNetworkModel.isFeatureSelection())).train; 68 | } 69 | } else { 70 | d = data.test; 71 | } 72 | return handleSplitData(neuralNetworkModel, num==1 ? num+9 : num+10, retString + "\n \n" + evaluationService.evaluateData(data.train, multilayerPerceptron, d)); 73 | } 74 | return retString; 75 | } 76 | 77 | public void createModel(NeuralNetworkModel nn) throws Exception{ 78 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(nn.getFileName(), new Options(nn.isFeatureSelection())); 79 | Classifier cls = handleClassification(data.train, nn); 80 | loadData.saveModel(cls, getString(nn)); 81 | } 82 | 83 | public String getModel(NeuralNetworkModel nn) throws Exception { 84 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(nn.getFileName(), new Options(nn.isFeatureSelection())); 85 | Classifier cls = loadData.getModel(getString(nn)); 86 | nn.setTestType(ML.TestType.TestData); 87 | return handleEvaluation((MultilayerPerceptron) cls, nn, data); 88 | } 89 | 90 | public String neuralNetworkWithReduction() throws Exception{ 91 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options()); 92 | Instances pcaCensus = featureReductionService.applyPCAFilter(censusTrainTest.test, 30); 93 | Instances icaCensus = featureReductionService.applyICA(censusTrainTest.test, 30); 94 | Instances rpCensus = featureReductionService.applyRP(censusTrainTest.test, 30); 95 | 96 | NeuralNetworkModel neuralNetworkModel = new NeuralNetworkModel(); 97 | neuralNetworkModel.setEpochRate(500); 98 | neuralNetworkModel.setHiddenLayers(4); 99 | 100 | censusTrainTest.train = pcaCensus; 101 | String one = handleEvaluation(handleClassification(pcaCensus, neuralNetworkModel), neuralNetworkModel, censusTrainTest); 102 | censusTrainTest.train = icaCensus; 103 | String two = handleEvaluation(handleClassification(icaCensus, neuralNetworkModel), neuralNetworkModel, censusTrainTest); 104 | censusTrainTest.train = rpCensus; 105 | String three = handleEvaluation(handleClassification(rpCensus, neuralNetworkModel), neuralNetworkModel, censusTrainTest); 106 | 107 | return "PCA \n \n " + one + " \n \n \n ICA \n \n" + two + " \n \n \n RP \n \n" + three; 108 | } 109 | 110 | private String getString(NeuralNetworkModel nn){ 111 | String s = "ANN-hiddenLayers=" + nn.getHiddenLayers() + 112 | "-epochRate=" + nn.getEpochRate() + 113 | "-FileName=" + nn.getFileName(); 114 | if (nn.isFeatureSelection()){ 115 | s += "-feature=true"; 116 | } 117 | return s + ".model"; 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/DecisionTreeService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.DecisionTree; 4 | import com.derek.ml.models.ML; 5 | import com.derek.ml.models.Options; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Service; 8 | import weka.classifiers.Classifier; 9 | import weka.classifiers.Evaluation; 10 | import weka.classifiers.bayes.net.search.local.SimulatedAnnealing; 11 | import weka.classifiers.evaluation.NominalPrediction; 12 | import weka.classifiers.meta.AdaBoostM1; 13 | import weka.classifiers.trees.DecisionStump; 14 | import weka.classifiers.trees.J48; 15 | import weka.classifiers.trees.REPTree; 16 | import weka.core.FastVector; 17 | import weka.core.Instances; 18 | 19 | import java.util.HashMap; 20 | import java.util.Random; 21 | 22 | @Service 23 | public class DecisionTreeService { 24 | 25 | private LoadData loadData; 26 | private FileFactory fileFactory; 27 | private EvaluationService evaluationService; 28 | 29 | @Autowired 30 | public DecisionTreeService(FileFactory fileFactory, LoadData loadData, EvaluationService evaluationService){ 31 | this.fileFactory = fileFactory; 32 | this.loadData = loadData; 33 | this.evaluationService = evaluationService; 34 | } 35 | 36 | public String getDecisionTreeInformation(DecisionTree decisionTree) throws Exception{ 37 | if (decisionTree.getTreeType() == DecisionTree.TreeType.J48) { 38 | return handleJ48(decisionTree); 39 | } 40 | return null; 41 | } 42 | 43 | public Classifier buildJ48(DecisionTree decisionTree, Instances data) throws Exception{ 44 | //uses information gain ratio 45 | J48 j48 = new J48(); 46 | if (decisionTree.getUnpruned() != null){ 47 | j48.setUnpruned(decisionTree.getUnpruned()); 48 | } 49 | if (decisionTree.getConfidence() != null){ 50 | j48.setOptions(new String[]{"-C", decisionTree.getConfidence()}); 51 | } 52 | if (decisionTree.getMinNumObj() != null){ 53 | j48.setMinNumObj(decisionTree.getMinNumObj()); 54 | } 55 | if (decisionTree.isBoost()){ 56 | AdaBoostM1 adaBoostM1 = new AdaBoostM1(); 57 | adaBoostM1.setUseResampling(false); 58 | adaBoostM1.setClassifier(j48); 59 | adaBoostM1.buildClassifier(data); 60 | return adaBoostM1; 61 | } else { 62 | j48.buildClassifier(data); 63 | return j48; 64 | } 65 | } 66 | 67 | public String evaluateData(FileFactory.TrainTest data, Classifier classifier, DecisionTree decisionTree) throws Exception{ 68 | Evaluation evaluation = new Evaluation(data.train); 69 | if (decisionTree.getTestType() == DecisionTree.TestType.CrossValidation){ 70 | evaluation.crossValidateModel(classifier, data.train, 10, new Random(1)); 71 | } else if (decisionTree.getTestType() == DecisionTree.TestType.Train){ 72 | FileFactory.TrainTest d; 73 | if (decisionTree.getFileName() == ML.Files.Census){ 74 | d = fileFactory.handlePublicCensus(0, new Options(decisionTree.isFeatureSelection())); 75 | } else { 76 | d = fileFactory.handlePublicCar(0); 77 | } 78 | evaluation.evaluateModel(classifier, d.train); 79 | } else { 80 | evaluation.evaluateModel(classifier, data.test); 81 | } 82 | String retString = ""; 83 | retString += evaluation.toSummaryString() + " \n"; 84 | retString += evaluation.toClassDetailsString() + " \n"; 85 | //retString += evaluation.toMatrixString() + " \n"; 86 | return retString; 87 | } 88 | 89 | public void createModel(DecisionTree decisionTree) throws Exception{ 90 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(decisionTree.getFileName(), new Options(decisionTree.isFeatureSelection())); 91 | Classifier cls = buildJ48(decisionTree, data.train); 92 | loadData.saveModel(cls, getString(decisionTree)); 93 | } 94 | 95 | public String getModel(DecisionTree decisionTree) throws Exception{ 96 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(decisionTree.getFileName(), new Options(decisionTree.isFeatureSelection())); 97 | Classifier cls = loadData.getModel(getString(decisionTree)); 98 | decisionTree.setTestType(ML.TestType.TestData); 99 | return evaluateData(data, cls, decisionTree); 100 | } 101 | 102 | private String getString(DecisionTree decisionTree){ 103 | String dString = "decisionTree-minNumObj=" + 104 | decisionTree.getMinNumObj().toString() + 105 | "-Boosted=" + decisionTree.isBoost() + 106 | "-C=" + decisionTree.getConfidence() + 107 | "-file=" + decisionTree.getFileName().toString(); 108 | if (decisionTree.isFeatureSelection()){ 109 | dString += "-feature=true"; 110 | } 111 | return dString + ".model"; 112 | } 113 | 114 | public String handleSplitData(DecisionTree decisionTree, int num, String retString) throws Exception{ 115 | if (num <= 100){ 116 | retString += "Amount " + Integer.toString(num) + "\n"; 117 | FileFactory.TrainTest data; 118 | if (decisionTree.getFileName() == ML.Files.Census){ 119 | data = fileFactory.handlePublicCensus(num, new Options(decisionTree.isFeatureSelection())); 120 | } else { 121 | data = fileFactory.handlePublicCar(num); 122 | } 123 | Classifier cls = buildJ48(decisionTree, data.train); 124 | Instances d; 125 | if (decisionTree.getTestType() == ML.TestType.Train){ 126 | if (decisionTree.getFileName() == ML.Files.Car){ 127 | d = fileFactory.handlePublicCar(0).train; 128 | } else { 129 | d = fileFactory.handlePublicCensus(0, new Options(decisionTree.isFeatureSelection())).train; 130 | } 131 | } else { 132 | d = data.test; 133 | } 134 | return handleSplitData(decisionTree, num==1 ? num+9 : num+10, retString + "\n \n" + evaluationService.evaluateData(data.train, cls, d)); 135 | } 136 | return retString; 137 | } 138 | 139 | public String handleRegressionTree(DecisionTree decisionTree) throws Exception{ 140 | REPTree repTree = new REPTree(); 141 | FileFactory.TrainTest trainTest = fileFactory.getInstancesFromFile(ML.Files.Boston, new Options()); 142 | repTree.buildClassifier(trainTest.train); 143 | 144 | Evaluation evaluation = new Evaluation(trainTest.train); 145 | evaluation.crossValidateModel(repTree, trainTest.train, 10, new Random(1)); 146 | return repTree.toString(); 147 | } 148 | 149 | 150 | private String handleJ48(DecisionTree decisionTree) throws Exception{ 151 | FileFactory.TrainTest data = fileFactory.getInstancesFromFile(decisionTree.getFileName(), new Options(decisionTree.isFeatureSelection())); 152 | Classifier j48 = buildJ48(decisionTree, data.train); 153 | return evaluateData(data, j48, decisionTree) + "\n \n " + j48.toString(); 154 | } 155 | 156 | } 157 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/FileFactory.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.ML; 4 | import com.derek.ml.models.Options; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.stereotype.Service; 7 | import weka.attributeSelection.*; 8 | 9 | import weka.core.Instance; 10 | import weka.core.Instances; 11 | import weka.filters.Filter; 12 | import weka.filters.unsupervised.attribute.*; 13 | import weka.filters.unsupervised.attribute.Discretize; 14 | import weka.filters.unsupervised.instance.RemovePercentage; 15 | 16 | @Service 17 | public class FileFactory { 18 | 19 | @Autowired 20 | public LoadData loadData; 21 | 22 | public TrainTest getInstancesFromFile(ML.Files file, Options options) throws Exception{ 23 | switch (file) { 24 | case Census: 25 | return handlePublicCensus(options); 26 | case Car: 27 | return new TrainTest(handleData("car_train"), handleData("car_test")); 28 | case Boston: 29 | return new TrainTest(handleData("boston"), handleData("boston")); 30 | case CarBin: 31 | return new TrainTest(handleData("car_bin", true), handleData("car_bin")); 32 | case CensusBin: 33 | return new TrainTest(handleData("census_bin", true), handleData("census_bin")); 34 | case CensusEm: 35 | return new TrainTest(handleData("census_em"), handleData("census_em")); 36 | case CensusKm: 37 | return new TrainTest(handleData("census_km"), handleData("census_km")); 38 | } 39 | return null; 40 | } 41 | 42 | private Instances filterClass(Instances data) throws Exception{ 43 | Remove filter = new Remove(); 44 | filter.setAttributeIndices("" + (data.classIndex() + 1)); 45 | filter.setInputFormat(data); 46 | return Filter.useFilter(data, filter); 47 | } 48 | 49 | public void splitCarDataToTest(int amount) throws Exception{ 50 | Instances instances = handleData("car"); 51 | RemovePercentage removePercentage = new RemovePercentage(); 52 | removePercentage.setPercentage(amount); 53 | removePercentage.setInputFormat(instances); 54 | loadData.saveToArff(Filter.useFilter(instances, removePercentage), "car_train.arff"); 55 | removePercentage.setInvertSelection(true); 56 | loadData.saveToArff(Filter.useFilter(instances, removePercentage), "car_test.arff"); 57 | } 58 | 59 | private Instances handleData(String fileName) throws Exception{ 60 | try { 61 | return loadData.getDataFromArff(fileName + ".arff"); 62 | } catch (Exception e){ 63 | return loadData.getDataFromCsvFile(fileName + ".csv"); 64 | } 65 | } 66 | 67 | private Instances handleData(String fileName, boolean noClass) throws Exception{ 68 | try { 69 | return loadData.getDataFromArff(fileName + ".arff", noClass); 70 | } catch (Exception e){ 71 | return loadData.getDataFromCsvFile(fileName + ".csv"); 72 | } 73 | } 74 | 75 | private Instances removeFilter(Instances data, String indicesToRemove) throws Exception{ 76 | Remove remove = new Remove(); 77 | remove.setAttributeIndices(indicesToRemove); 78 | remove.setInputFormat(data); 79 | remove.setInvertSelection(true); 80 | return Filter.useFilter(data, remove); 81 | } 82 | 83 | private Instances numericToNominalFilter(Instances data, String indicesToNominalize) throws Exception { 84 | NumericToNominal numericToNominal = new NumericToNominal(); 85 | numericToNominal.setAttributeIndices(indicesToNominalize); 86 | numericToNominal.setInputFormat(data); 87 | return Filter.useFilter(data, numericToNominal); 88 | } 89 | 90 | private Instances discretizeFilter(Instances data, String indices, int bins) throws Exception{ 91 | Discretize d = new Discretize(); 92 | if (indices != null){ 93 | d.setAttributeIndices(indices); 94 | } 95 | d.setIgnoreClass(true); 96 | d.setBins(bins); 97 | d.setInputFormat(data); 98 | return Filter.useFilter(data, d); 99 | } 100 | 101 | private Instances removeInstancesWithQuestionMarks(Instances data){ 102 | int numAttributes = data.numAttributes(); 103 | int numInstances = data.numInstances(); 104 | 105 | for (int out = 0; out < numInstances; out++){ 106 | Instance currentInstance = data.instance(out); 107 | for (int in = 0; in < numAttributes; in++){ 108 | if (currentInstance != null){ 109 | try{ 110 | String currentAttribute = currentInstance.toString(in); 111 | if (currentAttribute.contains("?")){ 112 | data.delete(out); 113 | } 114 | } catch (Exception e){ 115 | // 116 | } 117 | } 118 | } 119 | } 120 | return data; 121 | } 122 | 123 | public TrainTest handlePublicCensus(Options options) throws Exception{ 124 | return handlePublicCensus(0, options); 125 | } 126 | 127 | public TrainTest handlePublicCensus(int numToRemove, Options options) throws Exception{ 128 | Instances trainingData = handleData("census", options.isNoClass()); 129 | Instances testData; 130 | if (options.isNoClass()){ 131 | testData = handleData("census", false); 132 | } else { 133 | testData = handleData("censusTest", false); 134 | } 135 | 136 | Instances temp; 137 | if (numToRemove > 0){ 138 | RemovePercentage removePercentage = new RemovePercentage(); 139 | removePercentage.setInputFormat(trainingData); 140 | removePercentage.setPercentage(numToRemove); 141 | removePercentage.setInvertSelection(true); 142 | temp = Filter.useFilter(trainingData, removePercentage); 143 | }else { 144 | temp = trainingData; 145 | } 146 | 147 | //5,6,8,11,12 148 | if (options.isFeatureSelection()){ 149 | Instances trainingRemoved = removeFilter(temp, "1-4,7,9-10,13-14"); 150 | Instances testingRemoved = removeFilter(testData, "1-4,7,9-10,13-14"); 151 | return new TrainTest(trainingRemoved, testingRemoved); 152 | } 153 | 154 | return new TrainTest(temp, testData); 155 | } 156 | 157 | public TrainTest handlePublicCar(int num) throws Exception{ 158 | TrainTest data = new TrainTest(handleData("car_train"), handleData("car_test")); 159 | if (num <= 0){ 160 | return data; 161 | } 162 | RemovePercentage removePercentage = new RemovePercentage(); 163 | removePercentage.setInputFormat(data.train); 164 | removePercentage.setPercentage(num); 165 | removePercentage.setInvertSelection(true); 166 | Instances trainingData = Filter.useFilter(data.train, removePercentage); 167 | 168 | return new TrainTest(trainingData, data.test); 169 | } 170 | 171 | public void saveDiscretizedArff() throws Exception{ 172 | Instances temp = handleData("census"); 173 | Instances testData = handleData("censusTest"); 174 | weka.filters.supervised.attribute.Discretize discretize = new weka.filters.supervised.attribute.Discretize(); 175 | discretize.setInputFormat(temp); 176 | 177 | Instances trainingDiscretize = Filter.useFilter(temp, discretize); 178 | Instances testingDiscretize = Filter.useFilter(testData, discretize); 179 | loadData.saveToArff(trainingDiscretize, "census.arff"); 180 | loadData.saveToArff(testingDiscretize, "censusTest.arff"); 181 | } 182 | 183 | public class TrainTest { 184 | public Instances train; 185 | public Instances test; 186 | public TrainTest(Instances train, Instances test){ 187 | this.train = train; 188 | this.test = test; 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/FeatureReductionService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | 4 | import com.derek.ml.models.ML; 5 | import com.derek.ml.models.Options; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.stereotype.Service; 8 | import weka.attributeSelection.*; 9 | import weka.core.Attribute; 10 | import weka.core.Instances; 11 | import weka.filters.Filter; 12 | import weka.filters.unsupervised.attribute.Add; 13 | import weka.filters.unsupervised.attribute.IndependentComponents; 14 | import weka.filters.unsupervised.attribute.RandomProjection; 15 | import weka.filters.unsupervised.attribute.Remove; 16 | 17 | import java.io.FileWriter; 18 | 19 | @Service 20 | public class FeatureReductionService { 21 | 22 | @Autowired 23 | private FileFactory fileFactory; 24 | 25 | public String handlePCAFeatures() throws Exception{ 26 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.Car, new Options()); 27 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.Census, new Options()); 28 | return ApplyPCA("CAR", carTrainTest.train) + "\n \n \n \n \n" + ApplyPCA("CENSUS", censusTrainTest.train); 29 | } 30 | 31 | public String handleRandomizedProjectionFeatures() throws Exception{ 32 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.Car, new Options()); 33 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.Census, new Options()); 34 | return applyRP(carTrainTest.train, 4).toString() + "\n \n \n \n \n" + applyRP(censusTrainTest.train, 4).toString(); 35 | } 36 | 37 | public String handleICAFeatures() throws Exception { 38 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options()); 39 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options()); 40 | return applyICA(dropClass(carTrainTest.test), 4).toString() + "\n \n \n \n \n" + applyICA(dropClass(censusTrainTest.test), 4).toString(); 41 | } 42 | 43 | public String handleCFSSubsetEval() throws Exception { 44 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.Car, new Options()); 45 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.Census, new Options()); 46 | return applyCfsSubsetEval(carTrainTest.train) + " \n \n \n \n" + applyCfsSubsetEval(censusTrainTest.train); 47 | } 48 | 49 | public String ApplyPCA(String name, Instances trainingData) throws Exception{ 50 | AttributeSelection selector = new AttributeSelection(); 51 | 52 | PrincipalComponents principalComponents = new PrincipalComponents(); 53 | principalComponents.setMaximumAttributeNames(5); 54 | principalComponents.setVarianceCovered(.95); 55 | principalComponents.buildEvaluator(trainingData); 56 | 57 | Ranker ranker = new Ranker(); 58 | 59 | selector.setSearch(ranker); 60 | selector.setEvaluator(principalComponents); 61 | selector.SelectAttributes(trainingData); 62 | return name + "\n \n \n \n \n \n \n Principal Components: \n \n " 63 | + principalComponents.toString() + "\n \n Attribute Selection: \n \n" + selector.toResultsString(); 64 | } 65 | 66 | public Instances applyRP(Instances trainingData, int numAttributes) throws Exception { 67 | RandomProjection randomProjection = new RandomProjection(); 68 | randomProjection.setNumberOfAttributes(numAttributes); 69 | randomProjection.setInputFormat(trainingData); 70 | return Filter.useFilter(trainingData, randomProjection); 71 | } 72 | 73 | public Instances applyPCAFilter(Instances trainingData, int numAttributes) throws Exception{ 74 | weka.filters.unsupervised.attribute.PrincipalComponents principalComponents = new weka.filters.unsupervised.attribute.PrincipalComponents(); 75 | principalComponents.setMaximumAttributes(numAttributes); 76 | principalComponents.setInputFormat(trainingData); 77 | return Filter.useFilter(trainingData, principalComponents); 78 | } 79 | 80 | public Instances applyICA(Instances trainingData, int numAttributes) throws Exception { 81 | IndependentComponents independentComponents = new IndependentComponents(); 82 | independentComponents.setNumIterations(100); 83 | independentComponents.setInputFormat(trainingData); 84 | return Filter.useFilter(trainingData, independentComponents); 85 | } 86 | 87 | public String applyCfsSubsetEval(Instances data) throws Exception { 88 | AttributeSelection selector = new AttributeSelection(); 89 | CfsSubsetEval cfsSubsetEval = new CfsSubsetEval(); 90 | cfsSubsetEval.buildEvaluator(data); 91 | 92 | GreedyStepwise greedyStepwise = new GreedyStepwise(); 93 | greedyStepwise.setGenerateRanking(true); 94 | greedyStepwise.setNumToSelect(4); 95 | selector.setSearch(greedyStepwise); 96 | selector.setEvaluator(cfsSubsetEval); 97 | selector.SelectAttributes(data); 98 | 99 | return " \n \n Principal Components: \n \n " 100 | + cfsSubsetEval.toString() + "\n \n Attribute Selection: \n \n" + selector.toResultsString(); 101 | } 102 | 103 | private void plotRP(String fileName, Instances data) throws Exception { 104 | FileWriter writer = new FileWriter(fileName + ".csv", true); 105 | for (int i = 0; i < data.size(); i++){ 106 | double[] values = data.get(i).toDoubleArray(); 107 | writer.append(new Double(values[0]).toString()); 108 | writer.append(","); 109 | writer.append(new Double(values[1]).toString()); 110 | writer.append(","); 111 | writer.append(new Double(data.get(i).classValue()).toString()); 112 | writer.append("\n"); 113 | } 114 | writer.flush(); 115 | writer.close(); 116 | } 117 | 118 | public void plotRP() throws Exception{ 119 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.Car, new Options()); 120 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.Census, new Options()); 121 | plotRP("RP_car", applyRP(carTrainTest.train,2)); 122 | plotRP("RP_census", applyRP(censusTrainTest.train, 2)); 123 | } 124 | 125 | public void plotPCA() throws Exception { 126 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.Car, new Options()); 127 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.Census, new Options()); 128 | plotRP("PCA_car", applyPCAFilter(carTrainTest.train, 2)); 129 | plotRP("PCA_census", applyPCAFilter(censusTrainTest.train, 2)); 130 | } 131 | 132 | public void plotICA() throws Exception { 133 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options()); 134 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options()); 135 | Instances newCarData = applyICA(dropClass(carTrainTest.test), 2); 136 | Instances newCensusData = applyICA(dropClass(censusTrainTest.test), 2); 137 | 138 | Instances one = reAddClassification(newCarData, carTrainTest.test); 139 | Instances two = reAddClassification(newCensusData, censusTrainTest.test); 140 | 141 | plotRP("ICA_car", one); 142 | plotRP("ICA_census", two); 143 | } 144 | 145 | private Instances dropClass(Instances instances) throws Exception { 146 | Remove removeFilter = new Remove(); 147 | String[] options = new String[]{"-R", Integer.toString(instances.numAttributes() -1)}; 148 | removeFilter.setOptions(options); 149 | removeFilter.setInputFormat(instances); 150 | return Filter.useFilter(instances, removeFilter); 151 | } 152 | 153 | public Instances reAddClassification(Instances first, Instances second) throws Exception{ 154 | Add filter = new Add(); 155 | filter.setAttributeIndex("last"); 156 | filter.setAttributeName("NewNumeric"); 157 | filter.setInputFormat(first); 158 | Instances newFirst = Filter.useFilter(first, filter); 159 | for (int i = 0; i < newFirst.size(); i++){ 160 | newFirst.instance(i).setValue(newFirst.numAttributes() -1, second.instance(i).classValue()); 161 | newFirst.setClassIndex(newFirst.numAttributes() - 1); 162 | } 163 | return newFirst; 164 | } 165 | 166 | public Instances reAddClassificationNominal(Instances first, Instances second) throws Exception{ 167 | Add filter = new Add(); 168 | filter.setAttributeIndex("last"); 169 | filter.setNominalLabels("0,1"); 170 | filter.setAttributeName("NewNumeric"); 171 | filter.setInputFormat(first); 172 | Instances newFirst = Filter.useFilter(first, filter); 173 | 174 | for (int i = 0; i < newFirst.size(); i++){ 175 | newFirst.instance(i).setValue(newFirst.numAttributes() -1, second.instance(i).classValue()); 176 | newFirst.setClassIndex(newFirst.numAttributes() - 1); 177 | } 178 | return newFirst; 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/main/resources/arffs/car_test.arff: -------------------------------------------------------------------------------- 1 | @relation car-weka.filters.unsupervised.instance.RemovePercentage-P15.0 2 | 3 | @attribute buying {vhigh,high,med,low} 4 | @attribute maint {vhigh,high,med,low} 5 | @attribute doors {2.0,3.0,4.0,5more} 6 | @attribute persons {2.0,4.0,more} 7 | @attribute log_boot {small,med,big} 8 | @attribute safety {low,med,high} 9 | @attribute class {unacc,acc,vgood,good} 10 | 11 | @data 12 | low,low,2.0,more,small,low,unacc 13 | low,low,2.0,more,small,med,unacc 14 | low,low,2.0,more,small,high,unacc 15 | low,low,2.0,more,med,low,unacc 16 | low,low,2.0,more,med,med,acc 17 | low,low,2.0,more,med,high,good 18 | low,low,2.0,more,big,low,unacc 19 | low,low,2.0,more,big,med,good 20 | low,low,2.0,more,big,high,vgood 21 | low,low,3.0,2.0,small,low,unacc 22 | low,low,3.0,2.0,small,med,unacc 23 | low,low,3.0,2.0,small,high,unacc 24 | low,low,3.0,2.0,med,low,unacc 25 | low,low,3.0,2.0,med,med,unacc 26 | low,low,3.0,2.0,med,high,unacc 27 | low,low,3.0,2.0,big,low,unacc 28 | low,low,3.0,2.0,big,med,unacc 29 | low,low,3.0,2.0,big,high,unacc 30 | low,low,3.0,4.0,small,low,unacc 31 | low,low,3.0,4.0,small,med,acc 32 | low,low,3.0,4.0,small,high,good 33 | high,low,5more,2.0,small,med,unacc 34 | high,low,5more,2.0,small,high,unacc 35 | high,low,5more,2.0,med,low,unacc 36 | high,low,5more,2.0,med,med,unacc 37 | high,low,5more,2.0,med,high,unacc 38 | high,low,5more,2.0,big,low,unacc 39 | high,low,5more,2.0,big,med,unacc 40 | high,low,5more,2.0,big,high,unacc 41 | high,low,5more,4.0,small,low,unacc 42 | high,low,5more,4.0,small,med,unacc 43 | high,low,5more,4.0,small,high,acc 44 | high,low,5more,4.0,med,low,unacc 45 | high,low,5more,4.0,med,med,acc 46 | high,low,5more,4.0,med,high,acc 47 | high,low,5more,4.0,big,low,unacc 48 | high,low,5more,4.0,big,med,acc 49 | high,low,5more,4.0,big,high,acc 50 | high,low,5more,more,small,low,unacc 51 | high,low,5more,more,small,med,unacc 52 | high,low,5more,more,small,high,acc 53 | high,low,5more,more,med,low,unacc 54 | high,low,5more,more,med,med,acc 55 | high,low,5more,more,med,high,acc 56 | high,low,5more,more,big,low,unacc 57 | high,low,5more,more,big,med,acc 58 | high,low,5more,more,big,high,acc 59 | med,vhigh,2.0,2.0,small,low,unacc 60 | med,vhigh,2.0,2.0,small,med,unacc 61 | med,vhigh,2.0,2.0,small,high,unacc 62 | med,vhigh,2.0,2.0,med,low,unacc 63 | med,vhigh,2.0,2.0,med,med,unacc 64 | med,vhigh,2.0,2.0,med,high,unacc 65 | med,vhigh,2.0,2.0,big,low,unacc 66 | med,vhigh,2.0,2.0,big,med,unacc 67 | med,vhigh,2.0,2.0,big,high,unacc 68 | med,vhigh,2.0,4.0,small,low,unacc 69 | med,vhigh,2.0,4.0,small,med,unacc 70 | med,vhigh,2.0,4.0,small,high,acc 71 | med,vhigh,2.0,4.0,med,low,unacc 72 | med,vhigh,2.0,4.0,med,med,unacc 73 | med,vhigh,2.0,4.0,med,high,acc 74 | med,vhigh,2.0,4.0,big,low,unacc 75 | med,vhigh,2.0,4.0,big,med,acc 76 | med,vhigh,2.0,4.0,big,high,acc 77 | med,vhigh,2.0,more,small,low,unacc 78 | med,vhigh,2.0,more,small,med,unacc 79 | med,vhigh,2.0,more,small,high,unacc 80 | med,vhigh,2.0,more,med,low,unacc 81 | med,vhigh,2.0,more,med,med,unacc 82 | med,vhigh,2.0,more,med,high,acc 83 | med,vhigh,2.0,more,big,low,unacc 84 | vhigh,med,2.0,more,med,med,unacc 85 | vhigh,med,2.0,more,med,high,acc 86 | vhigh,med,2.0,more,big,low,unacc 87 | vhigh,med,2.0,more,big,med,acc 88 | vhigh,med,2.0,more,big,high,acc 89 | vhigh,med,3.0,2.0,small,low,unacc 90 | vhigh,med,3.0,2.0,small,med,unacc 91 | vhigh,med,3.0,2.0,small,high,unacc 92 | vhigh,med,3.0,2.0,med,low,unacc 93 | vhigh,med,3.0,2.0,med,med,unacc 94 | vhigh,med,3.0,2.0,med,high,unacc 95 | vhigh,med,3.0,2.0,big,low,unacc 96 | vhigh,med,3.0,2.0,big,med,unacc 97 | vhigh,med,3.0,2.0,big,high,unacc 98 | vhigh,med,3.0,4.0,small,low,unacc 99 | vhigh,med,3.0,4.0,small,med,unacc 100 | vhigh,med,3.0,4.0,small,high,acc 101 | vhigh,med,3.0,4.0,med,low,unacc 102 | vhigh,med,3.0,4.0,med,med,unacc 103 | vhigh,med,3.0,4.0,med,high,acc 104 | vhigh,med,3.0,4.0,big,low,unacc 105 | vhigh,med,3.0,4.0,big,med,acc 106 | vhigh,med,3.0,4.0,big,high,acc 107 | vhigh,med,3.0,more,small,low,unacc 108 | vhigh,med,3.0,more,small,med,unacc 109 | vhigh,med,3.0,more,small,high,acc 110 | vhigh,med,3.0,more,med,low,unacc 111 | vhigh,med,3.0,more,med,med,acc 112 | vhigh,med,3.0,more,med,high,acc 113 | vhigh,med,3.0,more,big,low,unacc 114 | vhigh,med,3.0,more,big,med,acc 115 | vhigh,med,3.0,more,big,high,acc 116 | vhigh,med,4.0,2.0,small,low,unacc 117 | vhigh,med,4.0,2.0,small,med,unacc 118 | vhigh,med,4.0,2.0,small,high,unacc 119 | vhigh,med,4.0,2.0,med,low,unacc 120 | vhigh,med,4.0,2.0,med,med,unacc 121 | vhigh,med,4.0,2.0,med,high,unacc 122 | vhigh,med,4.0,2.0,big,low,unacc 123 | vhigh,med,4.0,2.0,big,med,unacc 124 | vhigh,med,4.0,2.0,big,high,unacc 125 | vhigh,med,4.0,4.0,small,low,unacc 126 | vhigh,med,4.0,4.0,small,med,unacc 127 | vhigh,med,4.0,4.0,small,high,acc 128 | vhigh,med,4.0,4.0,med,low,unacc 129 | vhigh,med,4.0,4.0,med,med,acc 130 | vhigh,med,4.0,4.0,med,high,acc 131 | vhigh,med,4.0,4.0,big,low,unacc 132 | vhigh,med,4.0,4.0,big,med,acc 133 | vhigh,med,4.0,4.0,big,high,acc 134 | vhigh,med,4.0,more,small,low,unacc 135 | low,high,5more,2.0,med,high,unacc 136 | low,high,5more,2.0,big,low,unacc 137 | low,high,5more,2.0,big,med,unacc 138 | low,high,5more,2.0,big,high,unacc 139 | low,high,5more,4.0,small,low,unacc 140 | low,high,5more,4.0,small,med,acc 141 | low,high,5more,4.0,small,high,acc 142 | low,high,5more,4.0,med,low,unacc 143 | low,high,5more,4.0,med,med,acc 144 | low,high,5more,4.0,med,high,vgood 145 | low,high,5more,4.0,big,low,unacc 146 | low,high,5more,4.0,big,med,acc 147 | low,high,5more,4.0,big,high,vgood 148 | low,high,5more,more,small,low,unacc 149 | low,high,5more,more,small,med,acc 150 | low,high,5more,more,small,high,acc 151 | low,high,5more,more,med,low,unacc 152 | low,high,5more,more,med,med,acc 153 | low,high,5more,more,med,high,vgood 154 | low,high,5more,more,big,low,unacc 155 | low,high,5more,more,big,med,acc 156 | low,high,5more,more,big,high,vgood 157 | low,med,2.0,2.0,small,low,unacc 158 | low,med,2.0,2.0,small,med,unacc 159 | low,med,2.0,2.0,small,high,unacc 160 | low,med,2.0,2.0,med,low,unacc 161 | low,med,2.0,2.0,med,med,unacc 162 | low,med,2.0,2.0,med,high,unacc 163 | low,med,2.0,2.0,big,low,unacc 164 | low,med,2.0,2.0,big,med,unacc 165 | low,med,2.0,2.0,big,high,unacc 166 | med,high,2.0,more,small,low,unacc 167 | med,high,2.0,more,small,med,unacc 168 | med,high,2.0,more,small,high,unacc 169 | med,high,2.0,more,med,low,unacc 170 | med,high,2.0,more,med,med,unacc 171 | med,high,2.0,more,med,high,acc 172 | med,high,2.0,more,big,low,unacc 173 | med,high,2.0,more,big,med,acc 174 | med,high,2.0,more,big,high,acc 175 | med,high,3.0,2.0,small,low,unacc 176 | med,high,3.0,2.0,small,med,unacc 177 | med,high,3.0,2.0,small,high,unacc 178 | med,high,3.0,2.0,med,low,unacc 179 | med,high,3.0,2.0,med,med,unacc 180 | med,high,3.0,2.0,med,high,unacc 181 | med,high,3.0,2.0,big,low,unacc 182 | med,high,3.0,2.0,big,med,unacc 183 | med,high,3.0,2.0,big,high,unacc 184 | med,high,3.0,4.0,small,low,unacc 185 | med,high,3.0,4.0,small,med,unacc 186 | med,high,3.0,4.0,small,high,acc 187 | med,high,3.0,4.0,med,low,unacc 188 | med,high,3.0,4.0,med,med,unacc 189 | med,high,3.0,4.0,med,high,acc 190 | med,high,3.0,4.0,big,low,unacc 191 | med,high,3.0,4.0,big,med,acc 192 | med,high,3.0,4.0,big,high,acc 193 | med,high,3.0,more,small,low,unacc 194 | med,high,3.0,more,small,med,unacc 195 | med,high,3.0,more,small,high,acc 196 | med,high,3.0,more,med,low,unacc 197 | med,high,3.0,more,med,med,acc 198 | med,high,3.0,more,med,high,acc 199 | med,high,3.0,more,big,low,unacc 200 | med,high,3.0,more,big,med,acc 201 | med,high,3.0,more,big,high,acc 202 | med,high,4.0,2.0,small,low,unacc 203 | med,high,4.0,2.0,small,med,unacc 204 | med,high,4.0,2.0,small,high,unacc 205 | med,high,4.0,2.0,med,low,unacc 206 | med,high,4.0,2.0,med,med,unacc 207 | med,high,4.0,2.0,med,high,unacc 208 | med,high,4.0,2.0,big,low,unacc 209 | med,high,4.0,2.0,big,med,unacc 210 | med,high,4.0,2.0,big,high,unacc 211 | med,high,4.0,4.0,small,low,unacc 212 | med,high,4.0,4.0,small,med,unacc 213 | med,high,4.0,4.0,small,high,acc 214 | med,high,4.0,4.0,med,low,unacc 215 | med,high,4.0,4.0,med,med,acc 216 | med,high,4.0,4.0,med,high,acc 217 | vhigh,vhigh,3.0,4.0,small,high,unacc 218 | vhigh,vhigh,3.0,4.0,med,low,unacc 219 | vhigh,vhigh,3.0,4.0,med,med,unacc 220 | vhigh,vhigh,3.0,4.0,med,high,unacc 221 | vhigh,vhigh,3.0,4.0,big,low,unacc 222 | vhigh,vhigh,3.0,4.0,big,med,unacc 223 | vhigh,vhigh,3.0,4.0,big,high,unacc 224 | vhigh,vhigh,3.0,more,small,low,unacc 225 | vhigh,vhigh,3.0,more,small,med,unacc 226 | vhigh,vhigh,3.0,more,small,high,unacc 227 | vhigh,vhigh,3.0,more,med,low,unacc 228 | vhigh,vhigh,3.0,more,med,med,unacc 229 | vhigh,vhigh,3.0,more,med,high,unacc 230 | vhigh,vhigh,3.0,more,big,low,unacc 231 | vhigh,vhigh,3.0,more,big,med,unacc 232 | vhigh,vhigh,3.0,more,big,high,unacc 233 | vhigh,vhigh,4.0,2.0,small,low,unacc 234 | vhigh,vhigh,4.0,2.0,small,med,unacc 235 | vhigh,vhigh,4.0,2.0,small,high,unacc 236 | vhigh,vhigh,4.0,2.0,med,low,unacc 237 | vhigh,vhigh,4.0,2.0,med,med,unacc 238 | vhigh,vhigh,4.0,2.0,med,high,unacc 239 | vhigh,vhigh,4.0,2.0,big,low,unacc 240 | vhigh,vhigh,4.0,2.0,big,med,unacc 241 | vhigh,vhigh,4.0,2.0,big,high,unacc 242 | vhigh,vhigh,4.0,4.0,small,low,unacc 243 | med,low,4.0,more,big,med,good 244 | med,low,4.0,more,big,high,vgood 245 | med,low,5more,2.0,small,low,unacc 246 | med,low,5more,2.0,small,med,unacc 247 | med,low,5more,2.0,small,high,unacc 248 | med,low,5more,2.0,med,low,unacc 249 | med,low,5more,2.0,med,med,unacc 250 | med,low,5more,2.0,med,high,unacc 251 | med,low,5more,2.0,big,low,unacc 252 | med,low,5more,2.0,big,med,unacc 253 | med,low,5more,2.0,big,high,unacc 254 | med,low,5more,4.0,small,low,unacc 255 | med,low,5more,4.0,small,med,acc 256 | med,low,5more,4.0,small,high,good 257 | med,low,5more,4.0,med,low,unacc 258 | med,low,5more,4.0,med,med,good 259 | med,low,5more,4.0,med,high,vgood 260 | med,low,5more,4.0,big,low,unacc 261 | med,low,5more,4.0,big,med,good 262 | med,low,5more,4.0,big,high,vgood 263 | med,low,5more,more,small,low,unacc 264 | med,low,5more,more,small,med,acc 265 | med,low,5more,more,small,high,good 266 | med,low,5more,more,med,low,unacc 267 | med,low,5more,more,med,med,good 268 | med,low,5more,more,med,high,vgood 269 | med,low,5more,more,big,low,unacc 270 | med,low,5more,more,big,med,good 271 | med,low,5more,more,big,high,vgood 272 | low,vhigh,2.0,2.0,small,low,unacc 273 | low,vhigh,2.0,2.0,small,med,unacc 274 | -------------------------------------------------------------------------------- /src/main/java/com/derek/ml/services/ClusterService.java: -------------------------------------------------------------------------------- 1 | package com.derek.ml.services; 2 | 3 | import com.derek.ml.models.EMModel; 4 | import com.derek.ml.models.Cluster; 5 | import com.derek.ml.models.ML; 6 | import com.derek.ml.models.Options; 7 | import org.springframework.beans.factory.annotation.Autowired; 8 | import org.springframework.stereotype.Service; 9 | import weka.clusterers.ClusterEvaluation; 10 | import weka.clusterers.MakeDensityBasedClusterer; 11 | import weka.clusterers.SimpleKMeans; 12 | import weka.clusterers.EM; 13 | import weka.core.EuclideanDistance; 14 | import weka.core.Instances; 15 | import weka.core.ManhattanDistance; 16 | import weka.filters.Filter; 17 | import weka.filters.unsupervised.attribute.Remove; 18 | 19 | import java.io.FileWriter; 20 | 21 | @Service 22 | public class ClusterService { 23 | 24 | @Autowired 25 | private FileFactory fileFactory; 26 | 27 | @Autowired 28 | private EvaluationService evaluationService; 29 | 30 | @Autowired 31 | FeatureReductionService featureReductionService; 32 | 33 | public String handleKmeans(Cluster cluster) throws Exception{ 34 | FileFactory.TrainTest trainTest = fileFactory.getInstancesFromFile(cluster.getFileName(), new Options(false, true)); 35 | Instances data = trainTest.train; 36 | if (cluster.getFeatureSelection() != null){ 37 | data = applyFeatureSelection(trainTest, cluster); 38 | } 39 | MakeDensityBasedClusterer simpleKMeans = makeDensityBasedClustererWrapper(trainKmeans(cluster, data), data); 40 | 41 | ClusterEvaluation clusterEvaluation = new ClusterEvaluation(); 42 | clusterEvaluation.setClusterer(simpleKMeans); 43 | Instances instancesToEvaluate = trainTest.test; 44 | if (cluster.getFeatureSelection() != null){ 45 | instancesToEvaluate = featureReductionService.reAddClassificationNominal(data, trainTest.test); 46 | } 47 | clusterEvaluation.evaluateClusterer(instancesToEvaluate); 48 | return clusterEvaluation.clusterResultsToString() + "\n \n \n Log Likelihood : " + clusterEvaluation.getLogLikelihood(); 49 | } 50 | 51 | public SimpleKMeans trainKmeans(Cluster cluster, Instances data) throws Exception{ 52 | SimpleKMeans simpleKMeans = new SimpleKMeans(); 53 | simpleKMeans.setPreserveInstancesOrder(true); 54 | simpleKMeans.setNumClusters(cluster.getClusters()); 55 | if (cluster.getDistances() == Cluster.Distances.Euclidean){ 56 | simpleKMeans.setDistanceFunction(new EuclideanDistance()); 57 | } else { 58 | simpleKMeans.setDistanceFunction(new ManhattanDistance()); 59 | } 60 | simpleKMeans.setMaxIterations(cluster.getIterations()); 61 | simpleKMeans.buildClusterer(data); 62 | return simpleKMeans; 63 | } 64 | 65 | public String handleEM(Cluster emModel) throws Exception{ 66 | FileFactory.TrainTest trainTest = fileFactory.getInstancesFromFile(emModel.getFileName(), new Options(false, true)); 67 | Instances data = trainTest.train; 68 | if (emModel.getFeatureSelection() != null){ 69 | data = applyFeatureSelection(trainTest, emModel); 70 | } 71 | EM em = trainEm(emModel, data); 72 | 73 | ClusterEvaluation clusterEvaluation = new ClusterEvaluation(); 74 | clusterEvaluation.setClusterer(em); 75 | clusterEvaluation.evaluateClusterer(featureReductionService.reAddClassificationNominal(data, trainTest.test)); 76 | 77 | return clusterEvaluation.clusterResultsToString(); 78 | } 79 | 80 | public EM trainEm(Cluster emModel, Instances data) throws Exception{ 81 | EM em = new EM(); 82 | em.setMaxIterations(emModel.getIterations()); 83 | em.setNumClusters(emModel.getClusters()); 84 | em.buildClusterer(data); 85 | return em; 86 | } 87 | 88 | public void plotKM(Cluster cluster, Instances instances, String name) throws Exception{ 89 | SimpleKMeans simpleKMeans = trainKmeans(cluster, instances); 90 | int[] assignments = simpleKMeans.getAssignments(); 91 | 92 | try{ 93 | FileWriter writer = new FileWriter(name + ".csv", true); 94 | for (int i = 0; i < assignments.length; i++){ 95 | double[] values = instances.get(i).toDoubleArray(); 96 | writer.append(new Double(values[0]).toString()); 97 | writer.append(","); 98 | writer.append(new Double(values[1]).toString()); 99 | writer.append(","); 100 | writer.append(new Integer(assignments[i]).toString()); 101 | writer.append("\n"); 102 | } 103 | writer.flush(); 104 | writer.close(); 105 | 106 | } catch (Exception e){ 107 | System.out.println(e.toString()); 108 | } 109 | } 110 | 111 | public void plotEM(Cluster emModel, Instances instances, String name) throws Exception{ 112 | EM simpleEM = trainEm(emModel, instances); 113 | ClusterEvaluation clusterEvaluation = new ClusterEvaluation(); 114 | clusterEvaluation.setClusterer(simpleEM); 115 | clusterEvaluation.evaluateClusterer(instances); 116 | double[] assignments = clusterEvaluation.getClusterAssignments(); 117 | 118 | try{ 119 | FileWriter writer = new FileWriter(name + ".csv", true); 120 | for (int i = 0; i < assignments.length; i++){ 121 | double[] values = instances.get(i).toDoubleArray(); 122 | writer.append(new Double(values[0]).toString()); 123 | writer.append(","); 124 | writer.append(new Double(values[1]).toString()); 125 | writer.append(","); 126 | writer.append(new Double(assignments[i]).toString()); 127 | writer.append("\n"); 128 | } 129 | writer.flush(); 130 | writer.close(); 131 | 132 | } catch (Exception e){ 133 | System.out.println(e.toString()); 134 | } 135 | } 136 | 137 | public void plotKMWithFeature() throws Exception{ 138 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options(false, true)); 139 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options(false, true)); 140 | FileFactory.TrainTest carBin = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options(false, true)); 141 | FileFactory.TrainTest censusBin = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options(false, true)); 142 | 143 | Instances pcaCar = featureReductionService.applyPCAFilter(carTrainTest.train, 30); 144 | Instances pcaCensus = featureReductionService.applyPCAFilter(censusTrainTest.train, 30); 145 | Instances icaCar = featureReductionService.applyICA(carBin.test, 30); 146 | Instances icaCensus = featureReductionService.applyICA(censusBin.test, 30); 147 | Instances rpCar = featureReductionService.applyRP(carBin.train, 30); 148 | Instances rpCensus = featureReductionService.applyRP(censusBin.train, 30); 149 | 150 | Cluster cluster = new Cluster(); 151 | cluster.setClusters(6); 152 | cluster.setIterations(1000); 153 | 154 | plotKM(cluster, pcaCar, "PCA_CAR"); 155 | plotKM(cluster, pcaCensus, "PCA_CENSUS"); 156 | plotKM(cluster, filterClass(icaCar), "ICA_CAR"); 157 | plotKM(cluster, filterClass(icaCensus), "ICA_CENSUS"); 158 | plotKM(cluster, rpCar, "RP_CAR"); 159 | plotKM(cluster, rpCensus, "RP_CENSUS"); 160 | } 161 | 162 | public void plotEMWithFeature() throws Exception { 163 | FileFactory.TrainTest carTrainTest = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options(false, true)); 164 | FileFactory.TrainTest censusTrainTest = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options(false, true)); 165 | FileFactory.TrainTest carBin = fileFactory.getInstancesFromFile(ML.Files.CarBin, new Options(false, true)); 166 | FileFactory.TrainTest censusBin = fileFactory.getInstancesFromFile(ML.Files.CensusBin, new Options(false, true)); 167 | 168 | Instances pcaCar = featureReductionService.applyPCAFilter(carTrainTest.train, 30); 169 | Instances pcaCensus = featureReductionService.applyPCAFilter(censusTrainTest.train, 30); 170 | Instances icaCar = featureReductionService.applyICA(carBin.test, 30); 171 | Instances icaCensus = featureReductionService.applyICA(censusBin.test, 30); 172 | Instances rpCar = featureReductionService.applyRP(carBin.train, 30); 173 | Instances rpCensus = featureReductionService.applyRP(censusBin.train, 30); 174 | 175 | Cluster em = new Cluster(); 176 | em.setClusters(6); 177 | em.setIterations(1000); 178 | 179 | plotEM(em, pcaCar, "PCA_CAR"); 180 | plotEM(em, pcaCensus, "PCA_CENSUS"); 181 | plotEM(em, filterClass(icaCar), "ICA_CAR"); 182 | plotEM(em, filterClass(icaCensus), "ICA_CENSUS"); 183 | plotEM(em, rpCar, "RP_CAR"); 184 | plotEM(em, rpCensus, "RP_CENSUS"); 185 | } 186 | 187 | private Instances filterClass(Instances data) throws Exception{ 188 | Remove filter = new Remove(); 189 | filter.setAttributeIndices("" + (data.classIndex() + 1)); 190 | filter.setInputFormat(data); 191 | return Filter.useFilter(data, filter); 192 | } 193 | 194 | private MakeDensityBasedClusterer makeDensityBasedClustererWrapper(SimpleKMeans simpleKMeans, Instances data) throws Exception{ 195 | MakeDensityBasedClusterer makeDensityBasedClusterer = new MakeDensityBasedClusterer(); 196 | makeDensityBasedClusterer.setClusterer(simpleKMeans); 197 | makeDensityBasedClusterer.buildClusterer(data); 198 | return makeDensityBasedClusterer; 199 | } 200 | 201 | private Instances applyFeatureSelection(FileFactory.TrainTest data, Cluster cluster) throws Exception{ 202 | switch (cluster.getFeatureSelection()){ 203 | case ICA: 204 | return filterClass(featureReductionService.applyICA(data.test, 5)); 205 | case PCA: 206 | return featureReductionService.applyPCAFilter(data.train, 5); 207 | case RP: 208 | return featureReductionService.applyRP(data.train, 5); 209 | case CFS: 210 | return null; 211 | } 212 | return null; 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /students-filters-master/src/main/java/filters/FastICA.java: -------------------------------------------------------------------------------- 1 | /* 2 | * This is free and unencumbered software released into the public domain. 3 | * 4 | * Anyone is free to copy, modify, publish, use, compile, sell, or 5 | * distribute this software, either in source code form or as a compiled 6 | * binary, for any purpose, commercial or non-commercial, and by any 7 | * means. 8 | * 9 | * In jurisdictions that recognize copyright laws, the author or authors 10 | * of this software dedicate any and all copyright interest in the 11 | * software to the public domain. We make this dedication for the benefit 12 | * of the public at large and to the detriment of our heirs and 13 | * successors. We intend this dedication to be an overt act of 14 | * relinquishment in perpetuity of all present and future rights to this 15 | * software under copyright law. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | * IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 21 | * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 22 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | * OTHER DEALINGS IN THE SOFTWARE. 24 | * 25 | * For more information, please refer to 26 | */ 27 | 28 | package filters; 29 | 30 | import java.util.Random; 31 | 32 | import org.ejml.simple.SimpleEVD; 33 | import org.ejml.simple.SimpleMatrix; 34 | import org.ejml.simple.SimpleSVD; 35 | 36 | 37 | 38 | /** 39 | * FastICA port of Python scikit-learn implementation. 40 | * 41 | * @author Chris Gearhart 42 | * 43 | */ 44 | public class FastICA { 45 | 46 | // The estimated unmixing matrix 47 | private SimpleMatrix W; 48 | 49 | // The pre-whitening matrix 50 | private SimpleMatrix K; 51 | 52 | // the product of K and W 53 | private SimpleMatrix KW; 54 | 55 | // The data matrix 56 | private SimpleMatrix X; 57 | 58 | // The mean value of each column of the input matrix 59 | private SimpleMatrix X_means; 60 | 61 | // Reference to non-linear neg-entropy estimator function 62 | private NegativeEntropyEstimator G; 63 | 64 | // Number of components to output 65 | private int num_components; 66 | 67 | /** number of rows (instances) in X */ 68 | private int m; 69 | 70 | /** number of columns (features) in X */ 71 | private int n; 72 | 73 | // Convergence tolerance 74 | private final double tolerance; 75 | 76 | // Iteration limit 77 | private final int max_iter; 78 | 79 | // Whiten the data if true 80 | private final boolean whiten; 81 | 82 | /** 83 | * General FastICA instance constructor with an arbitrary (user-supplied) 84 | * function to estimate negative entropy. This implementation does not 85 | * perform automatic component selection or reduction. 86 | * 87 | * @param data 2d array of doubles containing the source data; 88 | * each column contains a single signal, and each row contains 89 | * one sample of all signals (rows contain instances, 90 | * columns are features) 91 | * @param g_func {@link NegativeEntropyEstimator} to estimate 92 | * negative entropy 93 | * @param tolerance maximum allowable convergence error 94 | * @param max_iter max number of iterations 95 | * @param whiten whiten the data matrix (default true) 96 | */ 97 | public FastICA(NegativeEntropyEstimator g_func, 98 | double tolerance, int max_iter, boolean whiten) { 99 | this.G = g_func; 100 | this.tolerance = tolerance; 101 | this.max_iter = max_iter; 102 | this.whiten = whiten; 103 | } 104 | 105 | /** 106 | * FastICA instance using the LogCosh() function to estimate negative 107 | * entropy and mean-centering with opitonal whitening of the data matrix. 108 | * This implementation does not perform automatic component selection or 109 | * reduction. 110 | * 111 | * @param tolerance maximum allowable convergence error 112 | * @param max_iter max number of iterations 113 | * @param whiten whiten the data matrix (default true) 114 | */ 115 | public FastICA(double tolerance, int max_iter, boolean whiten) { 116 | this(new LogCosh(), tolerance, max_iter, whiten); 117 | } 118 | 119 | /** 120 | * FastICA instance using the LogCosh() function to estimate negative 121 | * entropy and whitening/mean-centering of the data matrix. This 122 | * implementation does not perform automatic component selection or 123 | * reduction. 124 | * 125 | * @param tolerance maximum allowable convergence error 126 | * @param max_iter max number of iterations 127 | */ 128 | public FastICA(double tolerance, int max_iter) { 129 | this(new LogCosh(), tolerance, max_iter, true); 130 | } 131 | 132 | /** 133 | * Default FastICA instance using the LogCosh() function to estimate 134 | * negative entropy and whitening/mean-centering of the data matrix with 135 | * simple default values. 136 | */ 137 | public FastICA() { 138 | this(new LogCosh(), 1E-5, 200, true); 139 | } 140 | 141 | /** 142 | * Return the matrix that projects the data into the ICA domain 143 | * 144 | * @return W double[][] matrix containing estimated independent component 145 | * projection matrix 146 | */ 147 | public double[][] getW() { 148 | return FastICA.mToA(W); 149 | } 150 | 151 | /** 152 | * Return the pre-whitening matrix that was used on the data (defaults to 153 | * the identity matrix) 154 | * 155 | * @return K double[][] matrix containing the pre-whitening matrix 156 | */ 157 | public double[][] getK() { 158 | return FastICA.mToA(K); 159 | } 160 | 161 | /** 162 | * Return the estimated mixing matrix that maps sources to the data domain 163 | * 164 | * S * em = X 165 | * 166 | * @return em double[][] matrix containing the estimated mixing matrix 167 | */ 168 | public double[][] getEM() { 169 | return FastICA.mToA(K.mult(W).pseudoInverse()); 170 | } 171 | 172 | /** 173 | * Project a row-indexed matrix of data into the ICA domain by applying 174 | * the pre-whitening and un-mixing matrices. This method should not be 175 | * called prior to running fit() with input data. 176 | * 177 | * @param data rectangular double[][] array containing values; the 178 | * number of columns should match the data provided to the 179 | * fit() method for training 180 | * @return result rectangular double[][] array containing the projected 181 | * output values 182 | */ 183 | public double[][] transform(double[][] data) { 184 | SimpleMatrix x = new SimpleMatrix(data); 185 | return FastICA.mToA(x.minus(X_means).mult(KW)); 186 | } 187 | 188 | /** 189 | * Estimate the unmixing matrix for the data provided 190 | * 191 | * @param data - 2d array of doubles containing the data; each column 192 | * contains a single signal, and each row contains one sample of all 193 | * signals (rows contain instances, columns are features) 194 | */ 195 | public void fit(double[][] data, int num_components) throws Exception { 196 | X = new SimpleMatrix(data); 197 | m = X.numRows(); 198 | n = X.numCols(); 199 | 200 | // mean center the attributes in X 201 | double[] means = center(X); 202 | X_means = new SimpleMatrix(new double[][]{means}); 203 | 204 | // get the size parameter of the symmetric W matrix; size cannot be 205 | // larger than the number of samples or the number of features 206 | this.num_components = Math.min(Math.min(m, n), num_components); 207 | 208 | K = SimpleMatrix.identity(this.num_components); // init K 209 | if (this.whiten) { 210 | X = whiten(X); // sets K 211 | } 212 | 213 | // start with an orthogonal initial W matrix drawn from a standard Normal distribution 214 | W = symmetricDecorrelation(gaussianSquareMatrix(num_components)); 215 | 216 | // fit the data 217 | parallel_ica(); // solves for W 218 | 219 | // Store the resulting transformation matrix 220 | KW = K.mult(W); 221 | 222 | } 223 | 224 | /* 225 | * FastICA main loop - using default symmetric decorrelation. (i.e., 226 | * estimate all the independent components in parallel) 227 | */ 228 | private void parallel_ica() throws Exception { 229 | 230 | double tmp; 231 | double lim; 232 | SimpleMatrix W_next; 233 | SimpleMatrix newRow; 234 | SimpleMatrix oldRow; 235 | 236 | for (int iter = 0; iter < max_iter; iter++) { 237 | 238 | // Estimate the negative entropy and first derivative average 239 | G.estimate(X.mult(W)); 240 | 241 | // Update the W matrix 242 | W_next = X.transpose().mult(G.getGx()).scale(1. / new Double(n)); 243 | for (int i = 0; i < num_components; i++) { 244 | newRow = W_next.extractVector(true, i); 245 | oldRow = W.extractVector(true, i); 246 | W_next.insertIntoThis(i, 0, 247 | newRow.minus(oldRow.elementMult(G.getG_x()))); 248 | } 249 | W_next = symmetricDecorrelation(W_next); 250 | 251 | // Test convergence criteria for W 252 | lim = 0; 253 | for (int i = 0; i < W.numRows(); i++) { 254 | newRow = W_next.extractVector(true, i); 255 | oldRow = W.extractVector(true, i); 256 | tmp = newRow.dot(oldRow.transpose()); 257 | tmp = Math.abs(Math.abs(tmp) - 1); 258 | if (tmp > lim) { 259 | lim = tmp; 260 | } 261 | } 262 | W = W_next; 263 | 264 | if (lim < tolerance) { 265 | return; 266 | } 267 | } 268 | 269 | throw new Exception("ICA did not converge - try again with more iterations."); 270 | } 271 | 272 | /* 273 | * Whiten a matrix of column vectors by decorrelating and scaling the 274 | * elements according to: x_new = ED^{-1/2}E'x , where E is the 275 | * orthogonal matrix of eigenvectors of E{xx'}. In this implementation 276 | * (based on the FastICA sklearn Python package) the eigen decomposition is 277 | * replaced with the SVD. 278 | * 279 | * The decomposition is ambiguous with regard to the direction of 280 | * column vectors (they can be either +/- without changing the result). 281 | */ 282 | @SuppressWarnings("rawtypes") 283 | private SimpleMatrix whiten(SimpleMatrix x) { 284 | // get compact SVD (D matrix is min(m,n) square) 285 | SimpleSVD svd = x.svd(true); 286 | 287 | // K should only keep `num_components` columns if performing 288 | // dimensionality reduction 289 | K = svd.getV().mult(svd.getW().invert()) 290 | .extractMatrix(0, x.numCols(), 0, num_components); 291 | // K = K.scale(-1); // sklearn returns this version for K; doesn't affect results 292 | 293 | // return x.mult(K).scale(Math.sqrt(m)); // sklearn scales the input 294 | return x.mult(K); 295 | } 296 | 297 | /* 298 | * Center the input matrix and store it in X by subtracting the average of 299 | * each column vector from every element in the column 300 | */ 301 | private double[] center(SimpleMatrix x) { 302 | SimpleMatrix col; 303 | int numrows = x.numRows(); 304 | int numcols = x.numCols(); 305 | double[] means; 306 | 307 | means = new double[numcols]; 308 | for (int i = 0; i < numcols; i++) { 309 | col = x.extractVector(false, i); 310 | means[i] = col.elementSum() / new Double(numrows); 311 | for (int j = 0; j < numrows; j++) { 312 | col.set(j, col.get(j) - means[i]); 313 | } 314 | X.insertIntoThis(0, i, col); 315 | } 316 | 317 | return means; 318 | } 319 | 320 | /* 321 | * Perform symmetric decorrelation on the input matrix to ensure that each 322 | * column is independent from all the others. This is required in order 323 | * to prevent FastICA from solving for the same components in multiple 324 | * columns. 325 | * 326 | * NOTE: There are only real eigenvalues for the W matrix 327 | * 328 | * W <- W * (W.T * W)^{-1/2} 329 | * 330 | * Python (Numpy): 331 | * s, u = linalg.eigh(np.dot(W.T, W)) 332 | * W = np.dot(W, np.dot(u * (1. / np.sqrt(s)), u)) 333 | * Matlab: 334 | * B = B * real(inv(B' * B)^(1/2)) 335 | * 336 | */ 337 | @SuppressWarnings("rawtypes") 338 | private static SimpleMatrix symmetricDecorrelation(SimpleMatrix x) { 339 | 340 | double d; 341 | SimpleMatrix QL; 342 | SimpleMatrix Q; 343 | 344 | SimpleEVD evd = x.transpose().mult(x).eig(); 345 | int len = evd.getNumberOfEigenvalues(); 346 | QL = new SimpleMatrix(len, len); 347 | Q = new SimpleMatrix(len, len); 348 | 349 | // Scale each column of the eigenvector matrix by the square root of 350 | // the reciprocal of the associated eigenvalue 351 | for (int i = 0; i < len; i++) { 352 | d = evd.getEigenvalue(i).getReal(); 353 | d = (d + Math.abs(d)) / 2; // improve numerical stability by eliminating small negatives near singular matrix zeros 354 | QL.insertIntoThis(0, i, evd.getEigenVector(i).divide(Math.sqrt(d))); 355 | Q.insertIntoThis(0, i, evd.getEigenVector(i)); 356 | } 357 | 358 | return x.mult(QL.mult(Q.transpose())); 359 | } 360 | 361 | /* 362 | * Randomly generate a square matrix drawn from a standard gaussian 363 | * distribution. 364 | */ 365 | private static SimpleMatrix gaussianSquareMatrix(int size) { 366 | SimpleMatrix ret = new SimpleMatrix(size, size); 367 | Random rand = new Random(); 368 | for (int i = 0; i < size; i++) { 369 | for (int j = 0; j < size; j++) { 370 | ret.set(i, j, rand.nextGaussian()); 371 | } 372 | } 373 | return ret; 374 | } 375 | 376 | /* 377 | * Convert a {@link SimpleMatrix} to a 2d array of double[][] 378 | */ 379 | private static double[][] mToA(SimpleMatrix x) { 380 | double[][] result = new double[x.numRows()][]; 381 | for (int i = 0; i < x.numRows(); i++) { 382 | result[i] = new double[x.numCols()]; 383 | for (int j = 0; j < x.numCols(); j++) { 384 | result[i][j] = x.get(i, j); 385 | } 386 | } 387 | return result; 388 | } 389 | 390 | } 391 | -------------------------------------------------------------------------------- /students-filters-master/src/main/java/weka/filters/unsupervised/attribute/IndependentComponents.java: -------------------------------------------------------------------------------- 1 | /* 2 | * This is free and unencumbered software released into the public domain. 3 | * 4 | * Anyone is free to copy, modify, publish, use, compile, sell, or 5 | * distribute this software, either in source code form or as a compiled 6 | * binary, for any purpose, commercial or non-commercial, and by any 7 | * means. 8 | * 9 | * In jurisdictions that recognize copyright laws, the author or authors 10 | * of this software dedicate any and all copyright interest in the 11 | * software to the public domain. We make this dedication for the benefit 12 | * of the public at large and to the detriment of our heirs and 13 | * successors. We intend this dedication to be an overt act of 14 | * relinquishment in perpetuity of all present and future rights to this 15 | * software under copyright law. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | * IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 21 | * OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 22 | * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 23 | * OTHER DEALINGS IN THE SOFTWARE. 24 | * 25 | * For more information, please refer to 26 | */ 27 | 28 | package weka.filters.unsupervised.attribute; 29 | 30 | import java.util.ArrayList; 31 | import java.util.Enumeration; 32 | import java.util.Vector; 33 | 34 | import filters.FastICA; 35 | 36 | 37 | import weka.core.Attribute; 38 | import weka.core.AttributeStats; 39 | import weka.core.Capabilities; 40 | import weka.core.DenseInstance; 41 | import weka.core.Instance; 42 | import weka.core.Instances; 43 | import weka.core.Option; 44 | import weka.core.OptionHandler; 45 | import weka.core.RevisionUtils; 46 | import weka.core.Capabilities.Capability; 47 | import weka.core.Utils; 48 | import weka.filters.Filter; 49 | import weka.filters.UnsupervisedFilter; 50 | import weka.filters.unsupervised.attribute.Remove; 51 | 52 | // TODO: add support for negative entropy estimator option 53 | public class IndependentComponents 54 | extends Filter 55 | implements OptionHandler, UnsupervisedFilter{ 56 | 57 | /** for serialization. */ 58 | private static final long serialVersionUID = -5416810876710954131L; 59 | 60 | protected FastICA m_filter; 61 | 62 | /** If true, whiten input data. */ 63 | protected boolean m_whiten = true; 64 | 65 | /** Number of attributes to include. */ 66 | protected int m_numAttributes = -1; 67 | 68 | /** Maximum number of FastICA iterations. */ 69 | protected int m_numIterations = 200; 70 | 71 | /** Error tolerance for convergence. */ 72 | protected double m_tolerance = 1E-4; 73 | 74 | /** True when the instances sent to determineOutputFormat() has a class attribute */ 75 | protected boolean m_hasClass; 76 | 77 | public String globalInfo() { 78 | return "Performs Independent Component Analysis and transformation " + 79 | "of numeric data using the FastICA algorithm while ignoring " + 80 | "the class label."; 81 | } 82 | 83 | public String whitenDataTipText() { 84 | return "Whiten the data (decoupling transform) if set."; 85 | } 86 | 87 | public void setWhitenData(boolean flag) { 88 | m_whiten = flag; 89 | } 90 | 91 | public boolean getWhitenData() { 92 | return m_whiten; 93 | } 94 | 95 | public String numAttributesTipText() { 96 | return "Number of separate sources to identify in the output." + 97 | " (-1 = include all; default: -1)"; 98 | } 99 | 100 | public void setNumAttributes(int num) { 101 | m_numAttributes = num; 102 | } 103 | 104 | public int getNumAttributes() { 105 | return m_numAttributes; 106 | } 107 | 108 | public String numIterationsTipText() { 109 | return "The maximum number of iterations of the FastICA main loop to allow."; 110 | } 111 | 112 | public void setNumIterations(int num) { 113 | m_numIterations = num; 114 | } 115 | 116 | public int getNumIterations() { 117 | return m_numIterations; 118 | } 119 | 120 | public String toleranceTipText() { 121 | return "Error tolerance for solution convergence."; 122 | } 123 | 124 | public void setTolerance(double tolerance) { 125 | m_tolerance = tolerance; 126 | } 127 | 128 | public double getTolerance() { 129 | return m_tolerance; 130 | } 131 | 132 | /** 133 | * Returns the capabilities of this evaluator. 134 | * 135 | * @return the capabilities of this evaluator 136 | * @see Capabilities 137 | */ 138 | public Capabilities getCapabilities() { 139 | Capabilities result = super.getCapabilities(); 140 | result.disableAll(); 141 | 142 | // attributes 143 | result.enable(Capability.NUMERIC_ATTRIBUTES); 144 | 145 | // class 146 | result.enable(Capability.NOMINAL_CLASS); 147 | result.enable(Capability.NUMERIC_CLASS); 148 | result.enable(Capability.DATE_CLASS); 149 | result.enable(Capability.MISSING_CLASS_VALUES); 150 | result.enable(Capability.NO_CLASS); 151 | 152 | return result; 153 | } 154 | 155 | /** 156 | * Returns an enumeration describing the available options. 157 | * 158 | * @return an enumeration of all the available options. 159 | */ 160 | public Enumeration