├── README.md ├── pom.xml └── src └── main ├── java └── com │ └── oreilly │ └── dswj │ ├── dataops │ ├── Batch.java │ ├── CrossEntropyLossFunction.java │ ├── Dictionary.java │ ├── HashingDictionary.java │ ├── IrisPCAExample.java │ ├── LabelEncoder.java │ ├── LinearLossFunction.java │ ├── ListResampler.java │ ├── LossFunction.java │ ├── MatrixResampler.java │ ├── MatrixScaleType.java │ ├── MatrixScaler.java │ ├── MatrixScalingOperator.java │ ├── PCA.java │ ├── PCAEIGImplementation.java │ ├── PCAImplementation.java │ ├── PCASVDImplementation.java │ ├── ProbabilityEncoder.java │ ├── QuadraticLossFunction.java │ ├── SimpleTokenizer.java │ ├── SoftMaxCrossEntropyLossFunction.java │ ├── TFIDF.java │ ├── TFIDFVectorizer.java │ ├── TermDictionary.java │ ├── Tokenizer.java │ ├── TwoPointLossFunction.java │ ├── Vectorizer.java │ └── VectorizerExample.java │ ├── datasets │ ├── Anscombe.java │ ├── Iris.java │ ├── MNIST.java │ ├── MultiNormalMixtureDataset.java │ └── Sentiment.java │ ├── io │ ├── BasicBarChart.java │ ├── BasicScatterChart.java │ ├── DBApp.java │ ├── DBInsertBatchApp.java │ ├── FileIOExample.java │ ├── FitPlot.java │ └── Record.java │ ├── learn │ ├── BernoulliConditionalProbabilityEstimator.java │ ├── ClassifierAccuracy.java │ ├── ConditionalProbabilityEstimator.java │ ├── DBSCANOptimizeExample.java │ ├── DBSCANPlotExample.java │ ├── DeepNetwork.java │ ├── DeepNetworkIrisExample.java │ ├── DeepNetworkMNISTExample.java │ ├── DeltaRule.java │ ├── GaussianConditionalProbabilityEstimator.java │ ├── GaussianMixtureClusteringExample.java │ ├── GradientDescent.java │ ├── GradientDescentMomentum.java │ ├── IterativeLearningProcess.java │ ├── KMeansExample.java │ ├── LinearModel.java │ ├── LinearModelEstimator.java │ ├── LinearOutputFunction.java │ ├── LogisticOutputFunction.java │ ├── MultinomialConditionalProbabilityEstimator.java │ ├── NaiveBayes.java │ ├── NaiveBayesGaussianIrisExample.java │ ├── NetworkLayer.java │ ├── Optimizer.java │ ├── OutputFunction.java │ ├── SilhouetteCoefficient.java │ ├── SoftmaxLinearModelExample.java │ ├── SoftmaxOutputFunction.java │ └── TanhOutputFunction.java │ ├── linalg │ ├── FunctionMapper.java │ ├── LinearSystemExample.java │ ├── MatrixOperations.java │ ├── RandomizedMatrix.java │ └── UnivariateFunctionMapper.java │ ├── mapreduce │ ├── BasicMapReduceExample.java │ ├── CustomWordCountMapReduceExample.java │ ├── JSONMapper.java │ ├── SimpleTokenMapper.java │ ├── SparseAlgebraMapReduceExample.java │ ├── SparseMatrixMultiplicationMapper.java │ ├── SparseMatrixMultiplicationReducer.java │ ├── SparseMatrixWritable.java │ └── WordCountMapReduceExample.java │ └── statistics │ ├── AggStatsExample.java │ ├── AnscombeStatsExample.java │ ├── AnscombesPlotExample.java │ ├── ContinuousDistributionPlot.java │ ├── DiscreteDistributionPlot.java │ ├── EntropyPlotExample.java │ ├── Guassian2DExample.java │ └── HistogramExample.java └── resources ├── css ├── chart.css ├── chart_lineplot.css ├── data_with_fit_line.css └── overlay-chart.css └── datasets ├── iris └── iris_data.csv ├── mnist └── README └── sentiment ├── amazon_cells_labelled.txt ├── imdb_labelled.txt └── yelp_labelled.txt /README.md: -------------------------------------------------------------------------------- 1 | Data Science with Java 2 | ========== 3 | 4 | This is the example code that accompanies Data Science with Java by Michael R. Brzustowicz, PhD (9781491934111). 5 | 6 | Click the Download Zip button to the right to download example code. 7 | 8 | Visit the catalog page [here](http://shop.oreilly.com/product/0636920043171.do). 9 | 10 | See an error? Report it [here](http://oreilly.com/catalog/errata.csp?isbn=0636920043171), or simply fork and send us a pull request 11 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | com.oreilly.javabook 5 | Dava_Science_with_Java 6 | 1.0-SNAPSHOT 7 | jar 8 | 9 | 10 | org.apache.commons 11 | commons-math3 12 | 3.6 13 | 14 | 15 | com.googlecode.json-simple 16 | json-simple 17 | 1.1 18 | 19 | 20 | mysql 21 | mysql-connector-java 22 | 5.1.35 23 | 24 | 25 | org.apache.hadoop 26 | hadoop-client 27 | 2.7.2 28 | 29 | 30 | 31 | UTF-8 32 | 1.8 33 | 1.8 34 | 35 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/Batch.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | 20 | /** 21 | * 22 | * @author Michael Brzustowicz 23 | */ 24 | public class Batch extends MatrixResampler { 25 | 26 | public Batch(RealMatrix features, RealMatrix labels) { 27 | super(features, labels); 28 | } 29 | 30 | public void calcNextBatch(int batchSize) { 31 | super.calculateTestTrainSplit(batchSize); 32 | } 33 | 34 | public RealMatrix getInputBatch() { 35 | return super.getTestingFeatures(); 36 | } 37 | 38 | public RealMatrix getTargetBatch() { 39 | return super.getTestingLabels(); 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/CrossEntropyLossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 19 | import org.apache.commons.math3.linear.ArrayRealVector; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 23 | import org.apache.commons.math3.util.FastMath; 24 | 25 | /** 26 | * 27 | * @author Michael Brzustowicz 28 | */ 29 | public class CrossEntropyLossFunction implements LossFunction { 30 | 31 | @Override 32 | public double getSampleLoss(double predicted, double target) { 33 | return -1.0 * (target * ((predicted>0)?FastMath.log(predicted):0) + 34 | (1.0 - target)*(predicted<1?FastMath.log(1.0-predicted):0)); 35 | } 36 | 37 | public double getSampleLoss(double[] predicted, double[] target) { 38 | double loss = 0.0; 39 | for (int i = 0; i < predicted.length; i++) { 40 | loss += getSampleLoss(predicted[i], target[i]); 41 | } 42 | return loss; 43 | } 44 | 45 | @Override 46 | public double getSampleLoss(RealVector predicted, RealVector target) { 47 | double loss = 0.0; 48 | for (int i = 0; i < predicted.getDimension(); i++) { 49 | loss += getSampleLoss(predicted.getEntry(i), target.getEntry(i)); 50 | } 51 | return loss; 52 | } 53 | 54 | @Override 55 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) { 56 | SummaryStatistics stats = new SummaryStatistics(); 57 | for (int i = 0; i < predicted.getRowDimension(); i++) { 58 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i))); 59 | } 60 | return stats.getMean(); 61 | } 62 | 63 | @Override 64 | public double getSampleLossGradient(double predicted, double target) { 65 | // NOTE this blows up if predicted = 0 or 1, which it should never be 66 | return (predicted - target) / (predicted * (1 - predicted)); 67 | } 68 | 69 | @Override 70 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) { 71 | RealVector loss = new ArrayRealVector(predicted.getDimension()); 72 | for (int i = 0; i < predicted.getDimension(); i++) { 73 | loss.setEntry(i, getSampleLossGradient(predicted.getEntry(i), target.getEntry(i))); 74 | } 75 | return loss; 76 | } 77 | 78 | @Override 79 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) { 80 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension()); 81 | for (int i = 0; i < predicted.getRowDimension(); i++) { 82 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i))); 83 | } 84 | return loss; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/Dictionary.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | /** 19 | * 20 | * @author Michael Brzustowicz 21 | */ 22 | public interface Dictionary { 23 | /** 24 | * boxed type can return a null if term is not in dictionary 25 | * @param term 26 | * @return 27 | */ 28 | Integer getTermIndex(String term); 29 | int getNumTerms(); 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/HashingDictionary.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | /** 19 | * uses hashing trick to store terms in large hashmap to avoid collisions 20 | * @author Michael Brzustowicz 21 | */ 22 | public class HashingDictionary implements Dictionary { 23 | private int numTerms; // 2^n is optimal 24 | 25 | public HashingDictionary() { 26 | // 2^20 = 1048576 27 | this(new Double(Math.pow(2,20)).intValue()); 28 | } 29 | 30 | public HashingDictionary(int numTerms) { 31 | this.numTerms = numTerms; 32 | } 33 | 34 | @Override 35 | public Integer getTermIndex(String term) { 36 | return Math.floorMod(term.hashCode(), numTerms); 37 | } 38 | 39 | @Override 40 | public int getNumTerms() { 41 | return numTerms; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/IrisPCAExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import com.oreilly.dswj.datasets.Iris; 19 | import java.io.File; 20 | import java.io.IOException; 21 | import javafx.application.Application; 22 | import javafx.embed.swing.SwingFXUtils; 23 | import javafx.scene.Scene; 24 | import javafx.scene.SnapshotParameters; 25 | import javafx.scene.chart.NumberAxis; 26 | import javafx.scene.chart.ScatterChart; 27 | import javafx.scene.chart.XYChart; 28 | import javafx.scene.image.WritableImage; 29 | import javafx.stage.Stage; 30 | import javax.imageio.ImageIO; 31 | import org.apache.commons.math3.linear.RealMatrix; 32 | 33 | /** 34 | * 35 | * @author Michael Brzustowicz 36 | */ 37 | public class IrisPCAExample extends Application { 38 | 39 | /** 40 | * @param args the command line arguments 41 | */ 42 | public static void main(String[] args) { 43 | launch(args); 44 | } 45 | 46 | @Override 47 | public void start(Stage stage) throws Exception { 48 | Iris iris = new Iris(); 49 | 50 | PCA pca = new PCA(iris.getData(), new PCAEIGImplementation()); 51 | System.out.println(pca.getExplainedVariances()); 52 | System.out.println(pca.getCumulativeVariances()); 53 | RealMatrix irispca = pca.getPrincipalComponents(0.98);//2); 54 | System.out.println(irispca); 55 | 56 | 57 | XYChart.Series series = new XYChart.Series(); 58 | for (int i = 0; i < irispca.getRowDimension(); i++) { 59 | series.getData().add(new XYChart.Data(irispca.getEntry(i, 0), irispca.getEntry(i, 1))); 60 | } 61 | 62 | NumberAxis xAxis = new NumberAxis(); 63 | xAxis.setLabel("X1"); 64 | NumberAxis yAxis = new NumberAxis(); 65 | yAxis.setLabel("X2"); 66 | ScatterChart scatterChart = new ScatterChart<>(xAxis,yAxis); 67 | scatterChart.getData().add(series); 68 | scatterChart.setAnimated(false); 69 | 70 | Scene scene = new Scene(scatterChart,800,600); 71 | 72 | /* 73 | tell the stage what scene to use and render it! 74 | */ 75 | stage.setScene(scene); 76 | stage.show(); 77 | 78 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null); 79 | 80 | // TODO: probably use a file chooser here 81 | // File outFile = new File("Iris2PCA.png"); 82 | // 83 | // try { 84 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", outFile); 85 | // } catch (IOException e) { 86 | // // TODO: handle exception here 87 | // } 88 | 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/LabelEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.Arrays; 19 | import java.util.List; 20 | 21 | /** 22 | * Encode labels of type T to an ArrayList with optional sorting 23 | * @author Michael Brzustowicz 24 | * @param 25 | */ 26 | public class LabelEncoder { 27 | 28 | private final List classes; 29 | 30 | public LabelEncoder(T[] labels) { 31 | // Arrays.sort(labels); // can sort first but not required 32 | classes = Arrays.asList(labels); 33 | } 34 | 35 | public List getClasses() { 36 | return classes; 37 | } 38 | 39 | public int encode(T label) { 40 | return classes.indexOf(label); 41 | } 42 | 43 | public T decode(int index) { 44 | return classes.get(index); 45 | } 46 | 47 | public int[] encodeOneHot(T label) { 48 | int[] oneHot = new int[classes.size()]; 49 | oneHot[encode(label)] = 1; 50 | return oneHot; 51 | } 52 | 53 | public T decodeOneHot(int[] oneHot) { 54 | return decode(Arrays.binarySearch(oneHot, 1)); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/LinearLossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.analysis.function.Signum; 19 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class LinearLossFunction implements LossFunction { 29 | 30 | @Override 31 | public double getSampleLoss(double predicted, double target) { 32 | return Math.abs(predicted - target); 33 | } 34 | 35 | @Override 36 | public double getSampleLoss(RealVector predicted, RealVector target) { 37 | return predicted.getL1Distance(target); 38 | } 39 | 40 | @Override 41 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) { 42 | SummaryStatistics stats = new SummaryStatistics(); 43 | for (int i = 0; i < predicted.getRowDimension(); i++) { 44 | double dist = getSampleLoss(predicted.getRowVector(i), target.getRowVector(i)); 45 | stats.addValue(dist); 46 | } 47 | return stats.getMean(); 48 | } 49 | 50 | @Override 51 | public double getSampleLossGradient(double predicted, double target) { 52 | return Math.signum(predicted - target); // -1, 0, 1 53 | } 54 | 55 | @Override 56 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) { 57 | return predicted.subtract(target).map(new Signum()); 58 | } 59 | 60 | //TODO SparseToSignum would be nice!!! ie only process elements of the iterable 61 | @Override 62 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) { 63 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension()); 64 | for (int i = 0; i < predicted.getRowDimension(); i++) { 65 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i))); 66 | } 67 | return loss; 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/ListResampler.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.Collections; 19 | import java.util.List; 20 | import java.util.Random; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | * @param 26 | */ 27 | public class ListResampler { 28 | 29 | private final List data; 30 | private final int trainingSetSize; 31 | private final int testSetSize; 32 | private final int validationSetSize; 33 | 34 | public ListResampler(List data, double testFraction, long seed) { 35 | this(data, testFraction, 0.0, seed); 36 | } 37 | 38 | public ListResampler(List data, double testFraction, double validationFraction, long seed) { 39 | this(data, testFraction, validationFraction, seed, false); 40 | } 41 | 42 | public ListResampler(List data, double testFraction, double validationFraction, long seed, boolean deepCopy) { 43 | this.data = deepCopy ? data : data; // deep copy so as not to alter original data!!! 44 | validationSetSize = new Double(validationFraction * data.size()).intValue(); 45 | testSetSize = new Double(testFraction * data.size()).intValue(); 46 | trainingSetSize = data.size() - (testSetSize + validationSetSize); 47 | Random rnd = new Random(seed); 48 | Collections.shuffle(data, rnd); 49 | } 50 | 51 | public List getValidationSet() { 52 | return data.subList(0, validationSetSize); 53 | } 54 | 55 | public List getTestSet() { 56 | return data.subList(validationSetSize, validationSetSize + testSetSize); 57 | } 58 | 59 | public List getTrainingSet() { 60 | return data.subList(validationSetSize + testSetSize, data.size()); 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/LossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public interface LossFunction { 26 | 27 | /** 28 | * loss of one dimension of one sample output 29 | * @param predicted 30 | * @param target 31 | * @return 32 | */ 33 | public double getSampleLoss(double predicted, double target); 34 | 35 | /** 36 | * combined loss over all dimensions of one sample output 37 | * @param predicted 38 | * @param target 39 | * @return 40 | */ 41 | public double getSampleLoss(RealVector predicted, RealVector target); 42 | 43 | /** 44 | * average loss over all samples 45 | * @param predicted 46 | * @param target 47 | * @return 48 | */ 49 | public double getMeanLoss(RealMatrix predicted, RealMatrix target); 50 | 51 | /** 52 | * derivative of loss of one dimension of one sample output 53 | * @param predicted 54 | * @param target 55 | * @return 56 | */ 57 | public double getSampleLossGradient(double predicted, double target); 58 | 59 | /** 60 | * derivative of loss over all dimensions of one sample output 61 | * @param predicted 62 | * @param target 63 | * @return 64 | */ 65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target); 66 | 67 | /** 68 | * 69 | * @param predicted 70 | * @param target 71 | * @return 72 | */ 73 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target); 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/MatrixScaleType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | /** 19 | * 20 | * @author Michael Brzustowicz 21 | */ 22 | public enum MatrixScaleType { 23 | MINMAX, CENTER, ZSCORE, L1, L2 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/MatrixScalingOperator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor; 19 | import org.apache.commons.math3.linear.RealVector; 20 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class MatrixScalingOperator implements RealMatrixChangingVisitor { 27 | 28 | RealVector normals; 29 | MultivariateSummaryStatistics mss; 30 | MatrixScaleType scaleType; 31 | 32 | public MatrixScalingOperator(MultivariateSummaryStatistics mss, MatrixScaleType scaleType) { 33 | this.normals = null; 34 | this.mss = mss; 35 | this.scaleType = scaleType; 36 | } 37 | 38 | public MatrixScalingOperator(RealVector normals, MatrixScaleType scaleType) { 39 | this.normals = normals; 40 | this.mss = null; 41 | this.scaleType = scaleType; 42 | } 43 | 44 | @Override 45 | public void start(int rows, int columns, int startRow, int endRow, 46 | int startColumn, int endColumn) { 47 | // nothing 48 | } 49 | 50 | @Override 51 | public double visit(int row, int column, double value) { 52 | double min, max, avg, std, rowNormal; 53 | double scaledValue = Double.NaN; 54 | 55 | switch (scaleType) { 56 | 57 | case MINMAX: 58 | min = mss.getMin()[column]; 59 | max = mss.getMax()[column]; 60 | scaledValue = (max > min) ? (value - min) / (max - min) : (value - min); 61 | break; 62 | 63 | case CENTER: 64 | avg = mss.getMean()[column]; 65 | scaledValue = value - avg; 66 | break; 67 | 68 | case ZSCORE: 69 | avg = mss.getMean()[column]; 70 | std = mss.getStandardDeviation()[column]; 71 | scaledValue = std > 0 ? (value - avg) / std : (value - avg); 72 | break; 73 | 74 | case L1: 75 | rowNormal = normals.getEntry(row); 76 | scaledValue = rowNormal > 0 ? value / rowNormal : 0; 77 | break; 78 | 79 | case L2: 80 | rowNormal = normals.getEntry(row); 81 | scaledValue = rowNormal > 0 ? value / rowNormal : 0; 82 | break; 83 | 84 | default: 85 | break; 86 | } 87 | 88 | return scaledValue; 89 | } 90 | 91 | @Override 92 | public double end() { 93 | return 0.0; 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/PCA.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class PCA { 26 | 27 | private final PCAImplementation pCAImplementation; 28 | 29 | /** 30 | * default is SVD implementation 31 | * @param data 32 | */ 33 | public PCA(RealMatrix data) { 34 | this(data, new PCASVDImplementation()); 35 | } 36 | 37 | public PCA(RealMatrix data, PCAImplementation pCAImplementation) { 38 | this.pCAImplementation = pCAImplementation; 39 | this.pCAImplementation.compute(data); 40 | } 41 | 42 | /** 43 | * Projects the centered data onto the new basis with k components 44 | * @param k number of components to use 45 | * @return 46 | */ 47 | public RealMatrix getPrincipalComponents(int k) { 48 | return pCAImplementation.getPrincipalComponents(k); 49 | } 50 | 51 | public RealMatrix getPrincipalComponents(int k, RealMatrix otherData) { 52 | return pCAImplementation.getPrincipalComponents(k, otherData); 53 | } 54 | 55 | public RealVector getExplainedVariances() { 56 | return pCAImplementation.getExplainedVariances(); 57 | } 58 | 59 | public RealVector getCumulativeVariances() { 60 | RealVector variances = getExplainedVariances(); 61 | RealVector cumulative = variances.copy(); 62 | double sum = 0; 63 | for (int i = 0; i < cumulative.getDimension(); i++) { 64 | sum += cumulative.getEntry(i); 65 | cumulative.setEntry(i, sum); 66 | } 67 | return cumulative; 68 | } 69 | 70 | public int getNumberOfComponents(double threshold) { 71 | RealVector cumulative = getCumulativeVariances(); 72 | int numComponents=1; 73 | for (int i = 0; i < cumulative.getDimension(); i++) { 74 | numComponents = i + 1; 75 | if(cumulative.getEntry(i) >= threshold) { 76 | break; 77 | } 78 | } 79 | return numComponents; 80 | } 81 | 82 | public RealMatrix getPrincipalComponents(double threshold) { 83 | int numComponents = getNumberOfComponents(threshold); 84 | return getPrincipalComponents(numComponents); 85 | } 86 | 87 | public RealMatrix getPrincipalComponents(double threshold, RealMatrix otherData) { 88 | int numComponents = getNumberOfComponents(threshold); 89 | return getPrincipalComponents(numComponents, otherData); 90 | } 91 | 92 | } 93 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/PCAEIGImplementation.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.ArrayRealVector; 19 | import org.apache.commons.math3.linear.EigenDecomposition; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | import org.apache.commons.math3.stat.correlation.Covariance; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class PCAEIGImplementation implements PCAImplementation { 29 | 30 | private transient RealMatrix data; 31 | private RealMatrix d; // eigenvalue matrix 32 | private RealMatrix v; // eigenvector matrix 33 | private RealVector explainedVariances; 34 | private transient EigenDecomposition eig; 35 | private final MatrixScaler matrixScaler; 36 | 37 | public PCAEIGImplementation() { 38 | matrixScaler = new MatrixScaler(MatrixScaleType.CENTER); 39 | } 40 | 41 | @Override 42 | public void compute(RealMatrix data) { 43 | this.data = data; 44 | eig = new EigenDecomposition(new Covariance(data).getCovarianceMatrix()); 45 | d = eig.getD(); 46 | v = eig.getV(); 47 | } 48 | 49 | @Override 50 | public RealVector getExplainedVariances() { 51 | //TODO just make this a getter and compute in compute method 52 | int n = eig.getD().getColumnDimension(); //colD = rowD 53 | explainedVariances = new ArrayRealVector(n); 54 | double[] eigenValues = eig.getRealEigenvalues(); 55 | double cumulative = 0.0; 56 | for (int i = 0; i < n; i++) { 57 | double var = eigenValues[i]; 58 | cumulative += var; 59 | explainedVariances.setEntry(i, var); 60 | } 61 | /* dividing the vector by the last (highest) value maximizes to 1 */ 62 | return explainedVariances.mapDivideToSelf(cumulative); 63 | } 64 | 65 | @Override 66 | public RealMatrix getPrincipalComponents(int k) { 67 | int m = eig.getV().getColumnDimension(); // rowD = colD 68 | matrixScaler.transform(data); 69 | // MatrixScaler.center(data); 70 | return data.multiply(eig.getV().getSubMatrix(0, m-1, 0, k-1)); 71 | } 72 | 73 | 74 | 75 | @Override 76 | public RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData) { 77 | int numRows = v.getRowDimension(); 78 | // NEW data transformed under OLD means 79 | matrixScaler.transform(otherData); 80 | return otherData.multiply(v.getSubMatrix(0, numRows-1, 0, numComponents-1)); 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/PCAImplementation.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public interface PCAImplementation { 26 | 27 | void compute(RealMatrix data); 28 | 29 | RealVector getExplainedVariances(); 30 | 31 | RealMatrix getPrincipalComponents(int numComponents); 32 | 33 | RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData); 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/PCASVDImplementation.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.ArrayRealVector; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.RealVector; 21 | import org.apache.commons.math3.linear.SingularValueDecomposition; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class PCASVDImplementation implements PCAImplementation { 28 | private RealMatrix u; 29 | private RealMatrix s; 30 | private RealMatrix v; 31 | private MatrixScaler matrixScaler; 32 | private SingularValueDecomposition svd; 33 | 34 | @Override 35 | public void compute(RealMatrix data) { 36 | //centers the data in place and stores the column stats for later use 37 | matrixScaler = new MatrixScaler(data, MatrixScaleType.CENTER); 38 | svd = new SingularValueDecomposition(data); 39 | u = svd.getU(); 40 | s = svd.getS(); 41 | v = svd.getV(); 42 | } 43 | 44 | @Override 45 | public RealVector getExplainedVariances() { 46 | 47 | double[] singularValues = svd.getSingularValues(); 48 | int n = singularValues.length; 49 | int m = u.getRowDimension(); // number of rows in U is same as in data 50 | RealVector explainedVariances = new ArrayRealVector(n); 51 | double sum = 0.0; 52 | for (int i = 0; i < n; i++) { 53 | double var = Math.pow(singularValues[i], 2) / (double)(m-1); 54 | sum += var; 55 | explainedVariances.setEntry(i, var); 56 | } 57 | /* dividing the vector by the last (highest) value maximizes to 1 */ 58 | return explainedVariances.mapDivideToSelf(sum); 59 | 60 | } 61 | 62 | @Override 63 | public RealMatrix getPrincipalComponents(int numComponents) { 64 | int numRows = svd.getU().getRowDimension(); 65 | /* submatrix limits are inclusive */ 66 | RealMatrix uk = u.getSubMatrix(0, numRows-1, 0, numComponents-1); 67 | RealMatrix sk = s.getSubMatrix(0, numComponents-1, 0, numComponents-1); 68 | return uk.multiply(sk); 69 | } 70 | 71 | @Override 72 | public RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData) { 73 | // center the (new) data on means from original data 74 | matrixScaler.transform(otherData); 75 | int numRows = v.getRowDimension(); 76 | // subMatrix indeces are inclusive 77 | return otherData.multiply(v.getSubMatrix(0, numRows-1, 0, numComponents-1)); 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/ProbabilityEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 19 | import org.apache.commons.math3.linear.ArrayRealVector; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class ProbabilityEncoder { 28 | 29 | public RealVector getBinary(RealVector probabilities, double threshhold) { 30 | RealVector out = new ArrayRealVector(probabilities.getDimension()); 31 | for (int i = 0; i < probabilities.getDimension(); i++) { 32 | double entry = probabilities.getEntry(i) >= threshhold ? 1 : 0; 33 | out.setEntry(i, entry); 34 | } 35 | return out; 36 | } 37 | 38 | public RealVector getOneHot(RealVector probabilities) { 39 | RealVector out = new ArrayRealVector(probabilities.getDimension()); 40 | out.setEntry(probabilities.getMaxIndex(), 1); 41 | return out; 42 | } 43 | 44 | public RealMatrix getOneHot(RealMatrix probabilities) { 45 | int numRows = probabilities.getRowDimension(); 46 | int numCols = probabilities.getColumnDimension(); 47 | RealMatrix out = new Array2DRowRealMatrix(numRows, numCols); 48 | for (int i = 0; i < numRows; i++) { 49 | int maxIndex = probabilities.getRowVector(i).getMaxIndex(); 50 | out.setEntry(i, maxIndex, 1); 51 | } 52 | return out; 53 | } 54 | } 55 | 56 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/QuadraticLossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class QuadraticLossFunction implements LossFunction { 27 | @Override 28 | public double getSampleLoss(double predicted, double target) { 29 | double diff = predicted - target; 30 | return 0.5 * diff * diff; 31 | } 32 | 33 | @Override 34 | public double getSampleLoss(RealVector predicted, RealVector target) { 35 | double dist = predicted.getDistance(target); 36 | return 0.5 * dist * dist; 37 | } 38 | 39 | @Override 40 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) { 41 | SummaryStatistics stats = new SummaryStatistics(); 42 | for (int i = 0; i < predicted.getRowDimension(); i++) { 43 | double dist = getSampleLoss(predicted.getRowVector(i), target.getRowVector(i)); 44 | stats.addValue(dist); 45 | } 46 | return stats.getMean(); 47 | } 48 | 49 | @Override 50 | public double getSampleLossGradient(double predicted, double target) { 51 | return predicted - target; 52 | } 53 | 54 | @Override 55 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) { 56 | return predicted.subtract(target); 57 | } 58 | 59 | @Override 60 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) { 61 | return predicted.subtract(target); 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/SimpleTokenizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class SimpleTokenizer implements Tokenizer { 26 | 27 | private final int minTokenSize; 28 | 29 | public SimpleTokenizer(int minTokenSize) { 30 | this.minTokenSize = minTokenSize; 31 | } 32 | 33 | public SimpleTokenizer() { 34 | this(0); 35 | } 36 | 37 | 38 | 39 | @Override 40 | public String[] getTokens(String document) { 41 | String[] tokens = document.trim().split("\\s+"); 42 | List cleanTokens = new ArrayList<>(); 43 | for (String token : tokens) { 44 | String cleanToken = token.trim().toLowerCase().replaceAll("[^A-Za-z\']+", ""); 45 | if(cleanToken.length() > minTokenSize) { 46 | cleanTokens.add(cleanToken); 47 | } 48 | } 49 | return cleanTokens.toArray(new String[0]); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/SoftMaxCrossEntropyLossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.RealVector; 21 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 22 | import org.apache.commons.math3.util.FastMath; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class SoftMaxCrossEntropyLossFunction implements LossFunction { 29 | 30 | @Override 31 | public double getSampleLoss(double predicted, double target) { 32 | return predicted > 0 ? -1.0 * target * FastMath.log(predicted) : 0; 33 | } 34 | 35 | @Override 36 | public double getSampleLoss(RealVector predicted, RealVector target) { 37 | double sampleLoss = 0.0; 38 | for (int i = 0; i < predicted.getDimension(); i++) { 39 | sampleLoss += getSampleLoss(predicted.getEntry(i), target.getEntry(i)); 40 | } 41 | return sampleLoss; 42 | } 43 | 44 | @Override 45 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) { 46 | SummaryStatistics stats = new SummaryStatistics(); 47 | for (int i = 0; i < predicted.getRowDimension(); i++) { 48 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i))); 49 | } 50 | return stats.getMean(); 51 | } 52 | 53 | /** 54 | * dE/dy = - t_i / y_i 55 | * @param predicted 56 | * @param target 57 | * @return 58 | */ 59 | @Override 60 | public double getSampleLossGradient(double predicted, double target) { 61 | return -1.0 * target / predicted; 62 | } 63 | 64 | @Override 65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) { 66 | return target.ebeDivide(predicted).mapMultiplyToSelf(-1.0); 67 | } 68 | 69 | @Override 70 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) { 71 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension()); 72 | for (int i = 0; i < predicted.getRowDimension(); i++) { 73 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i))); 74 | } 75 | return loss; 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/TFIDF.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class TFIDF implements RealMatrixChangingVisitor { 26 | 27 | private final int numDocuments; 28 | private final RealVector termDocumentFrequency; 29 | private final double logNumDocuments; 30 | 31 | public TFIDF(int numDocuments, RealVector termDocumentFrequency) { 32 | this.numDocuments = numDocuments; 33 | this.termDocumentFrequency = termDocumentFrequency; 34 | this.logNumDocuments = numDocuments > 0 ? Math.log(numDocuments) : 0; 35 | } 36 | 37 | @Override 38 | public void start(int rows, int columns, int startRow, int endRow, 39 | int startColumn, int endColumn) { 40 | //NA 41 | } 42 | 43 | @Override 44 | public double visit(int row, int column, double value) { 45 | double df = termDocumentFrequency.getEntry(column); 46 | double logDF = df > 0 ? Math.log(df) : 0.0; 47 | // TFIDF = TF_i * log(N/DF_i) = TF_i * ( log(N) - log(DF_i) ) 48 | return value * (logNumDocuments - logDF); 49 | } 50 | 51 | @Override 52 | public double end() { 53 | return 0.0; 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/TFIDFVectorizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.List; 19 | import org.apache.commons.math3.linear.OpenMapRealVector; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class TFIDFVectorizer { 28 | 29 | private Vectorizer vectorizer; 30 | private Vectorizer binaryVectorizer; 31 | private int numTerms; 32 | 33 | public TFIDFVectorizer(Dictionary dictionary, Tokenizer tokenzier) { 34 | vectorizer = new Vectorizer(dictionary, tokenzier, false); 35 | binaryVectorizer = new Vectorizer(dictionary, tokenzier, true); 36 | numTerms = dictionary.getNumTerms(); 37 | } 38 | 39 | public TFIDFVectorizer() { 40 | this(new HashingDictionary(16384), new SimpleTokenizer()); 41 | } 42 | 43 | public RealVector getTermDocumentCount(List documents) { 44 | RealVector vector = new OpenMapRealVector(numTerms); 45 | for (String document : documents) { 46 | vector.add(binaryVectorizer.getCountVector(document)); 47 | } 48 | return vector; 49 | } 50 | 51 | public RealMatrix getTFIDF(List documents) { 52 | int numDocuments = documents.size(); 53 | RealVector df = getTermDocumentCount(documents); 54 | RealMatrix tfidf = vectorizer.getCountMatrix(documents); 55 | tfidf.walkInOptimizedOrder(new TFIDF(numDocuments, df)); 56 | return tfidf; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/TermDictionary.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.HashMap; 19 | import java.util.Map; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class TermDictionary implements Dictionary { 26 | 27 | private final Map indexedTerms; 28 | private int counter; 29 | 30 | public TermDictionary() { 31 | indexedTerms = new HashMap<>(); 32 | counter = 0; 33 | } 34 | 35 | public void addTerm(String term) { 36 | if(!indexedTerms.containsKey(term)) { 37 | indexedTerms.put(term, counter++); 38 | } 39 | } 40 | 41 | public void addTerms(String[] terms) { 42 | for (String term : terms) { 43 | addTerm(term); 44 | } 45 | } 46 | 47 | @Override 48 | public Integer getTermIndex(String term) { 49 | return indexedTerms.get(term); 50 | } 51 | 52 | @Override 53 | public int getNumTerms() { 54 | return indexedTerms.size(); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/Tokenizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | /** 19 | * 20 | * @author Michael Brzustowicz 21 | */ 22 | public interface Tokenizer { 23 | String[] getTokens(String document); 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/TwoPointLossFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 19 | import org.apache.commons.math3.linear.ArrayRealVector; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 23 | import org.apache.commons.math3.util.FastMath; 24 | 25 | /** 26 | * target = -1 or 1 27 | * predicted is double between -1.0 and 1.0 28 | * @author Michael Brzustowicz 29 | */ 30 | public class TwoPointLossFunction implements LossFunction { 31 | 32 | @Override 33 | public double getSampleLoss(double predicted, double target) { 34 | // convert -1:1 to 0:1 scale 35 | double y = 0.5 * (predicted + 1); 36 | double t = 0.5 * (target + 1); 37 | return -1.0 * (t * ((y>0)?FastMath.log(y):0) + 38 | (1.0 - t)*(y<1?FastMath.log(1.0-y):0)); 39 | } 40 | 41 | @Override 42 | public double getSampleLoss(RealVector predicted, RealVector target) { 43 | double loss = 0.0; 44 | for (int i = 0; i < predicted.getDimension(); i++) { 45 | loss += getSampleLoss(predicted.getEntry(i), target.getEntry(i)); 46 | } 47 | return loss; 48 | } 49 | 50 | @Override 51 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) { 52 | SummaryStatistics stats = new SummaryStatistics(); 53 | for (int i = 0; i < predicted.getRowDimension(); i++) { 54 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i))); 55 | } 56 | return stats.getMean(); 57 | } 58 | 59 | @Override 60 | public double getSampleLossGradient(double predicted, double target) { 61 | return (predicted - target) / (1 - predicted * predicted); 62 | } 63 | 64 | @Override 65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) { 66 | RealVector loss = new ArrayRealVector(predicted.getDimension()); 67 | for (int i = 0; i < predicted.getDimension(); i++) { 68 | loss.setEntry(i, getSampleLossGradient(predicted.getEntry(i), target.getEntry(i))); 69 | } 70 | return loss; 71 | } 72 | 73 | @Override 74 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) { 75 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension()); 76 | for (int i = 0; i < predicted.getRowDimension(); i++) { 77 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i))); 78 | } 79 | return loss; 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/Vectorizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import java.util.List; 19 | import org.apache.commons.math3.linear.OpenMapRealMatrix; 20 | import org.apache.commons.math3.linear.OpenMapRealVector; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | import org.apache.commons.math3.linear.RealVector; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class Vectorizer { 29 | 30 | private final Dictionary dictionary; 31 | private final Tokenizer tokenzier; 32 | private final boolean isBinary; 33 | 34 | public Vectorizer(Dictionary dictionary, Tokenizer tokenzier, boolean isBinary) { 35 | this.dictionary = dictionary; 36 | this.tokenzier = tokenzier; 37 | this.isBinary = isBinary; 38 | } 39 | 40 | public Vectorizer() { 41 | this(new HashingDictionary(), new SimpleTokenizer(), false); 42 | } 43 | 44 | public RealVector getCountVector(String document) { 45 | 46 | RealVector vector = new OpenMapRealVector(dictionary.getNumTerms()); 47 | 48 | String[] tokens = tokenzier.getTokens(document); 49 | 50 | for (String token : tokens) { 51 | 52 | Integer index = dictionary.getTermIndex(token); 53 | 54 | if(index != null) { 55 | 56 | if(isBinary) { 57 | vector.setEntry(index, 1); 58 | } else { 59 | vector.addToEntry(index, 1); // increment ! 60 | } 61 | } 62 | } 63 | return vector; 64 | } 65 | 66 | public RealMatrix getCountMatrix(List documents) { 67 | int rowDimension = documents.size(); 68 | int columnDimension = dictionary.getNumTerms(); 69 | RealMatrix matrix = new OpenMapRealMatrix(rowDimension, columnDimension); 70 | int counter = 0; 71 | for (String document : documents) { 72 | matrix.setRowVector(counter++, getCountVector(document)); 73 | } 74 | return matrix; 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/dataops/VectorizerExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.dataops; 17 | 18 | import com.oreilly.dswj.datasets.Sentiment; 19 | import java.io.IOException; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class VectorizerExample { 27 | 28 | /** 29 | * @param args the command line arguments 30 | * @throws java.io.IOException 31 | */ 32 | public static void main(String[] args) throws IOException { 33 | 34 | Sentiment sentiment = new Sentiment(); 35 | 36 | TermDictionary termDictionary = new TermDictionary(); 37 | 38 | SimpleTokenizer tokenizer = new SimpleTokenizer(); 39 | 40 | for (String document : sentiment.getDocuments()) { 41 | String[]tokens = tokenizer.getTokens(document); 42 | termDictionary.addTerms(tokens); 43 | } 44 | 45 | // System.out.println(termDictionary.getTermIndex("iasdfasd")); 46 | // System.exit(0); 47 | 48 | 49 | Vectorizer vectorizer = new Vectorizer(termDictionary, tokenizer, false); 50 | RealMatrix counts = vectorizer.getCountMatrix(sentiment.getDocuments()); 51 | System.out.println(counts.getSubMatrix(0, 5, 0, 5)); 52 | 53 | Vectorizer binaryVectorizer = new Vectorizer(termDictionary, tokenizer, true); 54 | RealMatrix binCounts = binaryVectorizer.getCountMatrix(sentiment.getDocuments()); 55 | System.out.println(binCounts.getSubMatrix(0, 5, 0, 5)); 56 | 57 | TFIDFVectorizer tfidfVectorizer = new TFIDFVectorizer(termDictionary, tokenizer); 58 | RealMatrix tfidf = tfidfVectorizer.getTFIDF(sentiment.getDocuments()); 59 | System.out.println(tfidf.getSubMatrix(0, 5, 0, 5)); 60 | 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/datasets/Anscombe.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.datasets; 17 | 18 | /** 19 | * 20 | * @author Michael Brzustowicz 21 | */ 22 | public class Anscombe { 23 | public static final double[] x1 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0}; 24 | public static final double[] y1 = {8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68}; 25 | public static final double[] x2 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0}; 26 | public static final double[] y2 = {9.14, 8.14, 8.74, 8.77, 9.26, 8.10, 6.13, 3.10, 9.13, 7.26, 4.74}; 27 | public static final double[] x3 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0}; 28 | public static final double[] y3 = {7.46, 6.77, 12.74, 7.11, 7.81, 8.84, 6.08, 5.39, 8.15, 6.42, 5.73}; 29 | public static final double[] x4 = {8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 19.0, 8.0, 8.0, 8.0}; 30 | public static final double[] y4 = {6.58, 5.76, 7.71, 8.84, 8.47, 7.04, 5.25, 12.50, 5.56, 7.91, 6.89}; 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/datasets/Iris.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.datasets; 17 | 18 | import java.io.BufferedReader; 19 | 20 | import java.io.IOException; 21 | import java.io.InputStream; 22 | import java.io.InputStreamReader; 23 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 24 | import org.apache.commons.math3.linear.RealMatrix; 25 | 26 | /** 27 | * 28 | * @author Michael Brzustowicz 29 | */ 30 | public class Iris { 31 | 32 | private final RealMatrix data; 33 | private final RealMatrix labels; 34 | private static final String FILEPATH = "/datasets/iris/iris_data.csv"; 35 | 36 | public Iris() throws IOException { 37 | 38 | data = new Array2DRowRealMatrix(150, 4); 39 | labels = new Array2DRowRealMatrix(150, 3); // binarized 40 | 41 | try(InputStream inputStream = getClass().getResourceAsStream(FILEPATH)) { 42 | BufferedReader br = new BufferedReader(new InputStreamReader(inputStream)); 43 | String line; 44 | int rowCounter = 0; 45 | while ((line = br.readLine()) != null) { 46 | 47 | String[] s = line.split(","); 48 | double sepalLength = Double.parseDouble(s[0].trim()); 49 | double sepalWidth = Double.parseDouble(s[1].trim()); 50 | double petalLength = Double.parseDouble(s[2].trim()); 51 | double petalWidth = Double.parseDouble(s[3].trim()); 52 | String plantClass = s[4].trim(); 53 | 54 | data.setEntry(rowCounter, 0, sepalLength); 55 | data.setEntry(rowCounter, 1, sepalWidth); 56 | data.setEntry(rowCounter, 2, petalLength); 57 | data.setEntry(rowCounter, 3, petalWidth); 58 | 59 | if (null != plantClass) switch (plantClass) { 60 | case "Iris-setosa": 61 | labels.setEntry(rowCounter, 0, 1); 62 | break; 63 | case "Iris-versicolor": 64 | labels.setEntry(rowCounter, 1, 1); 65 | break; 66 | case "Iris-virginica": 67 | labels.setEntry(rowCounter, 3, 1); 68 | break; 69 | default: 70 | System.out.println("something wrong with " + plantClass); 71 | break; 72 | } 73 | 74 | rowCounter++; 75 | } 76 | } 77 | } 78 | 79 | public RealMatrix getData() { 80 | return data; 81 | } 82 | 83 | public RealMatrix getLabels() { 84 | return labels; 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/datasets/MNIST.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.datasets; 17 | 18 | import java.io.BufferedInputStream; 19 | import java.io.DataInputStream; 20 | import java.io.FileNotFoundException; 21 | import java.io.IOException; 22 | import org.apache.commons.math3.linear.BlockRealMatrix; 23 | import org.apache.commons.math3.linear.OpenMapRealMatrix; 24 | import org.apache.commons.math3.linear.RealMatrix; 25 | 26 | /** 27 | * all external data files are in src/main/resources/ 28 | * these four files in src/main/resources/datasets/mnist/ 29 | * @author Michael Brzustowicz 30 | */ 31 | public class MNIST { 32 | 33 | public RealMatrix trainingData; 34 | public RealMatrix trainingLabels; 35 | public RealMatrix testingData; 36 | public RealMatrix testingLabels; 37 | 38 | public MNIST() throws IOException { 39 | trainingData = new BlockRealMatrix(60000, 784); // image to vector 40 | trainingLabels = new OpenMapRealMatrix(60000, 10); // the one hot label 41 | testingData = new BlockRealMatrix(10000, 784); // image to vector 42 | testingLabels = new OpenMapRealMatrix(10000, 10); // the one hot label 43 | loadTrainingData("/datasets/mnist/train-images-idx3-ubyte"); // loads from jar 44 | loadTrainingLabels("/datasets/mnist/train-labels-idx1-ubyte"); 45 | loadTestingData("/datasets/mnist/t10k-images-idx3-ubyte"); 46 | loadTestingLabels("/datasets/mnist/t10k-labels-idx1-ubyte"); 47 | 48 | } 49 | 50 | private void loadTrainingData(String filename) throws FileNotFoundException, IOException { 51 | try (DataInputStream di = new DataInputStream(new BufferedInputStream(getClass().getResourceAsStream(filename)))) { 52 | int magicNumber = di.readInt(); //2051 53 | int numImages = di.readInt(); // 60000 54 | int numRows = di.readInt(); // 28 55 | int numCols = di.readInt(); // 28 56 | int vecSize = numRows * numCols; // 784 57 | for (int i = 0; i < numImages; i++) { 58 | for (int j = 0; j < vecSize; j++) { 59 | // values are 0 to 255, so normalize 60 | trainingData.setEntry(i, j, di.readUnsignedByte() / 255.0); 61 | } 62 | } 63 | } 64 | } 65 | 66 | private void loadTestingData(String filename) throws FileNotFoundException, IOException { 67 | 68 | try (DataInputStream di = new DataInputStream(new BufferedInputStream(getClass().getResourceAsStream(filename)))) { 69 | int magicNumber = di.readInt(); //2051 70 | int numImages = di.readInt(); // 10000 71 | int numRows = di.readInt(); // 28 72 | int numCols = di.readInt(); // 28 73 | for (int i = 0; i < numImages; i++) { 74 | for (int j = 0; j < 784; j++) { 75 | // values are 0 to 255, so normalize 76 | testingData.setEntry(i, j, di.readUnsignedByte() / 255.0); 77 | } 78 | } 79 | } 80 | } 81 | 82 | private void loadTrainingLabels(String filename) throws FileNotFoundException, IOException { 83 | try (DataInputStream di = new DataInputStream( new BufferedInputStream(getClass().getResourceAsStream(filename)))) { 84 | int magicNumber = di.readInt(); //2049 85 | int numImages = di.readInt(); // 60000 86 | for (int i = 0; i < numImages; i++) { 87 | // one-hot-encoding, column of 0-9 is given one all else 0 88 | trainingLabels.setEntry(i, di.readUnsignedByte(), 1.0); 89 | } 90 | } 91 | } 92 | 93 | private void loadTestingLabels(String filename) throws FileNotFoundException, IOException { 94 | try (DataInputStream di = new DataInputStream( new BufferedInputStream(getClass().getResourceAsStream(filename)))) { 95 | int magicNumber = di.readInt(); //2049 96 | int numImages = di.readInt(); // 10000 97 | for (int i = 0; i < numImages; i++) { 98 | // one-hot-encoding, column of 0-9 is given one all else 0 99 | testingLabels.setEntry(i, di.readUnsignedByte(), 1.0); 100 | } 101 | } 102 | } 103 | 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/datasets/MultiNormalMixtureDataset.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.datasets; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | import java.util.Random; 21 | import org.apache.commons.math3.distribution.AbstractRealDistribution; 22 | import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution; 23 | import org.apache.commons.math3.distribution.MultivariateNormalDistribution; 24 | import org.apache.commons.math3.distribution.UniformRealDistribution; 25 | import org.apache.commons.math3.linear.CholeskyDecomposition; 26 | import org.apache.commons.math3.stat.correlation.Covariance; 27 | import org.apache.commons.math3.util.Pair; 28 | 29 | /** 30 | * 31 | * @author Michael Brzustowicz 32 | */ 33 | public class MultiNormalMixtureDataset { 34 | int dimension; 35 | List> mixture; 36 | MixtureMultivariateNormalDistribution mixtureDistribution; 37 | 38 | public MultiNormalMixtureDataset(int dimension) { 39 | this.dimension = dimension; 40 | mixture = new ArrayList<>(); 41 | } 42 | 43 | public MixtureMultivariateNormalDistribution getMixtureDistribution() { 44 | return mixtureDistribution; 45 | } 46 | 47 | public void createRandomMixtureModel(int numComponents, double boxSize, long seed) { 48 | Random rnd = new Random(seed); 49 | double limit = boxSize / dimension; 50 | UniformRealDistribution dist = new UniformRealDistribution(-limit, limit); 51 | UniformRealDistribution disC = new UniformRealDistribution(-1, 1); 52 | dist.reseedRandomGenerator(seed); 53 | disC.reseedRandomGenerator(seed); 54 | for (int i = 0; i < numComponents; i++) { 55 | double alpha = rnd.nextDouble(); 56 | double[] means = dist.sample(dimension); 57 | double[][] cov = getRandomCovariance(disC); 58 | MultivariateNormalDistribution multiNorm = new MultivariateNormalDistribution(means, cov); 59 | addMultinormalDistributionToModel(alpha, multiNorm); 60 | } 61 | mixtureDistribution = new MixtureMultivariateNormalDistribution(mixture); 62 | mixtureDistribution.reseedRandomGenerator(seed); // calls to sample() will return same results 63 | } 64 | 65 | /** 66 | * NOTE this is for adding both internal and external, known distros but 67 | * need to figure out clean way to add the mixture to mixtureDistribution!!! 68 | * @param alpha 69 | * @param dist 70 | */ 71 | public void addMultinormalDistributionToModel(double alpha, MultivariateNormalDistribution dist) { 72 | // note all alpha will be L1 normed 73 | mixture.add(new Pair(alpha, dist)); 74 | } 75 | 76 | public double[][] getSimulatedData(int size) { 77 | return mixtureDistribution.sample(size); 78 | } 79 | 80 | private double[][] getRandomCovariance(AbstractRealDistribution dist) { 81 | double[][] data = new double[2*dimension][dimension]; 82 | double determinant = 0.0; 83 | Covariance cov = new Covariance(); 84 | while(Math.abs(determinant) == 0) { 85 | for (int i = 0; i < data.length; i++) { 86 | data[i] = dist.sample(dimension); 87 | } 88 | // check if cov matrix is singular ... if so, keep going 89 | cov = new Covariance(data); 90 | determinant = new CholeskyDecomposition(cov.getCovarianceMatrix()).getDeterminant(); 91 | } 92 | return cov.getCovarianceMatrix().getData(); 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/datasets/Sentiment.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.datasets; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.IOException; 20 | import java.io.InputStream; 21 | import java.io.InputStreamReader; 22 | import java.util.ArrayList; 23 | import java.util.List; 24 | 25 | 26 | /** 27 | * Sentiment labelled sentences 28 | * https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences 29 | * contains data from imdb, yelp and amazon 30 | * imdb has 1000 sent with 500 pos (1) and 500 neg (0) 31 | * yelp has 3729 sent with 500 pos (1) and 500 neg (0) 32 | * amzn has 15004 sent with 500 pos (1) and 500 neg (0) 33 | * @author Michael Brzustowicz 34 | */ 35 | public class Sentiment { 36 | 37 | private final List documents = new ArrayList<>(); 38 | private final List sentiments = new ArrayList<>(); 39 | private static final String IMDB_RESOURCE = "/datasets/sentiment/imdb_labelled.txt"; 40 | private static final String YELP_RESOURCE = "/datasets/sentiment/yelp_labelled.txt"; 41 | private static final String AMZN_RESOURCE = "/datasets/sentiment/amazon_cells_labelled.txt"; 42 | public enum DataSource {IMDB, YELP, AMZN}; 43 | 44 | public Sentiment() throws IOException { 45 | parseResource(IMDB_RESOURCE); // 1000 sentences 46 | parseResource(YELP_RESOURCE); // 1000 sentences 47 | parseResource(AMZN_RESOURCE); // 1000 sentences 48 | } 49 | 50 | public List getSentiments(DataSource dataSource) { 51 | int fromIndex = 0; // inclusive 52 | int toIndex = 3000; // exclusive 53 | switch(dataSource) { 54 | case IMDB: 55 | fromIndex = 0; 56 | toIndex = 1000; 57 | break; 58 | case YELP: 59 | fromIndex = 1000; 60 | toIndex = 2000; 61 | break; 62 | case AMZN: 63 | fromIndex = 2000; 64 | toIndex = 3000; 65 | break; 66 | } 67 | return sentiments.subList(fromIndex, toIndex); 68 | } 69 | 70 | public List getDocuments(DataSource dataSource) { 71 | int fromIndex = 0; // inclusive 72 | int toIndex = 3000; // exclusive 73 | switch(dataSource) { 74 | case IMDB: 75 | fromIndex = 0; 76 | toIndex = 1000; 77 | break; 78 | case YELP: 79 | fromIndex = 1000; 80 | toIndex = 2000; 81 | break; 82 | case AMZN: 83 | fromIndex = 2000; 84 | toIndex = 3000; 85 | break; 86 | } 87 | return documents.subList(fromIndex, toIndex); 88 | } 89 | 90 | public List getSentiments() { 91 | return sentiments; 92 | } 93 | 94 | public List getDocuments() { 95 | return documents; 96 | } 97 | 98 | private void parseResource(String resource) throws IOException { 99 | try(InputStream inputStream = getClass().getResourceAsStream(resource)) { 100 | BufferedReader br = new BufferedReader(new InputStreamReader(inputStream)); 101 | String line; 102 | while ((line = br.readLine()) != null) { 103 | String[] splitLine = line.split("\t"); 104 | // both yelp and amzn have many sentences with no label 105 | if (splitLine.length > 1) { 106 | documents.add(splitLine[0]); 107 | sentiments.add(Integer.parseInt(splitLine[1])); 108 | } 109 | } 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/BasicBarChart.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | 19 | import java.io.File; 20 | import javafx.application.Application; 21 | import javafx.collections.FXCollections; 22 | import javafx.embed.swing.SwingFXUtils; 23 | import javafx.scene.Scene; 24 | import javafx.scene.SnapshotParameters; 25 | import javafx.scene.chart.BarChart; 26 | import javafx.scene.chart.CategoryAxis; 27 | import javafx.scene.chart.NumberAxis; 28 | import javafx.scene.chart.XYChart.Series; 29 | import javafx.scene.chart.XYChart.Data; 30 | import javafx.scene.image.WritableImage; 31 | import javafx.stage.Stage; 32 | import javax.imageio.ImageIO; 33 | 34 | /** 35 | * 36 | * @author Michael Brzustowicz 37 | */ 38 | public class BasicBarChart extends Application { 39 | 40 | 41 | /** 42 | * @param args the command line arguments 43 | */ 44 | public static void main(String[] args) { 45 | 46 | launch(args); 47 | } 48 | 49 | @Override 50 | public void start(Stage stage) throws Exception { 51 | 52 | String[] catData = {"Mon", "Tues", "Wed", "Thurs", "Fri"}; 53 | double[] yData = {1.3, 2.1, 3.3, 4.0, 4.8}; 54 | 55 | /* 56 | create some data 57 | */ 58 | Series series = new Series(); 59 | for (int i = 0; i < yData.length; i++) { 60 | series.getData().add(new Data(catData[i], yData[i])); 61 | } 62 | 63 | 64 | //defining the axes 65 | CategoryAxis xAxis = new CategoryAxis(); 66 | NumberAxis yAxis = new NumberAxis(); 67 | xAxis.setLabel("x"); 68 | yAxis.setLabel("y"); 69 | 70 | //creating the bar chart; 71 | BarChart barChart = new BarChart<>(xAxis, yAxis); 72 | barChart.setAnimated(false); 73 | barChart.getData().add(series); 74 | barChart.setTitle("x vs. y"); 75 | barChart.setHorizontalGridLinesVisible(false); 76 | barChart.setVerticalGridLinesVisible(false); 77 | barChart.setVerticalZeroLineVisible(false); 78 | 79 | /* 80 | create a scene using the chart 81 | */ 82 | Scene scene = new Scene(barChart,800,600); 83 | 84 | /* 85 | tell the stage what scene to use and render it! 86 | */ 87 | stage.setScene(scene); 88 | stage.show(); 89 | 90 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null); 91 | // File file = new File("chart.png"); 92 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file); 93 | 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/BasicScatterChart.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | 19 | import java.io.File; 20 | import javafx.application.Application; 21 | import javafx.embed.swing.SwingFXUtils; 22 | import javafx.scene.Scene; 23 | import javafx.scene.SnapshotParameters; 24 | import javafx.scene.chart.NumberAxis; 25 | import javafx.scene.chart.ScatterChart; 26 | import javafx.scene.chart.XYChart.Series; 27 | import javafx.scene.chart.XYChart.Data; 28 | import javafx.scene.image.WritableImage; 29 | import javafx.stage.Stage; 30 | import javax.imageio.ImageIO; 31 | 32 | /** 33 | * 34 | * @author Michael Brzustowicz 35 | */ 36 | public class BasicScatterChart extends Application { 37 | 38 | 39 | /** 40 | * @param args the command line arguments 41 | */ 42 | public static void main(String[] args) { 43 | 44 | launch(args); 45 | } 46 | 47 | @Override 48 | public void start(Stage stage) throws Exception { 49 | int[] xData = {1, 2, 3, 4, 5}; 50 | double[] yData = {1.3, 2.1, 3.3, 4.0, 4.8}; 51 | 52 | /* 53 | create some data 54 | */ 55 | Series series = new Series(); 56 | for (int i = 0; i < yData.length; i++) { 57 | series.getData().add(new Data(xData[i], yData[i])); 58 | } 59 | 60 | 61 | //defining the axes 62 | NumberAxis xAxis = new NumberAxis(); 63 | NumberAxis yAxis = new NumberAxis(); 64 | xAxis.setLabel("x"); 65 | yAxis.setLabel("y"); 66 | 67 | //creating the scatter chart 68 | ScatterChart scatterChart = new ScatterChart<>(xAxis,yAxis); 69 | scatterChart.setAnimated(false); 70 | scatterChart.getData().add(series); 71 | scatterChart.setTitle("x vs. y"); 72 | scatterChart.setHorizontalGridLinesVisible(false); 73 | scatterChart.setVerticalGridLinesVisible(false); 74 | scatterChart.setVerticalZeroLineVisible(false); 75 | 76 | /* 77 | create a scene using the chart 78 | */ 79 | Scene scene = new Scene(scatterChart,800,600); 80 | 81 | /* 82 | tell the stage what scene to use and render it! 83 | */ 84 | stage.setScene(scene); 85 | stage.show(); 86 | 87 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null); 88 | // File file = new File("chart.png"); 89 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file); 90 | 91 | } 92 | 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/DBApp.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | import java.sql.Connection; 19 | import java.sql.DriverManager; 20 | import java.sql.PreparedStatement; 21 | import java.sql.ResultSet; 22 | import java.sql.SQLException; 23 | import java.sql.Statement; 24 | 25 | /** 26 | * 27 | * @author Michael Brzustowicz 28 | */ 29 | public class DBApp { 30 | 31 | /** 32 | * @param args the command line arguments 33 | */ 34 | public static void main(String[] args) { 35 | 36 | String uri = "jdbc:mysql://localhost:3306/mydb?user=root"; 37 | 38 | try(Connection c = DriverManager.getConnection(uri)) { 39 | 40 | /* DROP / CREATE TABLE */ 41 | String dropSQL = "DROP TABLE IF EXISTS data"; 42 | String createSQL = "CREATE TABLE IF NOT EXISTS data(id INTEGER PRIMARY KEY, yr INTEGER, city VARCHAR(80))"; 43 | 44 | try (Statement stmt = c.createStatement()) { 45 | 46 | stmt.execute(dropSQL); 47 | 48 | stmt.execute(createSQL); 49 | 50 | } 51 | 52 | /* INSERT DATA */ 53 | String insertSQL = "INSERT INTO data(id, yr, city) VALUES(?, ?, ?)"; 54 | 55 | try (PreparedStatement ps = c.prepareStatement(insertSQL)) { 56 | 57 | /* for one entry */ 58 | ps.setInt(1, 1); 59 | ps.setInt(2, 2015); 60 | ps.setString(3, "San Francisco"); 61 | ps.execute(); 62 | } 63 | 64 | /* SELECT DATA */ 65 | String selectSQL = "SELECT id, yr, city FROM data"; 66 | try (Statement st = c.createStatement(); ResultSet rs = st.executeQuery(selectSQL)) { 67 | 68 | while(rs.next()) { 69 | // TODO ... do something with data 70 | System.out.println(rs.getInt("id")+" "+rs.getInt("yr")+" "+rs.getString("city")); 71 | } 72 | } 73 | 74 | } catch (SQLException e) { 75 | System.out.println(e.getMessage()); 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/DBInsertBatchApp.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | import java.sql.Connection; 19 | import java.sql.DriverManager; 20 | import java.sql.PreparedStatement; 21 | import java.sql.ResultSet; 22 | import java.sql.SQLException; 23 | import java.sql.Statement; 24 | import java.util.ArrayList; 25 | import java.util.List; 26 | 27 | /** 28 | * 29 | * @author Michael Brzustowicz 30 | */ 31 | public class DBInsertBatchApp { 32 | 33 | public static List getData() { 34 | List data = new ArrayList<>(); 35 | data.add(new Record(1, 2015, "San Francisco")); 36 | data.add(new Record(2, 2014, "New York")); 37 | data.add(new Record(3, 2012, "Los Angeles")); 38 | return data; 39 | } 40 | 41 | /** 42 | * @param args the command line arguments 43 | */ 44 | public static void main(String[] args) { 45 | 46 | String uri = "jdbc:mysql://localhost:3306/mydb?user=root"; 47 | 48 | try(Connection c = DriverManager.getConnection(uri)) { 49 | 50 | /* DROP / CREATE TABLE */ 51 | String dropSQL = "DROP TABLE IF EXISTS data"; 52 | String createSQL = "CREATE TABLE IF NOT EXISTS data(id INTEGER PRIMARY KEY, yr INTEGER, city VARCHAR(80))"; 53 | 54 | try (Statement stmt = c.createStatement()) { 55 | 56 | stmt.execute(dropSQL); 57 | 58 | stmt.execute(createSQL); 59 | 60 | } 61 | 62 | /* INSERT BATCH DATA */ 63 | String insertSQL = "INSERT INTO data(id, yr, city) VALUES(?, ?, ?)"; 64 | 65 | try (PreparedStatement ps = c.prepareStatement(insertSQL)) { 66 | 67 | for (Record data : getData()) { 68 | 69 | ps.setInt(1, data.id); 70 | ps.setInt(2, data.year); 71 | ps.setString(3, data.city); 72 | 73 | /* add record to the batch !!! */ 74 | ps.addBatch(); 75 | } 76 | 77 | /* note this is different than the regular execute */ 78 | ps.executeBatch(); 79 | 80 | } 81 | 82 | /* SELECT DATA */ 83 | String selectSQL = "SELECT id, yr, city FROM data"; 84 | try (Statement st = c.createStatement(); ResultSet rs = st.executeQuery(selectSQL)) { 85 | 86 | while(rs.next()) { 87 | // TODO ... do something with data 88 | System.out.println(rs.getInt("id")+" "+rs.getInt("yr")+" "+rs.getString("city")); 89 | } 90 | } 91 | 92 | } catch (SQLException e) { 93 | System.out.println(e.getMessage()); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/FileIOExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.FileReader; 20 | import java.io.IOException; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class FileIOExample { 27 | 28 | /** 29 | * @param args the command line arguments 30 | */ 31 | public static void main(String[] args) { 32 | 33 | String filename = ""; 34 | 35 | // or use args to get filename 36 | 37 | try(BufferedReader br = new BufferedReader(new FileReader(filename))) { 38 | 39 | String line; 40 | 41 | while ((line = br.readLine()) != null) { 42 | // TODO ... do something with line 43 | // TODO ... parse line e.g. CSV, TSV, JSON 44 | // TODO ... check each value if required 45 | System.out.println(line); 46 | } 47 | 48 | } catch (IOException e) { 49 | System.err.println(e); 50 | } 51 | 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/io/Record.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.io; 17 | 18 | /** 19 | * 20 | * @author Michael Brzustowicz 21 | */ 22 | public class Record { 23 | public int id; 24 | public int year; 25 | public String city; 26 | 27 | public Record() { 28 | } 29 | 30 | public Record(int id, int year, String city) { 31 | this.id = id; 32 | this.year = year; 33 | this.city = city; 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/BernoulliConditionalProbabilityEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 19 | 20 | /** 21 | * 22 | * @author Michael Brzustowicz 23 | */ 24 | public class BernoulliConditionalProbabilityEstimator implements ConditionalProbabilityEstimator { 25 | 26 | @Override 27 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) { 28 | int n = features.length; 29 | double[] means = mss.getMean(); // this is actually the prob per features e.g. count / total 30 | double prob = 1.0; 31 | for (int i = 0; i < n; i++) { 32 | // if x_i = 1, the p, if x_i = 0 then 1-p ... but here x_i is a double 33 | prob *= (features[i] > 0.0) ? means[i] : 1-means[i]; 34 | } 35 | return prob; 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/ClassifierAccuracy.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.ProbabilityEncoder; 19 | import org.apache.commons.math3.analysis.function.Abs; 20 | import org.apache.commons.math3.linear.ArrayRealVector; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | import org.apache.commons.math3.linear.RealVector; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class ClassifierAccuracy { 29 | 30 | private final RealMatrix predictions; 31 | private final RealMatrix targets; 32 | private final ProbabilityEncoder probabilityEncoder; 33 | private RealVector classCount; 34 | 35 | /** 36 | * 37 | * @param predictions assumes probabilities 38 | * @param targets assumes binary multi-label OR one-hot-encoding 39 | */ 40 | public ClassifierAccuracy(RealMatrix predictions, RealMatrix targets) { 41 | this.predictions = predictions; 42 | this.targets = targets; 43 | probabilityEncoder = new ProbabilityEncoder(); 44 | //tally the binary class occurances per dimension 45 | classCount = new ArrayRealVector(targets.getColumnDimension()); 46 | for (int i = 0; i < targets.getRowDimension(); i++) { 47 | classCount = classCount.add(targets.getRowVector(i)); 48 | } 49 | } 50 | 51 | /** 52 | * assumes one hot encoding 53 | * @return 54 | */ 55 | public RealVector getAccuracyPerDimension() { 56 | 57 | RealVector accuracy = new ArrayRealVector(predictions.getColumnDimension()); 58 | 59 | for (int i = 0; i < predictions.getRowDimension(); i++) { 60 | 61 | RealVector binarized = probabilityEncoder.getOneHot(predictions.getRowVector(i)); 62 | 63 | // 0*0, 0*1, 1*1 = 0 and 1*1 = 1 giving only true positives as 1 and all other 0 64 | RealVector decision = binarized.ebeMultiply(targets.getRowVector(i)); 65 | 66 | // append TP counts to accuracy 67 | accuracy = accuracy.add(decision); 68 | } 69 | return accuracy.ebeDivide(classCount); 70 | } 71 | 72 | /** 73 | * assumes one hot encoding 74 | * @return 75 | */ 76 | public double getAccuracy() { 77 | // convert accuracy_per_dim back to counts, then sum and divide by total rows 78 | return getAccuracyPerDimension().ebeMultiply(classCount).getL1Norm() / targets.getRowDimension(); 79 | } 80 | 81 | // implements jaccard similarity scores 82 | public RealVector getAccuracyPerDimension(double threshold) { // assumes un-correlated multi-output 83 | 84 | RealVector accuracy = new ArrayRealVector(targets.getColumnDimension()); 85 | 86 | for (int i = 0; i < predictions.getRowDimension(); i++) { 87 | 88 | //binarize the row vector according to the threshold 89 | RealVector binarized = probabilityEncoder.getBinary(predictions.getRowVector(i), threshold); 90 | 91 | // 0-0 (TN) and 1-1 (TP) = 0 while 1-0 = 1 and 0-1 = -1 92 | RealVector decision = binarized.subtract(targets.getRowVector(i)).map(new Abs()).mapMultiply(-1).mapAdd(1); 93 | 94 | // append either TP and TN counts to accuracy 95 | accuracy = accuracy.add(decision); 96 | } 97 | return accuracy.mapDivide((double) predictions.getRowDimension()); // accuracy for each dimension, given the threshold 98 | } 99 | 100 | public double getAccuracy(double threshold) { 101 | // mean of the accuracy vector 102 | return getAccuracyPerDimension(threshold).getL1Norm() / targets.getColumnDimension(); 103 | } 104 | 105 | 106 | 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/ConditionalProbabilityEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 19 | 20 | /** 21 | * 22 | * @author Michael Brzustowicz 23 | */ 24 | public interface ConditionalProbabilityEstimator { 25 | 26 | /** 27 | * 28 | * @param mss multivariate statistics object for this class 29 | * @param features 30 | * @return 31 | */ 32 | double getProbability(MultivariateSummaryStatistics mss, double[] features); 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/DeepNetwork.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | import java.util.ListIterator; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class DeepNetwork extends IterativeLearningProcess { 28 | 29 | private final List layers; 30 | 31 | public DeepNetwork() { 32 | this.layers = new ArrayList<>(); 33 | } 34 | 35 | public void addLayer(NetworkLayer networkLayer) { 36 | layers.add(networkLayer); 37 | } 38 | 39 | public List getLayers() { 40 | return layers; 41 | } 42 | 43 | @Override 44 | public RealMatrix predict(RealMatrix input) { 45 | 46 | /* the initial input MUST BE DEEP COPIED or is overwritten */ 47 | RealMatrix layerInput = input.copy(); 48 | 49 | for (NetworkLayer layer : layers) { 50 | 51 | layer.setInput(layerInput); 52 | 53 | /* calc the output and set to next layer input*/ 54 | RealMatrix output = layer.getOutput(layerInput); 55 | layer.setOutput(output); 56 | 57 | /* 58 | does not need a deep copy, but be aware that 59 | every layer input shares memory of prior layer output 60 | */ 61 | layerInput = output; 62 | 63 | } 64 | 65 | /* layerInput is holding the final output ... get a deep copy */ 66 | return layerInput.copy(); 67 | 68 | } 69 | 70 | @Override 71 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) { 72 | 73 | /* this is the gradient of the network error and starts the back prop process */ 74 | RealMatrix layerError = getLossFunction().getLossGradient(output, target).copy(); 75 | 76 | /* create list iterator and set cursor to last! */ 77 | ListIterator li = layers.listIterator(layers.size()); 78 | 79 | while (li.hasPrevious()) { 80 | 81 | NetworkLayer layer = (NetworkLayer) li.previous(); 82 | 83 | /* get error input from higher layer */ 84 | layer.setOutputError(layerError); 85 | 86 | /* this back propagates the error and updates weights */ 87 | layer.update(); 88 | 89 | /* pass along error to next layer down */ 90 | layerError = layer.getInputError(); 91 | 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/DeepNetworkIrisExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.MatrixResampler; 19 | import com.oreilly.dswj.datasets.Iris; 20 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction; 21 | import java.io.IOException; 22 | import org.apache.commons.math3.linear.RealMatrix; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class DeepNetworkIrisExample { 29 | 30 | /** 31 | * @param args the command line arguments 32 | * @throws java.io.IOException 33 | */ 34 | public static void main(String[] args) throws IOException { 35 | Iris iris = new Iris(); 36 | MatrixResampler mr = new MatrixResampler(iris.getData(), iris.getLabels()); 37 | mr.calculateTestTrainSplit(0.4, 0L); 38 | DeepNetwork net = new DeepNetwork(); 39 | net.addLayer(new NetworkLayer(4, 10, new TanhOutputFunction(), new GradientDescent(0.001))); 40 | net.addLayer(new NetworkLayer(10, 3, new SoftmaxOutputFunction(), new GradientDescent(0.001))); 41 | net.setLossFunction(new SoftMaxCrossEntropyLossFunction()); 42 | net.setBatchSize(0); 43 | net.setMaxIterations(6000); 44 | net.setTolerance(10E-6); 45 | net.learn(mr.getTrainingFeatures(), mr.getTrainingLabels()); 46 | RealMatrix predictions = net.predict(mr.getTestingFeatures()); 47 | ClassifierAccuracy acc = new ClassifierAccuracy(predictions, mr.getTestingLabels()); 48 | System.out.println("converged = " + net.isConverged()); 49 | System.out.println("iterations = " + net.getNumIterations()); 50 | System.out.println("accuracy = " + acc.getAccuracy()); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/DeepNetworkMNISTExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.datasets.MNIST; 19 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction; 20 | import java.io.IOException; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class DeepNetworkMNISTExample { 28 | 29 | /** 30 | * @param args the command line arguments 31 | * @throws java.io.IOException 32 | */ 33 | public static void main(String[] args) throws IOException { 34 | 35 | MNIST mnist = new MNIST(); 36 | 37 | DeepNetwork network = new DeepNetwork(); 38 | 39 | /* input, hidden and output layers */ 40 | network.addLayer(new NetworkLayer(784, 500, new TanhOutputFunction(), 41 | new GradientDescentMomentum(0.0001, 0.95))); 42 | 43 | network.addLayer(new NetworkLayer(500, 300, new TanhOutputFunction(), 44 | new GradientDescentMomentum(0.0001, 0.95))); 45 | 46 | network.addLayer(new NetworkLayer(300, 10, new SoftmaxOutputFunction(), 47 | new GradientDescentMomentum(0.0001, 0.95))); 48 | 49 | /* runtime parameters */ 50 | network.setLossFunction(new SoftMaxCrossEntropyLossFunction()); 51 | network.setMaxIterations(10000); 52 | network.setTolerance(10E-6); 53 | network.setBatchSize(200); 54 | 55 | /* learn */ 56 | network.learn(mnist.trainingData, mnist.trainingLabels); 57 | 58 | /* predict */ 59 | RealMatrix prediction = network.predict(mnist.testingData); 60 | 61 | /* compute accuracy */ 62 | ClassifierAccuracy accuracy = new ClassifierAccuracy(prediction, mnist.testingLabels); 63 | 64 | /* print report */ 65 | System.out.println("isConverged = " + network.isConverged()); 66 | System.out.println("numIter = " + network.getNumIterations()); 67 | System.out.println("error = " + network.getLoss()); 68 | System.out.println("accuracy = " + accuracy.getAccuracy()); 69 | System.out.println("accuracy per dim = " + accuracy.getAccuracyPerDimension()); 70 | 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/DeltaRule.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class DeltaRule implements Optimizer { 26 | 27 | private final double learningRate; 28 | 29 | public DeltaRule(double learningRate) { 30 | this.learningRate = learningRate; 31 | } 32 | 33 | @Override 34 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) { 35 | return weightGradient.scalarMultiply(-1.0 * learningRate); 36 | } 37 | 38 | @Override 39 | public RealVector getBiasUpdate(RealVector biasGradient) { 40 | return biasGradient.mapMultiply(-1.0 * learningRate); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/GaussianConditionalProbabilityEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.distribution.NormalDistribution; 19 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 20 | 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class GaussianConditionalProbabilityEstimator implements ConditionalProbabilityEstimator{ 27 | 28 | @Override 29 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) { 30 | double[] means = mss.getMean(); 31 | double[] stds = mss.getStandardDeviation(); 32 | double prob = 1.0; 33 | for (int i = 0; i < features.length; i++) { 34 | prob *= new NormalDistribution(means[i], stds[i]).density(features[i]); 35 | } 36 | return prob; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/GradientDescent.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class GradientDescent implements Optimizer { 26 | 27 | private double learningRate; 28 | 29 | public GradientDescent(double learningRate) { 30 | this.learningRate = learningRate; 31 | } 32 | 33 | public void setLearningRate(double learningRate) { 34 | this.learningRate = learningRate; 35 | } 36 | 37 | @Override 38 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) { 39 | return weightGradient.scalarMultiply(-1.0 * learningRate); 40 | } 41 | 42 | @Override 43 | public RealVector getBiasUpdate(RealVector biasGradient) { 44 | return biasGradient.mapMultiply(-1.0 * learningRate); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/GradientDescentMomentum.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.ArrayRealVector; 19 | import org.apache.commons.math3.linear.BlockRealMatrix; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class GradientDescentMomentum extends GradientDescent { 28 | 29 | private final double momentum; 30 | private RealMatrix priorWeightUpdate; 31 | private RealVector priorBiasUpdate; 32 | 33 | /** 34 | * 35 | * @param learningRate good default is 0.0001 36 | * @param momentum good default is 0.95 37 | */ 38 | public GradientDescentMomentum(double learningRate, double momentum) { 39 | super(learningRate); 40 | this.momentum = momentum; 41 | priorWeightUpdate = null; 42 | priorBiasUpdate = null; 43 | } 44 | 45 | @Override 46 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) { 47 | // creates matrix of zeros same size as gradients if not already exists 48 | if(priorWeightUpdate == null) { 49 | priorWeightUpdate = new BlockRealMatrix(weightGradient.getRowDimension(), weightGradient.getColumnDimension()); 50 | } 51 | // add term from GradientDescent since it is already negative ( - eta * gradW ) 52 | RealMatrix update = priorWeightUpdate.scalarMultiply(momentum).add(super.getWeightUpdate(weightGradient)); 53 | priorWeightUpdate = update; 54 | return update; 55 | } 56 | 57 | @Override 58 | public RealVector getBiasUpdate(RealVector biasGradient) { 59 | if(priorBiasUpdate == null) { 60 | priorBiasUpdate = new ArrayRealVector(biasGradient.getDimension()); 61 | } 62 | // add term from GradientDescent since it is already negative ( - eta * gradW ) 63 | RealVector update = priorBiasUpdate.mapMultiply(momentum).add(super.getBiasUpdate(biasGradient)); 64 | priorBiasUpdate = update; 65 | return update; 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/IterativeLearningProcess.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.Batch; 19 | import com.oreilly.dswj.dataops.LossFunction; 20 | import com.oreilly.dswj.dataops.QuadraticLossFunction; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class IterativeLearningProcess { 28 | 29 | private boolean converged; 30 | private int numIterations; 31 | private int maxIterations; 32 | private double loss; 33 | private double tolerance; 34 | private int batchSize; // if == 0 then uses ALL data 35 | private LossFunction lossFunction; 36 | 37 | public IterativeLearningProcess(LossFunction lossFunction) { 38 | this.lossFunction = lossFunction; 39 | loss = 0; 40 | converged = false; 41 | numIterations = 0; 42 | maxIterations = 200; 43 | tolerance = 10E-6; 44 | batchSize = 100; 45 | } 46 | 47 | public IterativeLearningProcess() { 48 | this(new QuadraticLossFunction()); 49 | } 50 | 51 | public void learn(RealMatrix input, RealMatrix target) { 52 | 53 | double priorLoss = tolerance; 54 | 55 | numIterations = 0; 56 | 57 | loss = 0; 58 | 59 | converged = false; 60 | 61 | Batch batch = new Batch(input, target); 62 | RealMatrix inputBatch; 63 | RealMatrix targetBatch; 64 | 65 | 66 | while(numIterations < maxIterations && !converged) { 67 | 68 | if(batchSize > 0 && batchSize < input.getRowDimension()) { 69 | 70 | batch.calcNextBatch(batchSize); 71 | inputBatch = batch.getInputBatch(); 72 | targetBatch = batch.getTargetBatch(); 73 | 74 | } else { 75 | 76 | inputBatch = input; 77 | targetBatch = target; 78 | 79 | } 80 | 81 | RealMatrix outputBatch = predict(inputBatch); 82 | 83 | loss = lossFunction.getMeanLoss(outputBatch, targetBatch); 84 | 85 | if(Math.abs(priorLoss - loss) < tolerance) { 86 | 87 | converged = true; 88 | 89 | } else { 90 | 91 | update(inputBatch, targetBatch, outputBatch); 92 | 93 | priorLoss = loss; 94 | 95 | } 96 | 97 | numIterations++; 98 | } 99 | 100 | } 101 | 102 | public RealMatrix predict(RealMatrix input) { 103 | throw new UnsupportedOperationException("Implement the predict method!"); 104 | } 105 | 106 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) { 107 | throw new UnsupportedOperationException("Implement the update method!"); 108 | } 109 | 110 | public void setBatchSize(int batchSize) { 111 | this.batchSize = batchSize; 112 | } 113 | 114 | public void setMaxIterations(int maxIterations) { 115 | this.maxIterations = maxIterations; 116 | } 117 | 118 | public void setTolerance(double tolerance) { 119 | this.tolerance = tolerance; 120 | } 121 | 122 | public int getNumIterations() { 123 | return numIterations; 124 | } 125 | 126 | public double getLoss() { 127 | return loss; 128 | } 129 | 130 | public LossFunction getLossFunction() { 131 | return lossFunction; 132 | } 133 | 134 | public void setLossFunction(LossFunction lossFunction) { 135 | this.lossFunction = lossFunction; 136 | } 137 | 138 | public boolean isConverged() { 139 | return converged; 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/KMeansExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; 21 | import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer; 22 | import org.apache.commons.math3.distribution.NormalDistribution; 23 | import org.apache.commons.math3.ml.clustering.CentroidCluster; 24 | import org.apache.commons.math3.ml.clustering.DoublePoint; 25 | import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances; 26 | import org.apache.commons.math3.ml.distance.EuclideanDistance; 27 | 28 | /** 29 | * 30 | * @author Michael Brzustowicz 31 | */ 32 | public class KMeansExample { 33 | 34 | /** 35 | * @param args the command line arguments 36 | */ 37 | public static void main(String[] args) { 38 | int p = 7; //dimensions 39 | int n = 1000; //num points 40 | 41 | List data = new ArrayList<>(); 42 | NormalDistribution dist = new NormalDistribution(); 43 | for (int i = 0; i < 1000; i++) { 44 | data.add(new DoublePoint(dist.sample(p))); 45 | } 46 | 47 | for (int i = 1; i < 5; i++) { 48 | 49 | System.out.println("cluster: " + i); 50 | 51 | KMeansPlusPlusClusterer kmpp = new KMeansPlusPlusClusterer<>(i); 52 | List> results = kmpp.cluster(data); 53 | 54 | /* use cluster vars to observe cluster quality */ 55 | SumOfClusterVariances clusterVar = new SumOfClusterVariances<>(new EuclideanDistance()); 56 | System.out.println("score: " + clusterVar.score(results)); 57 | 58 | for (CentroidCluster result : results) { 59 | DoublePoint centroid = (DoublePoint) result.getCenter(); 60 | // System.out.println(centroid); 61 | // result.getPoints(); 62 | } 63 | } 64 | 65 | /* performs k++ numTrials times and returns only the best result */ 66 | int numTrials = 10; 67 | for (int i = 1; i < 5; i++) { 68 | System.out.println("MULTI " + i); 69 | KMeansPlusPlusClusterer kmpp2 = new KMeansPlusPlusClusterer<>(i); 70 | MultiKMeansPlusPlusClusterer multiKMPP = new MultiKMeansPlusPlusClusterer<>(kmpp2, numTrials); 71 | List> multiResults = multiKMPP.cluster(data); 72 | /* use cluster vars to observe cluster quality */ 73 | SumOfClusterVariances clusterVar = new SumOfClusterVariances<>(new EuclideanDistance()); 74 | System.out.println("score: " + clusterVar.score(multiResults)); 75 | for (CentroidCluster multiResult : multiResults) { 76 | // System.out.println(multiResult.getCenter()); 77 | } 78 | } 79 | 80 | 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/LinearModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.linalg.RandomizedMatrix; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.RealVector; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class LinearModel { 27 | 28 | private RealMatrix weight; 29 | private RealVector bias; 30 | private final OutputFunction outputFunction; 31 | 32 | public LinearModel(int inputDimension, int outputDimension, 33 | OutputFunction outputFunction) { 34 | RandomizedMatrix randM = new RandomizedMatrix(); 35 | weight = randM.getMatrix(inputDimension, outputDimension); 36 | bias = randM.getVector(outputDimension); 37 | this.outputFunction = outputFunction; 38 | } 39 | 40 | public RealMatrix getOutput(RealMatrix input) { 41 | return outputFunction.getOutput(input, weight, bias); 42 | } 43 | 44 | public void addUpdateToWeight(RealMatrix weightUpdate) { 45 | weight = weight.add(weightUpdate); 46 | } 47 | 48 | public void addUpdateToBias(RealVector biasUpdate) { 49 | bias = bias.add(biasUpdate); 50 | } 51 | 52 | /* setter and getters */ 53 | 54 | public void setWeight(RealMatrix weight) { 55 | this.weight = weight; 56 | } 57 | 58 | public void setBias(RealVector bias) { 59 | this.bias = bias; 60 | } 61 | 62 | public RealMatrix getWeight() { 63 | return weight; 64 | } 65 | 66 | public RealVector getBias() { 67 | return bias; 68 | } 69 | 70 | public OutputFunction getOutputFunction() { 71 | return outputFunction; 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/LinearModelEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.LossFunction; 19 | import org.apache.commons.math3.linear.ArrayRealVector; 20 | import org.apache.commons.math3.linear.RealMatrix; 21 | import org.apache.commons.math3.linear.RealVector; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class LinearModelEstimator extends IterativeLearningProcess { 28 | 29 | private final LinearModel linearModel; 30 | private final Optimizer optimizer; 31 | 32 | public LinearModelEstimator( 33 | LinearModel linearModel, 34 | LossFunction lossFunction, 35 | Optimizer optimizer) { 36 | super(lossFunction); 37 | this.linearModel = linearModel; 38 | this.optimizer = optimizer; 39 | } 40 | 41 | @Override 42 | public RealMatrix predict(RealMatrix input) { 43 | return linearModel.getOutput(input); 44 | } 45 | 46 | @Override 47 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) { 48 | RealMatrix weightGradient = input.transpose().multiply(output.subtract(target)); 49 | RealMatrix weightUpdate = optimizer.getWeightUpdate(weightGradient); 50 | linearModel.addUpdateToWeight(weightUpdate); 51 | 52 | RealVector h = new ArrayRealVector(input.getRowDimension(), 1.0); 53 | RealVector biasGradient = output.subtract(target).preMultiply(h); 54 | RealVector biasUpdate = optimizer.getBiasUpdate(biasGradient); 55 | linearModel.addUpdateToBias(biasUpdate); 56 | 57 | } 58 | 59 | public LinearModel getLinearModel() { 60 | return linearModel; 61 | } 62 | 63 | public Optimizer getOptimizer() { 64 | return optimizer; 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/LinearOutputFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.linalg.MatrixOperations; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.RealVector; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class LinearOutputFunction implements OutputFunction { 27 | 28 | @Override 29 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) { 30 | return MatrixOperations.XWplusB(input, weight, bias); 31 | } 32 | 33 | @Override 34 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) { 35 | // output gradient is all 1's ... so just return errorGradient 36 | return errorGradient; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/LogisticOutputFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.linalg.MatrixOperations; 19 | import com.oreilly.dswj.linalg.UnivariateFunctionMapper; 20 | import org.apache.commons.math3.analysis.UnivariateFunction; 21 | import org.apache.commons.math3.analysis.function.Sigmoid; 22 | import org.apache.commons.math3.linear.RealMatrix; 23 | import org.apache.commons.math3.linear.RealVector; 24 | 25 | /** 26 | * 27 | * @author Michael Brzustowicz 28 | */ 29 | public class LogisticOutputFunction implements OutputFunction { 30 | 31 | @Override 32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) { 33 | return MatrixOperations.XWplusB(input, weight, bias, new Sigmoid()); 34 | } 35 | 36 | @Override 37 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) { 38 | 39 | // this changes output permanently 40 | output.walkInOptimizedOrder(new UnivariateFunctionMapper(new LogisticGradient())); 41 | 42 | // output is now the output gradient 43 | return MatrixOperations.ebeMultiply(errorGradient, output); 44 | } 45 | 46 | private class LogisticGradient implements UnivariateFunction { 47 | 48 | @Override 49 | public double value(double x) { 50 | return x * (1 - x); 51 | } 52 | 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/MultinomialConditionalProbabilityEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 19 | 20 | /** 21 | * 22 | * @author Michael Brzustowicz 23 | */ 24 | public class MultinomialConditionalProbabilityEstimator implements ConditionalProbabilityEstimator { 25 | 26 | private double alpha; 27 | 28 | public MultinomialConditionalProbabilityEstimator(double alpha) { 29 | this.alpha = alpha; // Lidstone smoothing 0 > alpha > 1 30 | } 31 | 32 | public MultinomialConditionalProbabilityEstimator() { 33 | this(1); // Laplace smoothing 34 | } 35 | 36 | @Override 37 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) { 38 | int n = features.length; 39 | double prob = 1.0; 40 | double[] sum = mss.getSum(); // array of x_i sums for this class 41 | double total = 0.0; // total count of all features 42 | for (int i = 0; i < n; i++) { 43 | total += sum[i]; 44 | } 45 | 46 | /* works great for smaller x_i ie features[i] */ 47 | // for (int i = 0; i < n; i++) { 48 | // prob *= Math.pow((sum[i] + alpha) / (total + alpha * n), features[i]); 49 | // } 50 | // return prob; 51 | 52 | /* for large x_i need to solve in log space and convert back with exp */ 53 | prob = 0; 54 | for (int i = 0; i < n; i++) { 55 | prob += features[i] * Math.log((sum[i] + alpha) / (total + alpha * n)); 56 | } 57 | return Math.exp(prob); 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/NaiveBayes.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import java.util.HashMap; 19 | import java.util.Map; 20 | import org.apache.commons.math3.linear.BlockRealMatrix; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class NaiveBayes { 29 | 30 | private Map statistics; 31 | private final Map stats; 32 | private final ConditionalProbabilityEstimator conditionalProbabilityEstimator; 33 | private int numberOfPoints; // total number of points the model was trained on 34 | 35 | public NaiveBayes() { 36 | this(new GaussianConditionalProbabilityEstimator()); 37 | } 38 | 39 | /** 40 | * provide the strategy for implementing the conditional probability 41 | * @param conditionalProbabilityEstimator 42 | */ 43 | public NaiveBayes(ConditionalProbabilityEstimator conditionalProbabilityEstimator) { 44 | stats = new HashMap<>(); 45 | statistics = new HashMap<>(); 46 | this.conditionalProbabilityEstimator = conditionalProbabilityEstimator; 47 | numberOfPoints = 0; 48 | } 49 | 50 | 51 | /** 52 | * 53 | * @param input 54 | * @param target multi class OR one hot encoded labels 55 | */ 56 | public void learn(RealMatrix input, RealMatrix target) { 57 | // if numTargetCols == 1 then multiclass e.g. 0, 1, 2, 3 58 | // else one-hot e.g. 1000, 0100, 0010, 0001 59 | 60 | numberOfPoints += input.getRowDimension(); 61 | 62 | for (int i = 0; i < input.getRowDimension(); i++) { 63 | 64 | double[] rowData = input.getRow(i); 65 | int label; 66 | 67 | if (target.getColumnDimension()==1) { 68 | label = new Double(target.getEntry(i, 0)).intValue(); 69 | } else { 70 | label = target.getRowVector(i).getMaxIndex(); 71 | } 72 | 73 | if(!statistics.containsKey(label)) { 74 | statistics.put(label, new MultivariateSummaryStatistics(rowData.length, true)); 75 | } 76 | 77 | statistics.get(label).addValue(rowData); 78 | 79 | } 80 | } 81 | 82 | public RealMatrix predict(RealMatrix input) { 83 | 84 | int numRows = input.getRowDimension(); 85 | int numCols = statistics.size(); 86 | RealMatrix predictions = new BlockRealMatrix(numRows, numCols); 87 | 88 | for (int i = 0; i < numRows; i++) { 89 | 90 | // double[] rowData = input.getRow(i); 91 | double[] probs = new double[numCols]; 92 | double sumProbs = 0; 93 | 94 | for (Map.Entry entrySet : statistics.entrySet()) { 95 | 96 | Integer classNumber = entrySet.getKey(); // assumes these are 0, 1, 2 ... n-1 97 | MultivariateSummaryStatistics mss = entrySet.getValue(); 98 | 99 | /* prior prob n_k / N ie num points in class divided by total points */ 100 | double prob = new Long(mss.getN()).doubleValue() / numberOfPoints; 101 | 102 | /* depends on type ... Gaussian, Multinomial or Bernoulli */ 103 | prob *= conditionalProbabilityEstimator.getProbability(mss, input.getRow(i)); 104 | 105 | probs[classNumber] = prob; 106 | sumProbs += prob; 107 | } 108 | 109 | /* L1 norm the probs */ 110 | for (int j = 0; j < numCols; j++) { 111 | probs[j] /= sumProbs; 112 | } 113 | 114 | predictions.setRow(i, probs); 115 | 116 | } 117 | 118 | return predictions; 119 | } 120 | 121 | } 122 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/NaiveBayesGaussianIrisExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.MatrixResampler; 19 | import com.oreilly.dswj.datasets.Iris; 20 | import java.io.IOException; 21 | import org.apache.commons.math3.linear.RealMatrix; 22 | 23 | /** 24 | * 25 | * @author Michael Brzustowicz 26 | */ 27 | public class NaiveBayesGaussianIrisExample { 28 | 29 | /** 30 | * @param args the command line arguments 31 | * @throws java.io.IOException 32 | */ 33 | public static void main(String[] args) throws IOException { 34 | 35 | Iris iris = new Iris(); 36 | MatrixResampler mr = new MatrixResampler(iris.getData(), iris.getLabels()); 37 | mr.calculateTestTrainSplit(0.4, 0L); 38 | 39 | NaiveBayes nb = new NaiveBayes(new GaussianConditionalProbabilityEstimator()); 40 | nb.learn(mr.getTrainingFeatures(), mr.getTrainingLabels()); 41 | 42 | RealMatrix predictions = nb.predict(mr.getTestingFeatures()); 43 | 44 | ClassifierAccuracy acc = new ClassifierAccuracy(predictions, mr.getTestingLabels()); 45 | System.out.println(acc.getAccuracyPerDimension()); 46 | System.out.println(acc.getAccuracy()); 47 | 48 | 49 | 50 | 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/NetworkLayer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.ArrayRealVector; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.RealVector; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class NetworkLayer extends LinearModel { 27 | 28 | RealMatrix input; 29 | RealMatrix inputError; 30 | RealMatrix output; 31 | RealMatrix outputError; 32 | Optimizer optimizer; 33 | 34 | public NetworkLayer(int inputDimension, int outputDimension, 35 | OutputFunction outputFunction, Optimizer optimizer) { 36 | super(inputDimension, outputDimension, outputFunction); 37 | this.optimizer = optimizer; 38 | } 39 | 40 | public void update() { 41 | //back propagate error 42 | /* D = eps o f'(XW) where o is Hadamard product or J f'(XW) where J is Jacobian */ 43 | RealMatrix deltas = getOutputFunction().getDelta(outputError, output); 44 | 45 | /* E_out = D W^T */ 46 | inputError = deltas.multiply(getWeight().transpose()); 47 | 48 | /* W = W - alpha * delta * input */ 49 | RealMatrix weightGradient = input.transpose().multiply(deltas); 50 | 51 | /* w_{t+1} = w_{t} + \delta w_{t} */ 52 | addUpdateToWeight(optimizer.getWeightUpdate(weightGradient)); 53 | 54 | // this essentially sums the columns of delta and that vector is grad_b 55 | RealVector h = new ArrayRealVector(input.getRowDimension(), 1.0); 56 | RealVector biasGradient = deltas.preMultiply(h); 57 | addUpdateToBias(optimizer.getBiasUpdate(biasGradient)); 58 | } 59 | 60 | public void setOutputError(RealMatrix outputError) { 61 | this.outputError = outputError; 62 | } 63 | 64 | public void setInputError(RealMatrix inputError) { 65 | this.inputError = inputError; 66 | } 67 | 68 | public void setInput(RealMatrix input) { 69 | this.input = input; 70 | } 71 | 72 | public void setOutput(RealMatrix output) { 73 | this.output = output; 74 | } 75 | 76 | public RealMatrix getInputError() { 77 | return inputError; 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/Optimizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public interface Optimizer { 26 | RealMatrix getWeightUpdate(RealMatrix weightGradient); 27 | RealVector getBiasUpdate(RealVector biasGradient); 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/OutputFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import org.apache.commons.math3.linear.RealMatrix; 19 | import org.apache.commons.math3.linear.RealVector; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public interface OutputFunction { 26 | RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias); 27 | RealMatrix getDelta(RealMatrix error, RealMatrix output); 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/SilhouetteCoefficient.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import java.util.List; 19 | import org.apache.commons.math3.exception.OutOfRangeException; 20 | import org.apache.commons.math3.linear.ArrayRealVector; 21 | import org.apache.commons.math3.linear.RealVector; 22 | import org.apache.commons.math3.ml.clustering.Cluster; 23 | import org.apache.commons.math3.ml.clustering.DoublePoint; 24 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 25 | 26 | /** 27 | * 28 | * @author Michael Brzustowicz 29 | */ 30 | public class SilhouetteCoefficient { 31 | private final List> clusters; 32 | private double coefficient;// = 0.0; 33 | private int numClusters; 34 | private int numSamples; 35 | 36 | public SilhouetteCoefficient(List> clusters) throws OutOfRangeException { 37 | 38 | this.clusters = clusters; 39 | 40 | if(checkClusterSize()) { 41 | calculateMeanCoefficient(); 42 | } else { 43 | throw new OutOfRangeException(clusters.size(), 2, numSamples - 1); 44 | } 45 | } 46 | 47 | public double getCoefficient() { 48 | return coefficient; 49 | } 50 | 51 | 52 | private void calculateMeanCoefficient() { 53 | SummaryStatistics stats = new SummaryStatistics(); 54 | int clusterNumber = 0; 55 | for (Cluster cluster : clusters) { 56 | for (DoublePoint point : cluster.getPoints()) { 57 | double s = calculateCoefficientForOnePoint(point, clusterNumber); 58 | stats.addValue(s); 59 | } 60 | clusterNumber++; 61 | } 62 | coefficient = stats.getMean(); 63 | } 64 | 65 | private double calculateCoefficientForOnePoint(DoublePoint onePoint, int clusterLabel) { 66 | 67 | /* all other points will compared to this one */ 68 | RealVector vector = new ArrayRealVector(onePoint.getPoint()); 69 | 70 | double a = 0; 71 | double b = Double.MAX_VALUE; 72 | 73 | int clusterNumber = 0; 74 | 75 | for (Cluster cluster : clusters) { 76 | 77 | SummaryStatistics clusterStats = new SummaryStatistics(); 78 | 79 | for (DoublePoint otherPoint : cluster.getPoints()) { 80 | RealVector otherVector = new ArrayRealVector(otherPoint.getPoint()); 81 | double dist = vector.getDistance(otherVector); 82 | clusterStats.addValue(dist); 83 | } 84 | 85 | double avgDistance = clusterStats.getMean(); 86 | 87 | if(clusterNumber==clusterLabel) { 88 | /* we have included a 0 distance of point with itself and need to subtract it out of the mean */ 89 | double n = new Long(clusterStats.getN()).doubleValue(); 90 | double correction = n / (n - 1.0); 91 | a = correction * avgDistance; 92 | } else { 93 | b = Math.min(avgDistance, b); 94 | } 95 | 96 | clusterNumber++; 97 | } 98 | 99 | return (b-a) / Math.max(a, b); 100 | } 101 | 102 | private boolean checkClusterSize() throws OutOfRangeException { 103 | numClusters = clusters.size(); 104 | numSamples = 0; 105 | for (Cluster cluster : clusters) { 106 | numSamples += cluster.getPoints().size(); 107 | } 108 | return numClusters >= 2 && numClusters <= (numSamples - 1) ; 109 | } 110 | 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/SoftmaxLinearModelExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.MatrixResampler; 19 | import com.oreilly.dswj.datasets.Iris; 20 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction; 21 | import java.io.IOException; 22 | import org.apache.commons.math3.linear.RealMatrix; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class SoftmaxLinearModelExample { 29 | 30 | /** 31 | * @param args the command line arguments 32 | * @throws java.io.IOException 33 | */ 34 | public static void main(String[] args) throws IOException { 35 | Iris iris = new Iris(); 36 | MatrixResampler resampler = new MatrixResampler(iris.getData(), iris.getLabels()); 37 | resampler.calculateTestTrainSplit(0.40, 0L); 38 | 39 | LinearModelEstimator estimator = new LinearModelEstimator( 40 | new LinearModel(4, 3, new SoftmaxOutputFunction()), 41 | new SoftMaxCrossEntropyLossFunction(), 42 | new DeltaRule(0.001)); 43 | 44 | 45 | // /* this is the SAME thing as a lone layer network */ 46 | // DeepNetwork estimator = new DeepNetwork(); 47 | // estimator.setLossFunction(new SoftMaxCrossEntropyLossFunction()); 48 | // estimator.addLayer(new NetworkLayer(4, 3, new SoftmaxOutputFunction(), new GradientDescent(0.001))); 49 | // estimator.setBatchSize(0); 50 | 51 | 52 | estimator.setMaxIterations(6000); 53 | estimator.setTolerance(10E-6); 54 | 55 | estimator.learn(resampler.getTrainingFeatures(), resampler.getTrainingLabels()); 56 | 57 | RealMatrix prediction = estimator.predict(resampler.getTestingFeatures()); 58 | 59 | ClassifierAccuracy accuracy = new ClassifierAccuracy(prediction, resampler.getTestingLabels()); 60 | 61 | System.out.println("isConverged " + estimator.isConverged()); 62 | System.out.println("numIterations " + estimator.getNumIterations()); 63 | System.out.println("loss " + estimator.getLoss()); 64 | System.out.println("accuracy " + accuracy.getAccuracy()); 65 | System.out.println("accuracy per dim " + accuracy.getAccuracyPerDimension()); 66 | 67 | //isConverged true 68 | //numIterations 3094 69 | //loss 0.07695531148974591 70 | //accuracy 0.9833333333333333 71 | //accuracy per dim {1; 0.9230769231; 1} 72 | 73 | 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/SoftmaxOutputFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.dataops.MatrixScaler; 19 | import com.oreilly.dswj.linalg.MatrixOperations; 20 | import org.apache.commons.math3.analysis.function.Exp; 21 | import org.apache.commons.math3.linear.BlockRealMatrix; 22 | import org.apache.commons.math3.linear.RealMatrix; 23 | import org.apache.commons.math3.linear.RealVector; 24 | 25 | /** 26 | * 27 | * @author Michael Brzustowicz 28 | */ 29 | public class SoftmaxOutputFunction implements OutputFunction { 30 | 31 | @Override 32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) { 33 | RealMatrix output = MatrixOperations.XWplusB(input, weight, bias, new Exp()); 34 | MatrixScaler.l1(output); 35 | return output; 36 | } 37 | 38 | @Override 39 | public RealMatrix getDelta(RealMatrix error, RealMatrix output) { 40 | 41 | RealMatrix delta = new BlockRealMatrix(error.getRowDimension(), error.getColumnDimension()); 42 | 43 | for (int i = 0; i < output.getRowDimension(); i++) { 44 | delta.setRowVector(i, getJacobian(output.getRowVector(i)).preMultiply(error.getRowVector(i))); 45 | } 46 | 47 | return delta; 48 | } 49 | 50 | private RealMatrix getJacobian(RealVector output) { 51 | 52 | int numRows = output.getDimension(); 53 | 54 | int numCols = output.getDimension(); 55 | 56 | RealMatrix jacobian = new BlockRealMatrix(numRows, numCols); 57 | 58 | for (int i = 0; i < numRows; i++) { 59 | 60 | double output_i = output.getEntry(i); 61 | 62 | for (int j = i; j < numCols; j++) { 63 | 64 | double output_j = output.getEntry(j); 65 | 66 | if(i==j) { 67 | 68 | jacobian.setEntry(i, i, output_i*(1-output_i)); 69 | 70 | } else { 71 | 72 | jacobian.setEntry(i, j, -1.0 * output_i * output_j); 73 | jacobian.setEntry(j, i, -1.0 * output_j * output_i); 74 | } 75 | 76 | } 77 | } 78 | return jacobian; 79 | } 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/learn/TanhOutputFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.learn; 17 | 18 | import com.oreilly.dswj.linalg.MatrixOperations; 19 | import com.oreilly.dswj.linalg.UnivariateFunctionMapper; 20 | import org.apache.commons.math3.analysis.UnivariateFunction; 21 | import org.apache.commons.math3.analysis.function.Tanh; 22 | import org.apache.commons.math3.linear.RealMatrix; 23 | import org.apache.commons.math3.linear.RealVector; 24 | 25 | /** 26 | * 27 | * @author Michael Brzustowicz 28 | */ 29 | public class TanhOutputFunction implements OutputFunction { 30 | 31 | @Override 32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) { 33 | return MatrixOperations.XWplusB(input, weight, bias, new Tanh()); 34 | } 35 | 36 | @Override 37 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) { 38 | // this changes output permanently 39 | output.walkInOptimizedOrder(new UnivariateFunctionMapper(new TanhGradient())); 40 | 41 | // output is now the output gradient 42 | return MatrixOperations.ebeMultiply(errorGradient, output); 43 | } 44 | 45 | private class TanhGradient implements UnivariateFunction { 46 | 47 | @Override 48 | public double value(double x) { 49 | return (1 - x * x); 50 | } 51 | 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/linalg/FunctionMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.linalg; 17 | 18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor; 19 | 20 | /** 21 | * 22 | * @author Michael Brzustowicz 23 | */ 24 | public class FunctionMapper implements RealMatrixChangingVisitor { 25 | 26 | private final double power; 27 | 28 | public FunctionMapper(double power) { 29 | this.power = power; 30 | } 31 | 32 | @Override 33 | public void start(int rows, int columns, int startRow, int endRow, 34 | int startColumn, int endColumn) { 35 | // do nothing 36 | } 37 | 38 | @Override 39 | public double visit(int row, int column, double value) { 40 | return Math.pow(value, power); 41 | } 42 | 43 | @Override 44 | public double end() { 45 | return 0; 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/linalg/LinearSystemExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.linalg; 17 | 18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 19 | import org.apache.commons.math3.linear.RealMatrix; 20 | import org.apache.commons.math3.linear.SingularValueDecomposition; 21 | 22 | /** 23 | * 24 | * @author Michael Brzustowicz 25 | */ 26 | public class LinearSystemExample { 27 | 28 | /** 29 | * @param args the command line arguments 30 | */ 31 | public static void main(String[] args) { 32 | double[][] xData = {{0, 0.5, 0.2}, {1, 1.2, .9}, {2, 2.5, 1.9}, {3, 3.6, 4.2}}; 33 | double[][] yData = {{-1, -0.5}, {0.2, 1}, {0.9, 1.2}, {2.1, 1.5}}; 34 | double[] ones = {1.0, 1.0, 1.0, 1.0}; 35 | int xRows = 4; 36 | int xCols = 3; 37 | RealMatrix x = new Array2DRowRealMatrix(xRows, xCols + 1); 38 | x.setSubMatrix(xData, 0, 0); 39 | x.setColumn(3, ones); // 4th column is index of 3 !!! 40 | RealMatrix y = new Array2DRowRealMatrix(yData); 41 | 42 | SingularValueDecomposition svd = new SingularValueDecomposition(x); 43 | RealMatrix solution = svd.getSolver().solve(y); 44 | System.out.println(solution); 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/linalg/MatrixOperations.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.linalg; 17 | 18 | import org.apache.commons.math3.analysis.UnivariateFunction; 19 | import org.apache.commons.math3.distribution.AbstractRealDistribution; 20 | import org.apache.commons.math3.distribution.UniformRealDistribution; 21 | import org.apache.commons.math3.linear.ArrayRealVector; 22 | import org.apache.commons.math3.linear.BlockRealMatrix; 23 | import org.apache.commons.math3.linear.RealMatrix; 24 | import org.apache.commons.math3.linear.RealVector; 25 | 26 | /** 27 | * 28 | * @author Michael Brzustowicz 29 | */ 30 | public class MatrixOperations { 31 | 32 | // TODO name this similar to BLAS ??? 33 | public static RealMatrix XWplusB(RealMatrix X, RealMatrix W, RealVector b) { 34 | RealVector h = new ArrayRealVector(X.getRowDimension(), 1.0); 35 | return X.multiply(W).add(h.outerProduct(b)); 36 | } 37 | 38 | public static RealMatrix XWplusB(RealMatrix X, RealMatrix W, RealVector b, UnivariateFunction univariateFunction) { 39 | RealMatrix z = XWplusB(X, W, b); 40 | z.walkInOptimizedOrder(new UnivariateFunctionMapper(univariateFunction)); 41 | return z; 42 | } 43 | 44 | public static RealMatrix ebeMultiply(RealMatrix a, RealMatrix b) { 45 | int rowDimension = a.getRowDimension(); 46 | int columnDimension = a.getColumnDimension(); 47 | //TODO a and b should have same dimensions 48 | RealMatrix output = new BlockRealMatrix(rowDimension, columnDimension); 49 | for (int i = 0; i < rowDimension; i++) { 50 | for (int j = 0; j < columnDimension; j++) { 51 | output.setEntry(i, j, a.getEntry(i, j) * b.getEntry(i, j)); 52 | } 53 | } 54 | return output; 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/linalg/RandomizedMatrix.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.linalg; 17 | 18 | import org.apache.commons.math3.distribution.AbstractRealDistribution; 19 | import org.apache.commons.math3.distribution.NormalDistribution; 20 | import org.apache.commons.math3.distribution.UniformRealDistribution; 21 | import org.apache.commons.math3.linear.Array2DRowRealMatrix; 22 | import org.apache.commons.math3.linear.ArrayRealVector; 23 | import org.apache.commons.math3.linear.BlockRealMatrix; 24 | import org.apache.commons.math3.linear.RealMatrix; 25 | import org.apache.commons.math3.linear.RealVector; 26 | 27 | /** 28 | * 29 | * @author Michael Brzustowicz 30 | */ 31 | public class RandomizedMatrix { 32 | 33 | private AbstractRealDistribution distribution; 34 | 35 | public RandomizedMatrix(AbstractRealDistribution distribution, long seed) { 36 | this.distribution = distribution; 37 | distribution.reseedRandomGenerator(seed); 38 | } 39 | 40 | public RandomizedMatrix() { 41 | this(new UniformRealDistribution(-1, 1), 0L); 42 | } 43 | 44 | public void fillMatrix(RealMatrix matrix) { 45 | for (int i = 0; i < matrix.getRowDimension(); i++) { 46 | matrix.setRow(i, distribution.sample(matrix.getColumnDimension())); 47 | } 48 | } 49 | 50 | public RealMatrix getMatrix(int numRows, int numCols) { 51 | RealMatrix output = new BlockRealMatrix(numRows, numCols); 52 | for (int i = 0; i < numRows; i++) { 53 | output.setRow(i, distribution.sample(numCols)); 54 | } 55 | return output; 56 | } 57 | 58 | public void fillVector(RealVector vector) { 59 | for (int i = 0; i < vector.getDimension(); i++) { 60 | vector.setEntry(i, distribution.sample()); 61 | } 62 | } 63 | 64 | public RealVector getVector(int dim) { 65 | return new ArrayRealVector(distribution.sample(dim)); 66 | } 67 | 68 | public static RealMatrix getTruncatedGaussian(int numRows, int numCols, long seed) { 69 | RealMatrix out = new BlockRealMatrix(numRows, numCols); 70 | NormalDistribution dist = new NormalDistribution(0.0, 0.5); 71 | dist.reseedRandomGenerator(seed); 72 | for (int i = 0; i < numRows; i++) { 73 | out.setRow(i, dist.sample(numCols)); 74 | } 75 | return out; 76 | } 77 | 78 | public static RealMatrix getUniform(int numRows, int numCols, long seed) { 79 | RealMatrix out = new Array2DRowRealMatrix(numRows, numCols); 80 | UniformRealDistribution dist = new UniformRealDistribution(-1, 1); 81 | dist.reseedRandomGenerator(seed); 82 | for (int i = 0; i < numRows; i++) { 83 | out.setRow(i, dist.sample(numCols)); 84 | } 85 | return out; 86 | } 87 | 88 | 89 | 90 | } 91 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/linalg/UnivariateFunctionMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.linalg; 17 | 18 | import org.apache.commons.math3.analysis.UnivariateFunction; 19 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class UnivariateFunctionMapper implements RealMatrixChangingVisitor { 26 | 27 | UnivariateFunction univariateFunction; 28 | 29 | public UnivariateFunctionMapper(UnivariateFunction univariateFunction) { 30 | this.univariateFunction = univariateFunction; 31 | } 32 | 33 | @Override 34 | public void start(int rows, int columns, int startRow, int endRow, 35 | int startColumn, int endColumn) { 36 | //NA 37 | } 38 | 39 | @Override 40 | public double visit(int row, int column, double value) { 41 | return univariateFunction.value(value); 42 | } 43 | 44 | @Override 45 | public double end() { 46 | return 0.0; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/BasicMapReduceExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import org.apache.hadoop.conf.Configured; 19 | import org.apache.hadoop.fs.Path; 20 | import org.apache.hadoop.mapreduce.Job; 21 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 22 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 23 | import org.apache.hadoop.util.Tool; 24 | import org.apache.hadoop.util.ToolRunner; 25 | 26 | /** 27 | * 28 | * @author Michael Brzustowicz 29 | */ 30 | public class BasicMapReduceExample extends Configured implements Tool { 31 | 32 | /** 33 | * @param args the command line arguments 34 | * @throws java.lang.Exception 35 | */ 36 | public static void main(String[] args) throws Exception { 37 | int exitCode = ToolRunner.run(new BasicMapReduceExample(), args); 38 | System.exit(exitCode); 39 | } 40 | 41 | @Override 42 | public int run(String[] args) throws Exception { 43 | /* this is deprecated */ 44 | // Job job = new Job(getConf()); 45 | /* the singleton method is preferred */ 46 | Job job = Job.getInstance(getConf()); 47 | job.setJarByClass(BasicMapReduceExample.class); 48 | job.setJobName("BasicMapReduceExample"); 49 | 50 | FileInputFormat.addInputPath(job, new Path(args[0])); 51 | FileOutputFormat.setOutputPath(job, new Path(args[1])); 52 | 53 | return job.waitForCompletion(true) ? 0 : 1; 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/CustomWordCountMapReduceExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import org.apache.hadoop.conf.Configured; 19 | import org.apache.hadoop.fs.Path; 20 | import org.apache.hadoop.io.LongWritable; 21 | import org.apache.hadoop.io.Text; 22 | import org.apache.hadoop.mapreduce.Job; 23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 24 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 25 | import org.apache.hadoop.mapreduce.lib.reduce.IntSumReducer; 26 | import org.apache.hadoop.util.Tool; 27 | import org.apache.hadoop.util.ToolRunner; 28 | 29 | /** 30 | * 31 | * @author Michael Brzustowicz 32 | */ 33 | public class CustomWordCountMapReduceExample extends Configured implements Tool { 34 | 35 | /** 36 | * @param args the command line arguments 37 | * @throws java.lang.Exception 38 | */ 39 | public static void main(String[] args) throws Exception { 40 | int exitCode = ToolRunner.run(new CustomWordCountMapReduceExample(), args); 41 | System.exit(exitCode); 42 | } 43 | 44 | @Override 45 | public int run(String[] args) throws Exception { 46 | Job job = Job.getInstance(getConf()); 47 | job.setJarByClass(CustomWordCountMapReduceExample.class); 48 | job.setJobName("CustomWordCountMapReduceExample"); 49 | 50 | FileInputFormat.addInputPath(job, new Path(args[0])); 51 | FileOutputFormat.setOutputPath(job, new Path(args[1])); 52 | 53 | job.setMapperClass(SimpleTokenMapper.class); 54 | job.setMapOutputKeyClass(Text.class); 55 | job.setMapOutputValueClass(LongWritable.class); 56 | job.setReducerClass(IntSumReducer.class); 57 | job.setOutputKeyClass(Text.class); 58 | job.setOutputValueClass(LongWritable.class); 59 | job.setNumReduceTasks(1); 60 | 61 | 62 | return job.waitForCompletion(true) ? 0 : 1; 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/JSONMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import java.io.IOException; 19 | import org.apache.hadoop.io.LongWritable; 20 | import org.apache.hadoop.io.Text; 21 | import org.apache.hadoop.mapreduce.Mapper; 22 | import org.json.simple.JSONObject; 23 | import org.json.simple.parser.JSONParser; 24 | import org.json.simple.parser.ParseException; 25 | 26 | /** 27 | * 28 | * @author Michael Brzustowicz 29 | */ 30 | public class JSONMapper extends Mapper { 31 | 32 | @Override 33 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { 34 | JSONParser parser = new JSONParser(); 35 | try { 36 | JSONObject obj = (JSONObject) parser.parse(value.toString()); 37 | 38 | // get what you need from this object 39 | String userID = obj.get("user_id").toString(); 40 | String productID = obj.get("product_id").toString(); 41 | int numUnits = Integer.parseInt(obj.get("num_units").toString()); 42 | 43 | JSONObject output = new JSONObject(); 44 | output.put("productID", productID); 45 | output.put("numUnits", numUnits); 46 | 47 | context.write(new Text(userID), new Text(output.toString())); 48 | 49 | 50 | } catch (ParseException ex) { 51 | //TODO error parsing json 52 | } 53 | 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/SimpleTokenMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import com.oreilly.dswj.dataops.SimpleTokenizer; 19 | import java.io.IOException; 20 | import org.apache.hadoop.io.LongWritable; 21 | import org.apache.hadoop.io.Text; 22 | import org.apache.hadoop.mapreduce.Mapper; 23 | 24 | /** 25 | * @author Michael Brzustowicz 26 | */ 27 | public class SimpleTokenMapper extends Mapper { 28 | 29 | SimpleTokenizer tokenizer; 30 | 31 | @Override 32 | protected void setup(Context context) throws IOException { 33 | 34 | tokenizer = new SimpleTokenizer(3); // mintokensize = 3 !!! 35 | } 36 | 37 | @Override 38 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { 39 | 40 | String[] tokens = tokenizer.getTokens(value.toString()); 41 | 42 | for (String token : tokens) { 43 | context.write(new Text(token), new LongWritable(1L)); 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/SparseAlgebraMapReduceExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import org.apache.hadoop.conf.Configured; 19 | import org.apache.hadoop.fs.Path; 20 | import org.apache.hadoop.io.DoubleWritable; 21 | import org.apache.hadoop.io.IntWritable; 22 | import org.apache.hadoop.mapreduce.Job; 23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 24 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 25 | import org.apache.hadoop.util.Tool; 26 | import org.apache.hadoop.util.ToolRunner; 27 | 28 | /** 29 | * a sparse matrix stored as i,j,v is mapped to 30 | * k:v -> i:SparseMatrixWritables. The reducer classes each create a sparse 31 | * or dense vector from a serialized file of that object, and then each reducer 32 | * method creates a sparse vector for that matrix row (i) and then dot product 33 | * of row and vector is output as i:dotproduct 34 | * @author Michael Brzustowicz 35 | */ 36 | public class SparseAlgebraMapReduceExample extends Configured implements Tool { 37 | 38 | /** 39 | * @param args the command line arguments 40 | * @throws java.lang.Exception 41 | */ 42 | public static void main(String[] args) throws Exception { 43 | int exitCode = ToolRunner.run(new SparseAlgebraMapReduceExample(), args); 44 | System.exit(exitCode); 45 | } 46 | 47 | @Override 48 | public int run(String[] args) throws Exception { 49 | Job job = Job.getInstance(getConf()); 50 | job.setJarByClass(SparseAlgebraMapReduceExample.class); 51 | job.setJobName("SparseAlgebraMapReduceExample"); 52 | 53 | // third command line arg is the filepath to the serialized vector file 54 | job.getConfiguration().set("vectorFileName", args[2]); 55 | 56 | FileInputFormat.addInputPath(job, new Path(args[0])); 57 | FileOutputFormat.setOutputPath(job, new Path(args[1])); 58 | 59 | job.setMapperClass(SparseMatrixMultiplicationMapper.class); 60 | job.setMapOutputKeyClass(IntWritable.class); 61 | job.setMapOutputValueClass(SparseMatrixWritable.class); 62 | job.setReducerClass(SparseMatrixMultiplicationReducer.class); 63 | job.setOutputKeyClass(IntWritable.class); 64 | job.setOutputValueClass(DoubleWritable.class); 65 | job.setNumReduceTasks(1); 66 | 67 | 68 | return job.waitForCompletion(true) ? 0 : 1; 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixMultiplicationMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import java.io.IOException; 19 | import org.apache.hadoop.io.IntWritable; 20 | import org.apache.hadoop.io.LongWritable; 21 | import org.apache.hadoop.io.Text; 22 | import org.apache.hadoop.mapreduce.Mapper; 23 | 24 | /** 25 | * outputs key = row number of sparse entry and value is sparsematrix writable 26 | * this is a great mapper for matrix multiplication with a vector 27 | * @author Michael Brzustowicz 28 | */ 29 | public class SparseMatrixMultiplicationMapper extends Mapper { 30 | 31 | /** 32 | * 33 | * @param key 34 | * @param value 35 | * @param context 36 | * @throws IOException 37 | * @throws InterruptedException 38 | */ 39 | @Override 40 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { 41 | try { 42 | String[] items = value.toString().split(","); 43 | int rowIndex = Integer.parseInt(items[0]); 44 | int columnIndex = Integer.parseInt(items[1]); 45 | double entry = Double.parseDouble(items[2]); 46 | SparseMatrixWritable smw = new SparseMatrixWritable(rowIndex, columnIndex, entry); 47 | context.write(new IntWritable(rowIndex), smw); 48 | //NOTE can add another context.write() for e.g. a symmetric matrix entry if matrix is sparse upper triag 49 | } catch (NumberFormatException | IOException | InterruptedException e) { 50 | context.getCounter("mapperErrors", e.getMessage()).increment(1L); 51 | } 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixMultiplicationReducer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import java.io.FileInputStream; 19 | import java.io.IOException; 20 | import java.io.ObjectInputStream; 21 | import org.apache.commons.math3.linear.OpenMapRealVector; 22 | import org.apache.commons.math3.linear.RealVector; 23 | import org.apache.hadoop.io.DoubleWritable; 24 | import org.apache.hadoop.io.IntWritable; 25 | 26 | import org.apache.hadoop.mapreduce.Reducer; 27 | 28 | /** 29 | * 30 | * @author Michael Brzustowicz 31 | */ 32 | public class SparseMatrixMultiplicationReducer extends Reducer{ 33 | 34 | private RealVector vector; 35 | 36 | @Override 37 | protected void setup(Context context) throws IOException, InterruptedException { 38 | /* unserialize the RealVector object */ 39 | // NOTE this is just the filename, please include the resource itself in the dist cache via -files at runtime 40 | //TODO set the filename in Job conf with set("vectorFileName", "actual file name here") 41 | String vectorFileName = context.getConfiguration().get("vectorFileName"); 42 | try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(vectorFileName))) { 43 | vector = (RealVector) in.readObject(); 44 | } catch(ClassNotFoundException e) { 45 | //TODO 46 | } 47 | } 48 | 49 | @Override 50 | protected void reduce(IntWritable key, Iterable values, Context context) throws IOException, InterruptedException { 51 | 52 | /* rely on the fact that rowVector dim has to be same as input vector dim */ 53 | RealVector rowVector = new OpenMapRealVector(vector.getDimension()); 54 | 55 | for (SparseMatrixWritable value : values) { 56 | rowVector.setEntry(value.columnIndex, value.entry); 57 | } 58 | 59 | double dotProduct = rowVector.dotProduct(vector); 60 | 61 | /* only write the nonzero outputs, since the Matrix-Vector product is probably sparse */ 62 | if(dotProduct != 0.0) { 63 | /* this outputs the vector index and it's value */ 64 | context.write(key, new DoubleWritable(dotProduct)); 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixWritable.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import java.io.DataInput; 19 | import java.io.DataOutput; 20 | import java.io.IOException; 21 | import org.apache.hadoop.io.Writable; 22 | /** 23 | * ONE element of a sparse matrix 24 | * @author Michael Brzustowicz 25 | */ 26 | public class SparseMatrixWritable implements Writable { 27 | int rowIndex; // i 28 | int columnIndex; // j 29 | double entry; // the value at i,j 30 | 31 | public SparseMatrixWritable() { 32 | } 33 | 34 | public SparseMatrixWritable(int rowIndex, int columnIndex, double entry) { 35 | this.rowIndex = rowIndex; 36 | this.columnIndex = columnIndex; 37 | this.entry = entry; 38 | } 39 | 40 | @Override 41 | public void write(DataOutput d) throws IOException { 42 | d.writeInt(rowIndex); 43 | d.writeInt(rowIndex); 44 | d.writeDouble(entry); 45 | } 46 | 47 | @Override 48 | public void readFields(DataInput di) throws IOException { 49 | rowIndex = di.readInt(); 50 | columnIndex = di.readInt(); 51 | entry = di.readDouble(); 52 | } 53 | 54 | // THIS IS OPTIONAL 55 | public static SparseMatrixWritable read(DataInput di) throws IOException { 56 | SparseMatrixWritable smw = new SparseMatrixWritable(); 57 | smw.readFields(di); 58 | return smw; 59 | 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/mapreduce/WordCountMapReduceExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.mapreduce; 17 | 18 | import org.apache.hadoop.conf.Configured; 19 | import org.apache.hadoop.fs.Path; 20 | import org.apache.hadoop.io.IntWritable; 21 | import org.apache.hadoop.io.Text; 22 | import org.apache.hadoop.mapreduce.Job; 23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 24 | import org.apache.hadoop.mapreduce.lib.map.TokenCounterMapper; 25 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 26 | import org.apache.hadoop.mapreduce.lib.reduce.IntSumReducer; 27 | import org.apache.hadoop.util.Tool; 28 | import org.apache.hadoop.util.ToolRunner; 29 | 30 | /** 31 | * 32 | * @author Michael Brzustowicz 33 | */ 34 | public class WordCountMapReduceExample extends Configured implements Tool { 35 | 36 | /** 37 | * @param args the command line arguments 38 | * @throws java.lang.Exception 39 | */ 40 | public static void main(String[] args) throws Exception { 41 | int exitCode = ToolRunner.run(new WordCountMapReduceExample(), args); 42 | System.exit(exitCode); 43 | } 44 | 45 | @Override 46 | public int run(String[] args) throws Exception { 47 | Job job = Job.getInstance(getConf()); 48 | job.setJarByClass(WordCountMapReduceExample.class); 49 | job.setJobName("WordCountMapReduceExample"); 50 | 51 | FileInputFormat.addInputPath(job, new Path(args[0])); 52 | FileOutputFormat.setOutputPath(job, new Path(args[1])); 53 | 54 | job.setMapperClass(TokenCounterMapper.class); 55 | job.setMapOutputKeyClass(Text.class); 56 | job.setMapOutputValueClass(IntWritable.class); 57 | job.setReducerClass(IntSumReducer.class); 58 | job.setOutputKeyClass(Text.class); 59 | job.setOutputValueClass(IntWritable.class); 60 | job.setNumReduceTasks(1); 61 | 62 | 63 | return job.waitForCompletion(true) ? 0 : 1; 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/statistics/AggStatsExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.statistics; 17 | 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | import org.apache.commons.math3.stat.descriptive.AggregateSummaryStatistics; 21 | import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues; 22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 23 | 24 | /** 25 | * 26 | * @author Michael Brzustowicz 27 | */ 28 | public class AggStatsExample { 29 | 30 | /** 31 | * @param args the command line arguments 32 | */ 33 | public static void main(String[] args) { 34 | 35 | List ls = new ArrayList<>(); 36 | 37 | SummaryStatistics ss = new SummaryStatistics(); 38 | ss.addValue(1.0); 39 | ss.addValue(11.0); 40 | ss.addValue(5.0); 41 | 42 | SummaryStatistics ss2 = new SummaryStatistics(); 43 | ss2.addValue(2.0); 44 | ss2.addValue(12.0); 45 | ss2.addValue(6.0); 46 | 47 | SummaryStatistics ss3 = new SummaryStatistics(); 48 | ss3.addValue(0.0); 49 | ss3.addValue(10.0); 50 | ss3.addValue(4.0); 51 | 52 | 53 | ls.add(ss); 54 | ls.add(ss2); 55 | ls.add(ss3); 56 | 57 | StatisticalSummaryValues s = AggregateSummaryStatistics.aggregate(ls); 58 | 59 | System.out.println(s); 60 | 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/statistics/AnscombeStatsExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.statistics; 17 | 18 | import com.oreilly.dswj.datasets.Anscombe; 19 | import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; 20 | 21 | /** 22 | * 23 | * @author Michael Brzustowicz 24 | */ 25 | public class AnscombeStatsExample { 26 | 27 | /** 28 | * @param args the command line arguments 29 | */ 30 | public static void main(String[] args) { 31 | DescriptiveStatistics x1 = new DescriptiveStatistics(Anscombe.x1); 32 | DescriptiveStatistics y1 = new DescriptiveStatistics(Anscombe.y1); 33 | DescriptiveStatistics x2 = new DescriptiveStatistics(Anscombe.x2); 34 | DescriptiveStatistics y2 = new DescriptiveStatistics(Anscombe.y2); 35 | DescriptiveStatistics x3 = new DescriptiveStatistics(Anscombe.x3); 36 | DescriptiveStatistics y3 = new DescriptiveStatistics(Anscombe.y3); 37 | DescriptiveStatistics x4 = new DescriptiveStatistics(Anscombe.x4); 38 | DescriptiveStatistics y4 = new DescriptiveStatistics(Anscombe.y4); 39 | 40 | System.out.println(x1); 41 | System.out.println(y1); 42 | System.out.println(x2); 43 | System.out.println(y2); 44 | System.out.println(x3); 45 | System.out.println(y3); 46 | System.out.println(x4); 47 | System.out.println(y4); 48 | 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/statistics/ContinuousDistributionPlot.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.statistics; 17 | 18 | 19 | import java.io.File; 20 | import java.io.IOException; 21 | import javafx.application.Application; 22 | import javafx.embed.swing.SwingFXUtils; 23 | import javafx.scene.Scene; 24 | import javafx.scene.SnapshotParameters; 25 | import javafx.scene.chart.LineChart; 26 | import javafx.scene.chart.NumberAxis; 27 | import javafx.scene.chart.XYChart; 28 | import javafx.scene.image.WritableImage; 29 | import javafx.stage.Stage; 30 | import javax.imageio.ImageIO; 31 | import org.apache.commons.math3.distribution.LogNormalDistribution; 32 | import org.apache.commons.math3.distribution.NormalDistribution; 33 | import org.apache.commons.math3.distribution.UniformRealDistribution; 34 | 35 | /** 36 | * 37 | * @author Michael Brzustowicz 38 | */ 39 | public class ContinuousDistributionPlot extends Application { 40 | 41 | /** 42 | * @param args the command line arguments 43 | */ 44 | public static void main(String[] args) { 45 | Application.launch(args); 46 | } 47 | 48 | @Override 49 | public void start(Stage primaryStage) throws Exception { 50 | 51 | /* make the dataset */ 52 | NormalDistribution dist = new NormalDistribution(); 53 | // LogNormalDistribution dist = new LogNormalDistribution(); 54 | // UniformRealDistribution dist = new UniformRealDistribution(); 55 | 56 | int n = 1000; 57 | double[] x = new double[n]; 58 | double[] y = new double[n]; 59 | double min = -5; 60 | double max = 5; 61 | double delta = max - min; 62 | for (int i = 0; i < n; i++) { 63 | 64 | x[i] = min + i * delta / n; 65 | y[i] = dist.density(x[i]); // this is for PDF 66 | // y[i] = dist.cumulativeProbability(x[i]); // this is for CDF 67 | } 68 | 69 | /* CREATE THE PLOT */ 70 | XYChart.Series series = new XYChart.Series(); 71 | for (int i = 0; i < x.length; i++) { 72 | series.getData().add(new XYChart.Data(x[i], y[i])); 73 | } 74 | NumberAxis xAxis = new NumberAxis("x", -5, 5, 1); 75 | NumberAxis yAxis = new NumberAxis(); 76 | xAxis.setLabel("x"); 77 | yAxis.setLabel("f(x)"); // this is for PDF 78 | // yAxis.setLabel("F(x)"); // this is for CDF 79 | xAxis.setMinorTickVisible(false); 80 | yAxis.setMinorTickVisible(false); 81 | 82 | 83 | LineChart lineChart = new LineChart<>(xAxis,yAxis); 84 | lineChart.setAnimated(false); // need this to save file 85 | lineChart.getData().addAll(series); 86 | lineChart.setBackground(null); 87 | lineChart.setLegendVisible(false); 88 | lineChart.setHorizontalGridLinesVisible(false); 89 | lineChart.setVerticalGridLinesVisible(false); 90 | lineChart.setVerticalZeroLineVisible(false); 91 | lineChart.setCreateSymbols(false); 92 | 93 | Scene scene = new Scene(lineChart,800,600); 94 | // scene.getStylesheets().add("css/"); 95 | primaryStage.setScene(scene); 96 | primaryStage.show(); 97 | 98 | /* uncomment below to write to file in addition to screen rendering */ 99 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null); 100 | 101 | // TODO: probably use a file chooser here 102 | // File file = new File("plot.png"); 103 | 104 | // try { 105 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file); 106 | // } catch (IOException e) { 107 | // TODO: handle exception here 108 | // } 109 | 110 | 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/statistics/DiscreteDistributionPlot.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.statistics; 17 | 18 | 19 | import java.io.File; 20 | import java.io.IOException; 21 | import javafx.application.Application; 22 | import javafx.embed.swing.SwingFXUtils; 23 | import javafx.scene.Scene; 24 | import javafx.scene.SnapshotParameters; 25 | import javafx.scene.chart.NumberAxis; 26 | import javafx.scene.chart.ScatterChart; 27 | import javafx.scene.chart.XYChart; 28 | import javafx.scene.image.WritableImage; 29 | import javafx.stage.Stage; 30 | import javax.imageio.ImageIO; 31 | import org.apache.commons.math3.distribution.BinomialDistribution; 32 | import org.apache.commons.math3.distribution.PoissonDistribution; 33 | 34 | /** 35 | * 36 | * @author Michael Brzustowicz 37 | */ 38 | public class DiscreteDistributionPlot extends Application { 39 | 40 | /** 41 | * @param args the command line arguments 42 | */ 43 | public static void main(String[] args) { 44 | Application.launch(args); 45 | } 46 | 47 | @Override 48 | public void start(Stage primaryStage) throws Exception { 49 | /* make the dataset */ 50 | 51 | // BinomialDistribution dist = new BinomialDistribution(40, 0.5); 52 | PoissonDistribution dist = new PoissonDistribution(5.0); 53 | int n = 21; 54 | double[] x = new double[n]; 55 | double[] y = new double[n]; 56 | double min = 0; 57 | double max = 21; 58 | double delta = max - min; 59 | for (int i = 0; i < n; i++) { 60 | x[i] = i; 61 | y[i] = dist.probability(i); // PMF 62 | // y[i] = dist.cumulativeProbability(i); //CDF 63 | } 64 | 65 | /* */ 66 | XYChart.Series series = new XYChart.Series(); 67 | for (int i = 0; i < x.length; i++) { 68 | series.getData().add(new XYChart.Data(x[i], y[i])); 69 | } 70 | 71 | final NumberAxis xAxis = new NumberAxis("k", 0, 20, 5); 72 | final NumberAxis yAxis = new NumberAxis(); 73 | 74 | yAxis.setLabel("f(k)"); 75 | // xAxis.setTickLabelFont(new Font(12)); 76 | // xAxis.setTickUnit(1.0 / 41.0); 77 | xAxis.setMinorTickVisible(false); 78 | yAxis.setMinorTickVisible(false); 79 | 80 | //discrete 81 | final ScatterChart chart = new ScatterChart<>(xAxis,yAxis); 82 | 83 | chart.setAnimated(false); 84 | 85 | chart.getData().addAll(series); 86 | chart.setBackground(null); 87 | chart.setLegendVisible(false); 88 | 89 | chart.setHorizontalGridLinesVisible(false); 90 | chart.setVerticalGridLinesVisible(false); 91 | chart.setVerticalZeroLineVisible(false); 92 | // lineChart.setCreateSymbols(false); 93 | // lineChart.setStyle(".chart-plot-background {-fx-background-color: #ffffff;}"); 94 | 95 | Scene scene = new Scene(chart,800,600); 96 | scene.getStylesheets().add("css/chart_lineplot.css"); 97 | primaryStage.setScene(scene); 98 | primaryStage.show(); 99 | 100 | /* uncomment below to write to file in addition to screen rendering */ 101 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null); 102 | 103 | // TODO: probably use a file chooser here 104 | // File file = new File("plot.png"); 105 | 106 | // try { 107 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file); 108 | // } catch (IOException e) { 109 | // TODO: handle exception here 110 | // } 111 | 112 | 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/main/java/com/oreilly/dswj/statistics/EntropyPlotExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Michael Brzustowicz. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.oreilly.dswj.statistics; 17 | 18 | import java.io.File; 19 | import java.io.IOException; 20 | import javafx.application.Application; 21 | import javafx.embed.swing.SwingFXUtils; 22 | import javafx.scene.Scene; 23 | import javafx.scene.SnapshotParameters; 24 | import javafx.scene.chart.LineChart; 25 | import javafx.scene.chart.NumberAxis; 26 | import javafx.scene.chart.XYChart; 27 | import javafx.scene.image.WritableImage; 28 | import javafx.stage.Stage; 29 | import javax.imageio.ImageIO; 30 | 31 | /** 32 | * 33 | * @author Michael Brzustowicz 34 | */ 35 | public class EntropyPlotExample extends Application { 36 | 37 | public static double entropy(double p) { 38 | double out = 0; 39 | if (p>0 && p<1){ 40 | out = - ((1-p) * Math.log(1-p) + p * Math.log(p)) / Math.log(2); 41 | } 42 | return out; 43 | } 44 | /** 45 | * @param args the command line arguments 46 | */ 47 | public static void main(String[] args) { 48 | Application.launch(args); 49 | } 50 | 51 | @Override 52 | public void start(Stage primaryStage) throws Exception { 53 | 54 | int n = 100; 55 | double[] p = new double[n]; 56 | double[] h = new double[n]; 57 | double delta = 1.0 / (n-1); 58 | for (int i = 0; i < n; i++) { 59 | p[i] = i * delta; 60 | h[i] = entropy(p[i]); 61 | } 62 | 63 | XYChart.Series series = new XYChart.Series(); 64 | 65 | for (int i = 0; i < p.length; i++) { 66 | series.getData().add(new XYChart.Data(p[i], h[i])); 67 | } 68 | 69 | NumberAxis xAxis = new NumberAxis("p", 0, 1, 0.25); 70 | NumberAxis yAxis = new NumberAxis(); 71 | yAxis.setLabel("entropy"); 72 | xAxis.setMinorTickVisible(false); 73 | 74 | LineChart lineChart = new LineChart<>(xAxis,yAxis); 75 | lineChart.setAnimated(false); 76 | lineChart.getData().addAll(series); 77 | lineChart.setBackground(null); 78 | lineChart.setLegendVisible(false); 79 | lineChart.setHorizontalGridLinesVisible(false); 80 | lineChart.setVerticalGridLinesVisible(false); 81 | lineChart.setVerticalZeroLineVisible(false); 82 | lineChart.setCreateSymbols(false); 83 | 84 | Scene scene = new Scene(lineChart,800,600); 85 | scene.getStylesheets().add("css/chart_lineplot.css"); 86 | primaryStage.setScene(scene); 87 | primaryStage.show(); 88 | 89 | /* uncomment below to write to file in addition to screen rendering */ 90 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null); 91 | 92 | 93 | // TODO: probably use a file chooser here 94 | // File file = new File("entropy.png"); 95 | // 96 | // try { 97 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file); 98 | // } catch (IOException e) { 99 | // // TODO: handle exception here 100 | // } 101 | 102 | 103 | 104 | 105 | 106 | } 107 | 108 | } 109 | -------------------------------------------------------------------------------- /src/main/resources/css/chart.css: -------------------------------------------------------------------------------- 1 | /* 2 | To change this license header, choose License Headers in Project Properties. 3 | To change this template file, choose Tools | Templates 4 | and open the template in the editor. 5 | */ 6 | /* 7 | Created on : Apr 29, 2015, 6:23:44 PM 8 | Author : mbrzusto 9 | */ 10 | .chart { 11 | /*-fx-background-color: red;*/ 12 | } 13 | 14 | .chart-plot-background { 15 | -fx-background-color: #ffffff; 16 | } 17 | 18 | .default-color0.chart-series-line { -fx-stroke: transparent; } 19 | .default-color1.chart-series-line { -fx-stroke: blue; -fx-stroke-width: 1; } 20 | .default-color2.chart-series-line { 21 | -fx-stroke: blue; 22 | -fx-stroke-width: 1; 23 | -fx-stroke-dash-array: 1 4 1 4; 24 | } 25 | .default-color3.chart-series-line { 26 | -fx-stroke: blue; 27 | -fx-stroke-width: 1; 28 | -fx-stroke-dash-array: 1 4 1 4; 29 | } 30 | 31 | /*.default-color0.chart-line-symbol { 32 | -fx-background-color: white, green; 33 | }*/ 34 | .default-color1.chart-line-symbol { 35 | -fx-background-color: transparent, transparent; 36 | } 37 | .default-color2.chart-line-symbol { 38 | -fx-background-color: transparent, transparent; 39 | } 40 | .default-color3.chart-line-symbol { 41 | -fx-background-color: transparent, transparent; 42 | } 43 | /*.default-color0.chart-legend-item-symbol{ 44 | -fx-background-color: blue; 45 | } 46 | .default-color1.chart-legend-item-symbol{ 47 | -fx-background-color: red; 48 | } 49 | .default-color2.chart-legend-item-symbol{ 50 | -fx-background-color: black; 51 | } 52 | .default-color3.chart-legend-item-symbol{ 53 | -fx-background-color: black; 54 | }*/ -------------------------------------------------------------------------------- /src/main/resources/css/chart_lineplot.css: -------------------------------------------------------------------------------- 1 | /* 2 | To change this license header, choose License Headers in Project Properties. 3 | To change this template file, choose Tools | Templates 4 | and open the template in the editor. 5 | */ 6 | /* 7 | Created on : Apr 29, 2015, 6:23:44 PM 8 | Author : mbrzusto 9 | */ 10 | /*.chart { 11 | -fx-background-color: red; 12 | }*/ 13 | 14 | .chart-plot-background { 15 | -fx-background-color: #ffffff; 16 | } 17 | 18 | /*.axis { 19 | -fx-font-size: 1.6em; 20 | }*/ 21 | 22 | /*.axis-label { 23 | -fx-font-size: 1.6em; 24 | }*/ 25 | 26 | .default-color0.chart-series-line { 27 | -fx-stroke: black; 28 | -fx-stroke-width: 1; 29 | } 30 | 31 | .default-color1.chart-series-line { 32 | -fx-stroke: black; 33 | -fx-stroke-width: 1; 34 | -fx-stroke-dash-array: 1 4 1 4; 35 | } 36 | 37 | .default-color2.chart-series-line { 38 | -fx-stroke: black; 39 | -fx-stroke-width: 1; 40 | -fx-stroke-dash-array: 1 6 1 6; 41 | } 42 | .default-color3.chart-series-line { 43 | -fx-stroke: black; 44 | -fx-stroke-width: 1; 45 | -fx-stroke-dash-array: 1 8 1 8; 46 | } 47 | 48 | .default-color4.chart-series-line { 49 | -fx-stroke: black; 50 | -fx-stroke-width: 1; 51 | -fx-stroke-dash-array: 4 4 4 4; 52 | } 53 | 54 | .default-color0.chart-line-symbol { 55 | /*-fx-shape: "circle";*/ 56 | -fx-background-color: green, white; 57 | } 58 | 59 | .default-color1.chart-line-symbol { 60 | -fx-background-color: transparent, transparent; 61 | } 62 | .default-color2.chart-line-symbol { 63 | -fx-background-color: transparent, transparent; 64 | } 65 | .default-color3.chart-line-symbol { 66 | -fx-background-color: transparent, transparent; 67 | } 68 | /*.default-color0.chart-legend-item-symbol{ 69 | -fx-background-color: blue; 70 | } 71 | .default-color1.chart-legend-item-symbol{ 72 | -fx-background-color: red; 73 | } 74 | .default-color2.chart-legend-item-symbol{ 75 | -fx-background-color: black; 76 | } 77 | .default-color3.chart-legend-item-symbol{ 78 | -fx-background-color: black; 79 | }*/ -------------------------------------------------------------------------------- /src/main/resources/css/data_with_fit_line.css: -------------------------------------------------------------------------------- 1 | /* 2 | To change this license header, choose License Headers in Project Properties. 3 | To change this template file, choose Tools | Templates 4 | and open the template in the editor. 5 | */ 6 | /* 7 | Created on : Apr 29, 2015, 6:23:44 PM 8 | Author : mbrzusto 9 | */ 10 | .chart { 11 | /*-fx-background-color: red;*/ 12 | } 13 | 14 | .chart-plot-background { 15 | -fx-background-color: #ffffff; 16 | } 17 | 18 | .default-color0.chart-series-line { -fx-stroke: blue; -fx-stroke-width: 1; } 19 | .default-color1.chart-series-line { -fx-stroke: transparent; } 20 | 21 | 22 | 23 | 24 | .default-color0.chart-line-symbol { 25 | -fx-background-color: transparent, transparent; 26 | } 27 | /*.default-color1.chart-line-symbol { 28 | -fx-background-color: white, green; 29 | }*/ 30 | 31 | /*.default-color0.chart-legend-item-symbol{ 32 | -fx-background-color: blue; 33 | } 34 | .default-color1.chart-legend-item-symbol{ 35 | -fx-background-color: red; 36 | } 37 | .default-color2.chart-legend-item-symbol{ 38 | -fx-background-color: black; 39 | } 40 | .default-color3.chart-legend-item-symbol{ 41 | -fx-background-color: black; 42 | }*/ -------------------------------------------------------------------------------- /src/main/resources/css/overlay-chart.css: -------------------------------------------------------------------------------- 1 | /** file: overlay-chart.css (place in same directory as LayeredXyChartsSample */ 2 | .chart-plot-background { 3 | -fx-background-color: transparent; 4 | } 5 | .default-color0.chart-series-line { 6 | -fx-stroke: forestgreen; 7 | } 8 | 9 | -------------------------------------------------------------------------------- /src/main/resources/datasets/iris/iris_data.csv: -------------------------------------------------------------------------------- 1 | 5.1,3.5,1.4,0.2,Iris-setosa 2 | 4.9,3.0,1.4,0.2,Iris-setosa 3 | 4.7,3.2,1.3,0.2,Iris-setosa 4 | 4.6,3.1,1.5,0.2,Iris-setosa 5 | 5.0,3.6,1.4,0.2,Iris-setosa 6 | 5.4,3.9,1.7,0.4,Iris-setosa 7 | 4.6,3.4,1.4,0.3,Iris-setosa 8 | 5.0,3.4,1.5,0.2,Iris-setosa 9 | 4.4,2.9,1.4,0.2,Iris-setosa 10 | 4.9,3.1,1.5,0.1,Iris-setosa 11 | 5.4,3.7,1.5,0.2,Iris-setosa 12 | 4.8,3.4,1.6,0.2,Iris-setosa 13 | 4.8,3.0,1.4,0.1,Iris-setosa 14 | 4.3,3.0,1.1,0.1,Iris-setosa 15 | 5.8,4.0,1.2,0.2,Iris-setosa 16 | 5.7,4.4,1.5,0.4,Iris-setosa 17 | 5.4,3.9,1.3,0.4,Iris-setosa 18 | 5.1,3.5,1.4,0.3,Iris-setosa 19 | 5.7,3.8,1.7,0.3,Iris-setosa 20 | 5.1,3.8,1.5,0.3,Iris-setosa 21 | 5.4,3.4,1.7,0.2,Iris-setosa 22 | 5.1,3.7,1.5,0.4,Iris-setosa 23 | 4.6,3.6,1.0,0.2,Iris-setosa 24 | 5.1,3.3,1.7,0.5,Iris-setosa 25 | 4.8,3.4,1.9,0.2,Iris-setosa 26 | 5.0,3.0,1.6,0.2,Iris-setosa 27 | 5.0,3.4,1.6,0.4,Iris-setosa 28 | 5.2,3.5,1.5,0.2,Iris-setosa 29 | 5.2,3.4,1.4,0.2,Iris-setosa 30 | 4.7,3.2,1.6,0.2,Iris-setosa 31 | 4.8,3.1,1.6,0.2,Iris-setosa 32 | 5.4,3.4,1.5,0.4,Iris-setosa 33 | 5.2,4.1,1.5,0.1,Iris-setosa 34 | 5.5,4.2,1.4,0.2,Iris-setosa 35 | 4.9,3.1,1.5,0.1,Iris-setosa 36 | 5.0,3.2,1.2,0.2,Iris-setosa 37 | 5.5,3.5,1.3,0.2,Iris-setosa 38 | 4.9,3.1,1.5,0.1,Iris-setosa 39 | 4.4,3.0,1.3,0.2,Iris-setosa 40 | 5.1,3.4,1.5,0.2,Iris-setosa 41 | 5.0,3.5,1.3,0.3,Iris-setosa 42 | 4.5,2.3,1.3,0.3,Iris-setosa 43 | 4.4,3.2,1.3,0.2,Iris-setosa 44 | 5.0,3.5,1.6,0.6,Iris-setosa 45 | 5.1,3.8,1.9,0.4,Iris-setosa 46 | 4.8,3.0,1.4,0.3,Iris-setosa 47 | 5.1,3.8,1.6,0.2,Iris-setosa 48 | 4.6,3.2,1.4,0.2,Iris-setosa 49 | 5.3,3.7,1.5,0.2,Iris-setosa 50 | 5.0,3.3,1.4,0.2,Iris-setosa 51 | 7.0,3.2,4.7,1.4,Iris-versicolor 52 | 6.4,3.2,4.5,1.5,Iris-versicolor 53 | 6.9,3.1,4.9,1.5,Iris-versicolor 54 | 5.5,2.3,4.0,1.3,Iris-versicolor 55 | 6.5,2.8,4.6,1.5,Iris-versicolor 56 | 5.7,2.8,4.5,1.3,Iris-versicolor 57 | 6.3,3.3,4.7,1.6,Iris-versicolor 58 | 4.9,2.4,3.3,1.0,Iris-versicolor 59 | 6.6,2.9,4.6,1.3,Iris-versicolor 60 | 5.2,2.7,3.9,1.4,Iris-versicolor 61 | 5.0,2.0,3.5,1.0,Iris-versicolor 62 | 5.9,3.0,4.2,1.5,Iris-versicolor 63 | 6.0,2.2,4.0,1.0,Iris-versicolor 64 | 6.1,2.9,4.7,1.4,Iris-versicolor 65 | 5.6,2.9,3.6,1.3,Iris-versicolor 66 | 6.7,3.1,4.4,1.4,Iris-versicolor 67 | 5.6,3.0,4.5,1.5,Iris-versicolor 68 | 5.8,2.7,4.1,1.0,Iris-versicolor 69 | 6.2,2.2,4.5,1.5,Iris-versicolor 70 | 5.6,2.5,3.9,1.1,Iris-versicolor 71 | 5.9,3.2,4.8,1.8,Iris-versicolor 72 | 6.1,2.8,4.0,1.3,Iris-versicolor 73 | 6.3,2.5,4.9,1.5,Iris-versicolor 74 | 6.1,2.8,4.7,1.2,Iris-versicolor 75 | 6.4,2.9,4.3,1.3,Iris-versicolor 76 | 6.6,3.0,4.4,1.4,Iris-versicolor 77 | 6.8,2.8,4.8,1.4,Iris-versicolor 78 | 6.7,3.0,5.0,1.7,Iris-versicolor 79 | 6.0,2.9,4.5,1.5,Iris-versicolor 80 | 5.7,2.6,3.5,1.0,Iris-versicolor 81 | 5.5,2.4,3.8,1.1,Iris-versicolor 82 | 5.5,2.4,3.7,1.0,Iris-versicolor 83 | 5.8,2.7,3.9,1.2,Iris-versicolor 84 | 6.0,2.7,5.1,1.6,Iris-versicolor 85 | 5.4,3.0,4.5,1.5,Iris-versicolor 86 | 6.0,3.4,4.5,1.6,Iris-versicolor 87 | 6.7,3.1,4.7,1.5,Iris-versicolor 88 | 6.3,2.3,4.4,1.3,Iris-versicolor 89 | 5.6,3.0,4.1,1.3,Iris-versicolor 90 | 5.5,2.5,4.0,1.3,Iris-versicolor 91 | 5.5,2.6,4.4,1.2,Iris-versicolor 92 | 6.1,3.0,4.6,1.4,Iris-versicolor 93 | 5.8,2.6,4.0,1.2,Iris-versicolor 94 | 5.0,2.3,3.3,1.0,Iris-versicolor 95 | 5.6,2.7,4.2,1.3,Iris-versicolor 96 | 5.7,3.0,4.2,1.2,Iris-versicolor 97 | 5.7,2.9,4.2,1.3,Iris-versicolor 98 | 6.2,2.9,4.3,1.3,Iris-versicolor 99 | 5.1,2.5,3.0,1.1,Iris-versicolor 100 | 5.7,2.8,4.1,1.3,Iris-versicolor 101 | 6.3,3.3,6.0,2.5,Iris-virginica 102 | 5.8,2.7,5.1,1.9,Iris-virginica 103 | 7.1,3.0,5.9,2.1,Iris-virginica 104 | 6.3,2.9,5.6,1.8,Iris-virginica 105 | 6.5,3.0,5.8,2.2,Iris-virginica 106 | 7.6,3.0,6.6,2.1,Iris-virginica 107 | 4.9,2.5,4.5,1.7,Iris-virginica 108 | 7.3,2.9,6.3,1.8,Iris-virginica 109 | 6.7,2.5,5.8,1.8,Iris-virginica 110 | 7.2,3.6,6.1,2.5,Iris-virginica 111 | 6.5,3.2,5.1,2.0,Iris-virginica 112 | 6.4,2.7,5.3,1.9,Iris-virginica 113 | 6.8,3.0,5.5,2.1,Iris-virginica 114 | 5.7,2.5,5.0,2.0,Iris-virginica 115 | 5.8,2.8,5.1,2.4,Iris-virginica 116 | 6.4,3.2,5.3,2.3,Iris-virginica 117 | 6.5,3.0,5.5,1.8,Iris-virginica 118 | 7.7,3.8,6.7,2.2,Iris-virginica 119 | 7.7,2.6,6.9,2.3,Iris-virginica 120 | 6.0,2.2,5.0,1.5,Iris-virginica 121 | 6.9,3.2,5.7,2.3,Iris-virginica 122 | 5.6,2.8,4.9,2.0,Iris-virginica 123 | 7.7,2.8,6.7,2.0,Iris-virginica 124 | 6.3,2.7,4.9,1.8,Iris-virginica 125 | 6.7,3.3,5.7,2.1,Iris-virginica 126 | 7.2,3.2,6.0,1.8,Iris-virginica 127 | 6.2,2.8,4.8,1.8,Iris-virginica 128 | 6.1,3.0,4.9,1.8,Iris-virginica 129 | 6.4,2.8,5.6,2.1,Iris-virginica 130 | 7.2,3.0,5.8,1.6,Iris-virginica 131 | 7.4,2.8,6.1,1.9,Iris-virginica 132 | 7.9,3.8,6.4,2.0,Iris-virginica 133 | 6.4,2.8,5.6,2.2,Iris-virginica 134 | 6.3,2.8,5.1,1.5,Iris-virginica 135 | 6.1,2.6,5.6,1.4,Iris-virginica 136 | 7.7,3.0,6.1,2.3,Iris-virginica 137 | 6.3,3.4,5.6,2.4,Iris-virginica 138 | 6.4,3.1,5.5,1.8,Iris-virginica 139 | 6.0,3.0,4.8,1.8,Iris-virginica 140 | 6.9,3.1,5.4,2.1,Iris-virginica 141 | 6.7,3.1,5.6,2.4,Iris-virginica 142 | 6.9,3.1,5.1,2.3,Iris-virginica 143 | 5.8,2.7,5.1,1.9,Iris-virginica 144 | 6.8,3.2,5.9,2.3,Iris-virginica 145 | 6.7,3.3,5.7,2.5,Iris-virginica 146 | 6.7,3.0,5.2,2.3,Iris-virginica 147 | 6.3,2.5,5.0,1.9,Iris-virginica 148 | 6.5,3.0,5.2,2.0,Iris-virginica 149 | 6.2,3.4,5.4,2.3,Iris-virginica 150 | 5.9,3.0,5.1,1.8,Iris-virginica 151 | -------------------------------------------------------------------------------- /src/main/resources/datasets/mnist/README: -------------------------------------------------------------------------------- 1 | download the following files 2 | 3 | train-images-idx3-ubyte 4 | train-labels-idx1-ubyte 5 | t10k-images-idx3-ubyte 6 | t10k-labels-idx1-ubyte 7 | 8 | from 9 | 10 | http://yann.lecun.com/exdb/mnist/ 11 | 12 | and put them in this directory --------------------------------------------------------------------------------