├── README.md
├── pom.xml
└── src
└── main
├── java
└── com
│ └── oreilly
│ └── dswj
│ ├── dataops
│ ├── Batch.java
│ ├── CrossEntropyLossFunction.java
│ ├── Dictionary.java
│ ├── HashingDictionary.java
│ ├── IrisPCAExample.java
│ ├── LabelEncoder.java
│ ├── LinearLossFunction.java
│ ├── ListResampler.java
│ ├── LossFunction.java
│ ├── MatrixResampler.java
│ ├── MatrixScaleType.java
│ ├── MatrixScaler.java
│ ├── MatrixScalingOperator.java
│ ├── PCA.java
│ ├── PCAEIGImplementation.java
│ ├── PCAImplementation.java
│ ├── PCASVDImplementation.java
│ ├── ProbabilityEncoder.java
│ ├── QuadraticLossFunction.java
│ ├── SimpleTokenizer.java
│ ├── SoftMaxCrossEntropyLossFunction.java
│ ├── TFIDF.java
│ ├── TFIDFVectorizer.java
│ ├── TermDictionary.java
│ ├── Tokenizer.java
│ ├── TwoPointLossFunction.java
│ ├── Vectorizer.java
│ └── VectorizerExample.java
│ ├── datasets
│ ├── Anscombe.java
│ ├── Iris.java
│ ├── MNIST.java
│ ├── MultiNormalMixtureDataset.java
│ └── Sentiment.java
│ ├── io
│ ├── BasicBarChart.java
│ ├── BasicScatterChart.java
│ ├── DBApp.java
│ ├── DBInsertBatchApp.java
│ ├── FileIOExample.java
│ ├── FitPlot.java
│ └── Record.java
│ ├── learn
│ ├── BernoulliConditionalProbabilityEstimator.java
│ ├── ClassifierAccuracy.java
│ ├── ConditionalProbabilityEstimator.java
│ ├── DBSCANOptimizeExample.java
│ ├── DBSCANPlotExample.java
│ ├── DeepNetwork.java
│ ├── DeepNetworkIrisExample.java
│ ├── DeepNetworkMNISTExample.java
│ ├── DeltaRule.java
│ ├── GaussianConditionalProbabilityEstimator.java
│ ├── GaussianMixtureClusteringExample.java
│ ├── GradientDescent.java
│ ├── GradientDescentMomentum.java
│ ├── IterativeLearningProcess.java
│ ├── KMeansExample.java
│ ├── LinearModel.java
│ ├── LinearModelEstimator.java
│ ├── LinearOutputFunction.java
│ ├── LogisticOutputFunction.java
│ ├── MultinomialConditionalProbabilityEstimator.java
│ ├── NaiveBayes.java
│ ├── NaiveBayesGaussianIrisExample.java
│ ├── NetworkLayer.java
│ ├── Optimizer.java
│ ├── OutputFunction.java
│ ├── SilhouetteCoefficient.java
│ ├── SoftmaxLinearModelExample.java
│ ├── SoftmaxOutputFunction.java
│ └── TanhOutputFunction.java
│ ├── linalg
│ ├── FunctionMapper.java
│ ├── LinearSystemExample.java
│ ├── MatrixOperations.java
│ ├── RandomizedMatrix.java
│ └── UnivariateFunctionMapper.java
│ ├── mapreduce
│ ├── BasicMapReduceExample.java
│ ├── CustomWordCountMapReduceExample.java
│ ├── JSONMapper.java
│ ├── SimpleTokenMapper.java
│ ├── SparseAlgebraMapReduceExample.java
│ ├── SparseMatrixMultiplicationMapper.java
│ ├── SparseMatrixMultiplicationReducer.java
│ ├── SparseMatrixWritable.java
│ └── WordCountMapReduceExample.java
│ └── statistics
│ ├── AggStatsExample.java
│ ├── AnscombeStatsExample.java
│ ├── AnscombesPlotExample.java
│ ├── ContinuousDistributionPlot.java
│ ├── DiscreteDistributionPlot.java
│ ├── EntropyPlotExample.java
│ ├── Guassian2DExample.java
│ └── HistogramExample.java
└── resources
├── css
├── chart.css
├── chart_lineplot.css
├── data_with_fit_line.css
└── overlay-chart.css
└── datasets
├── iris
└── iris_data.csv
├── mnist
└── README
└── sentiment
├── amazon_cells_labelled.txt
├── imdb_labelled.txt
└── yelp_labelled.txt
/README.md:
--------------------------------------------------------------------------------
1 | Data Science with Java
2 | ==========
3 |
4 | This is the example code that accompanies Data Science with Java by Michael R. Brzustowicz, PhD (9781491934111).
5 |
6 | Click the Download Zip button to the right to download example code.
7 |
8 | Visit the catalog page [here](http://shop.oreilly.com/product/0636920043171.do).
9 |
10 | See an error? Report it [here](http://oreilly.com/catalog/errata.csp?isbn=0636920043171), or simply fork and send us a pull request
11 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 4.0.0
4 | com.oreilly.javabook
5 | Dava_Science_with_Java
6 | 1.0-SNAPSHOT
7 | jar
8 |
9 |
10 | org.apache.commons
11 | commons-math3
12 | 3.6
13 |
14 |
15 | com.googlecode.json-simple
16 | json-simple
17 | 1.1
18 |
19 |
20 | mysql
21 | mysql-connector-java
22 | 5.1.35
23 |
24 |
25 | org.apache.hadoop
26 | hadoop-client
27 | 2.7.2
28 |
29 |
30 |
31 | UTF-8
32 | 1.8
33 | 1.8
34 |
35 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/Batch.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 |
20 | /**
21 | *
22 | * @author Michael Brzustowicz
23 | */
24 | public class Batch extends MatrixResampler {
25 |
26 | public Batch(RealMatrix features, RealMatrix labels) {
27 | super(features, labels);
28 | }
29 |
30 | public void calcNextBatch(int batchSize) {
31 | super.calculateTestTrainSplit(batchSize);
32 | }
33 |
34 | public RealMatrix getInputBatch() {
35 | return super.getTestingFeatures();
36 | }
37 |
38 | public RealMatrix getTargetBatch() {
39 | return super.getTestingLabels();
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/CrossEntropyLossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
19 | import org.apache.commons.math3.linear.ArrayRealVector;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
23 | import org.apache.commons.math3.util.FastMath;
24 |
25 | /**
26 | *
27 | * @author Michael Brzustowicz
28 | */
29 | public class CrossEntropyLossFunction implements LossFunction {
30 |
31 | @Override
32 | public double getSampleLoss(double predicted, double target) {
33 | return -1.0 * (target * ((predicted>0)?FastMath.log(predicted):0) +
34 | (1.0 - target)*(predicted<1?FastMath.log(1.0-predicted):0));
35 | }
36 |
37 | public double getSampleLoss(double[] predicted, double[] target) {
38 | double loss = 0.0;
39 | for (int i = 0; i < predicted.length; i++) {
40 | loss += getSampleLoss(predicted[i], target[i]);
41 | }
42 | return loss;
43 | }
44 |
45 | @Override
46 | public double getSampleLoss(RealVector predicted, RealVector target) {
47 | double loss = 0.0;
48 | for (int i = 0; i < predicted.getDimension(); i++) {
49 | loss += getSampleLoss(predicted.getEntry(i), target.getEntry(i));
50 | }
51 | return loss;
52 | }
53 |
54 | @Override
55 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) {
56 | SummaryStatistics stats = new SummaryStatistics();
57 | for (int i = 0; i < predicted.getRowDimension(); i++) {
58 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i)));
59 | }
60 | return stats.getMean();
61 | }
62 |
63 | @Override
64 | public double getSampleLossGradient(double predicted, double target) {
65 | // NOTE this blows up if predicted = 0 or 1, which it should never be
66 | return (predicted - target) / (predicted * (1 - predicted));
67 | }
68 |
69 | @Override
70 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) {
71 | RealVector loss = new ArrayRealVector(predicted.getDimension());
72 | for (int i = 0; i < predicted.getDimension(); i++) {
73 | loss.setEntry(i, getSampleLossGradient(predicted.getEntry(i), target.getEntry(i)));
74 | }
75 | return loss;
76 | }
77 |
78 | @Override
79 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) {
80 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension());
81 | for (int i = 0; i < predicted.getRowDimension(); i++) {
82 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i)));
83 | }
84 | return loss;
85 | }
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/Dictionary.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | /**
19 | *
20 | * @author Michael Brzustowicz
21 | */
22 | public interface Dictionary {
23 | /**
24 | * boxed type can return a null if term is not in dictionary
25 | * @param term
26 | * @return
27 | */
28 | Integer getTermIndex(String term);
29 | int getNumTerms();
30 | }
31 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/HashingDictionary.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | /**
19 | * uses hashing trick to store terms in large hashmap to avoid collisions
20 | * @author Michael Brzustowicz
21 | */
22 | public class HashingDictionary implements Dictionary {
23 | private int numTerms; // 2^n is optimal
24 |
25 | public HashingDictionary() {
26 | // 2^20 = 1048576
27 | this(new Double(Math.pow(2,20)).intValue());
28 | }
29 |
30 | public HashingDictionary(int numTerms) {
31 | this.numTerms = numTerms;
32 | }
33 |
34 | @Override
35 | public Integer getTermIndex(String term) {
36 | return Math.floorMod(term.hashCode(), numTerms);
37 | }
38 |
39 | @Override
40 | public int getNumTerms() {
41 | return numTerms;
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/IrisPCAExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import com.oreilly.dswj.datasets.Iris;
19 | import java.io.File;
20 | import java.io.IOException;
21 | import javafx.application.Application;
22 | import javafx.embed.swing.SwingFXUtils;
23 | import javafx.scene.Scene;
24 | import javafx.scene.SnapshotParameters;
25 | import javafx.scene.chart.NumberAxis;
26 | import javafx.scene.chart.ScatterChart;
27 | import javafx.scene.chart.XYChart;
28 | import javafx.scene.image.WritableImage;
29 | import javafx.stage.Stage;
30 | import javax.imageio.ImageIO;
31 | import org.apache.commons.math3.linear.RealMatrix;
32 |
33 | /**
34 | *
35 | * @author Michael Brzustowicz
36 | */
37 | public class IrisPCAExample extends Application {
38 |
39 | /**
40 | * @param args the command line arguments
41 | */
42 | public static void main(String[] args) {
43 | launch(args);
44 | }
45 |
46 | @Override
47 | public void start(Stage stage) throws Exception {
48 | Iris iris = new Iris();
49 |
50 | PCA pca = new PCA(iris.getData(), new PCAEIGImplementation());
51 | System.out.println(pca.getExplainedVariances());
52 | System.out.println(pca.getCumulativeVariances());
53 | RealMatrix irispca = pca.getPrincipalComponents(0.98);//2);
54 | System.out.println(irispca);
55 |
56 |
57 | XYChart.Series series = new XYChart.Series();
58 | for (int i = 0; i < irispca.getRowDimension(); i++) {
59 | series.getData().add(new XYChart.Data(irispca.getEntry(i, 0), irispca.getEntry(i, 1)));
60 | }
61 |
62 | NumberAxis xAxis = new NumberAxis();
63 | xAxis.setLabel("X1");
64 | NumberAxis yAxis = new NumberAxis();
65 | yAxis.setLabel("X2");
66 | ScatterChart scatterChart = new ScatterChart<>(xAxis,yAxis);
67 | scatterChart.getData().add(series);
68 | scatterChart.setAnimated(false);
69 |
70 | Scene scene = new Scene(scatterChart,800,600);
71 |
72 | /*
73 | tell the stage what scene to use and render it!
74 | */
75 | stage.setScene(scene);
76 | stage.show();
77 |
78 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null);
79 |
80 | // TODO: probably use a file chooser here
81 | // File outFile = new File("Iris2PCA.png");
82 | //
83 | // try {
84 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", outFile);
85 | // } catch (IOException e) {
86 | // // TODO: handle exception here
87 | // }
88 |
89 | }
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/LabelEncoder.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.Arrays;
19 | import java.util.List;
20 |
21 | /**
22 | * Encode labels of type T to an ArrayList with optional sorting
23 | * @author Michael Brzustowicz
24 | * @param
25 | */
26 | public class LabelEncoder {
27 |
28 | private final List classes;
29 |
30 | public LabelEncoder(T[] labels) {
31 | // Arrays.sort(labels); // can sort first but not required
32 | classes = Arrays.asList(labels);
33 | }
34 |
35 | public List getClasses() {
36 | return classes;
37 | }
38 |
39 | public int encode(T label) {
40 | return classes.indexOf(label);
41 | }
42 |
43 | public T decode(int index) {
44 | return classes.get(index);
45 | }
46 |
47 | public int[] encodeOneHot(T label) {
48 | int[] oneHot = new int[classes.size()];
49 | oneHot[encode(label)] = 1;
50 | return oneHot;
51 | }
52 |
53 | public T decodeOneHot(int[] oneHot) {
54 | return decode(Arrays.binarySearch(oneHot, 1));
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/LinearLossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.analysis.function.Signum;
19 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class LinearLossFunction implements LossFunction {
29 |
30 | @Override
31 | public double getSampleLoss(double predicted, double target) {
32 | return Math.abs(predicted - target);
33 | }
34 |
35 | @Override
36 | public double getSampleLoss(RealVector predicted, RealVector target) {
37 | return predicted.getL1Distance(target);
38 | }
39 |
40 | @Override
41 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) {
42 | SummaryStatistics stats = new SummaryStatistics();
43 | for (int i = 0; i < predicted.getRowDimension(); i++) {
44 | double dist = getSampleLoss(predicted.getRowVector(i), target.getRowVector(i));
45 | stats.addValue(dist);
46 | }
47 | return stats.getMean();
48 | }
49 |
50 | @Override
51 | public double getSampleLossGradient(double predicted, double target) {
52 | return Math.signum(predicted - target); // -1, 0, 1
53 | }
54 |
55 | @Override
56 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) {
57 | return predicted.subtract(target).map(new Signum());
58 | }
59 |
60 | //TODO SparseToSignum would be nice!!! ie only process elements of the iterable
61 | @Override
62 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) {
63 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension());
64 | for (int i = 0; i < predicted.getRowDimension(); i++) {
65 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i)));
66 | }
67 | return loss;
68 | }
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/ListResampler.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.Collections;
19 | import java.util.List;
20 | import java.util.Random;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | * @param
26 | */
27 | public class ListResampler {
28 |
29 | private final List data;
30 | private final int trainingSetSize;
31 | private final int testSetSize;
32 | private final int validationSetSize;
33 |
34 | public ListResampler(List data, double testFraction, long seed) {
35 | this(data, testFraction, 0.0, seed);
36 | }
37 |
38 | public ListResampler(List data, double testFraction, double validationFraction, long seed) {
39 | this(data, testFraction, validationFraction, seed, false);
40 | }
41 |
42 | public ListResampler(List data, double testFraction, double validationFraction, long seed, boolean deepCopy) {
43 | this.data = deepCopy ? data : data; // deep copy so as not to alter original data!!!
44 | validationSetSize = new Double(validationFraction * data.size()).intValue();
45 | testSetSize = new Double(testFraction * data.size()).intValue();
46 | trainingSetSize = data.size() - (testSetSize + validationSetSize);
47 | Random rnd = new Random(seed);
48 | Collections.shuffle(data, rnd);
49 | }
50 |
51 | public List getValidationSet() {
52 | return data.subList(0, validationSetSize);
53 | }
54 |
55 | public List getTestSet() {
56 | return data.subList(validationSetSize, validationSetSize + testSetSize);
57 | }
58 |
59 | public List getTrainingSet() {
60 | return data.subList(validationSetSize + testSetSize, data.size());
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/LossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public interface LossFunction {
26 |
27 | /**
28 | * loss of one dimension of one sample output
29 | * @param predicted
30 | * @param target
31 | * @return
32 | */
33 | public double getSampleLoss(double predicted, double target);
34 |
35 | /**
36 | * combined loss over all dimensions of one sample output
37 | * @param predicted
38 | * @param target
39 | * @return
40 | */
41 | public double getSampleLoss(RealVector predicted, RealVector target);
42 |
43 | /**
44 | * average loss over all samples
45 | * @param predicted
46 | * @param target
47 | * @return
48 | */
49 | public double getMeanLoss(RealMatrix predicted, RealMatrix target);
50 |
51 | /**
52 | * derivative of loss of one dimension of one sample output
53 | * @param predicted
54 | * @param target
55 | * @return
56 | */
57 | public double getSampleLossGradient(double predicted, double target);
58 |
59 | /**
60 | * derivative of loss over all dimensions of one sample output
61 | * @param predicted
62 | * @param target
63 | * @return
64 | */
65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target);
66 |
67 | /**
68 | *
69 | * @param predicted
70 | * @param target
71 | * @return
72 | */
73 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target);
74 | }
75 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/MatrixScaleType.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | /**
19 | *
20 | * @author Michael Brzustowicz
21 | */
22 | public enum MatrixScaleType {
23 | MINMAX, CENTER, ZSCORE, L1, L2
24 | }
25 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/MatrixScalingOperator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
19 | import org.apache.commons.math3.linear.RealVector;
20 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class MatrixScalingOperator implements RealMatrixChangingVisitor {
27 |
28 | RealVector normals;
29 | MultivariateSummaryStatistics mss;
30 | MatrixScaleType scaleType;
31 |
32 | public MatrixScalingOperator(MultivariateSummaryStatistics mss, MatrixScaleType scaleType) {
33 | this.normals = null;
34 | this.mss = mss;
35 | this.scaleType = scaleType;
36 | }
37 |
38 | public MatrixScalingOperator(RealVector normals, MatrixScaleType scaleType) {
39 | this.normals = normals;
40 | this.mss = null;
41 | this.scaleType = scaleType;
42 | }
43 |
44 | @Override
45 | public void start(int rows, int columns, int startRow, int endRow,
46 | int startColumn, int endColumn) {
47 | // nothing
48 | }
49 |
50 | @Override
51 | public double visit(int row, int column, double value) {
52 | double min, max, avg, std, rowNormal;
53 | double scaledValue = Double.NaN;
54 |
55 | switch (scaleType) {
56 |
57 | case MINMAX:
58 | min = mss.getMin()[column];
59 | max = mss.getMax()[column];
60 | scaledValue = (max > min) ? (value - min) / (max - min) : (value - min);
61 | break;
62 |
63 | case CENTER:
64 | avg = mss.getMean()[column];
65 | scaledValue = value - avg;
66 | break;
67 |
68 | case ZSCORE:
69 | avg = mss.getMean()[column];
70 | std = mss.getStandardDeviation()[column];
71 | scaledValue = std > 0 ? (value - avg) / std : (value - avg);
72 | break;
73 |
74 | case L1:
75 | rowNormal = normals.getEntry(row);
76 | scaledValue = rowNormal > 0 ? value / rowNormal : 0;
77 | break;
78 |
79 | case L2:
80 | rowNormal = normals.getEntry(row);
81 | scaledValue = rowNormal > 0 ? value / rowNormal : 0;
82 | break;
83 |
84 | default:
85 | break;
86 | }
87 |
88 | return scaledValue;
89 | }
90 |
91 | @Override
92 | public double end() {
93 | return 0.0;
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/PCA.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class PCA {
26 |
27 | private final PCAImplementation pCAImplementation;
28 |
29 | /**
30 | * default is SVD implementation
31 | * @param data
32 | */
33 | public PCA(RealMatrix data) {
34 | this(data, new PCASVDImplementation());
35 | }
36 |
37 | public PCA(RealMatrix data, PCAImplementation pCAImplementation) {
38 | this.pCAImplementation = pCAImplementation;
39 | this.pCAImplementation.compute(data);
40 | }
41 |
42 | /**
43 | * Projects the centered data onto the new basis with k components
44 | * @param k number of components to use
45 | * @return
46 | */
47 | public RealMatrix getPrincipalComponents(int k) {
48 | return pCAImplementation.getPrincipalComponents(k);
49 | }
50 |
51 | public RealMatrix getPrincipalComponents(int k, RealMatrix otherData) {
52 | return pCAImplementation.getPrincipalComponents(k, otherData);
53 | }
54 |
55 | public RealVector getExplainedVariances() {
56 | return pCAImplementation.getExplainedVariances();
57 | }
58 |
59 | public RealVector getCumulativeVariances() {
60 | RealVector variances = getExplainedVariances();
61 | RealVector cumulative = variances.copy();
62 | double sum = 0;
63 | for (int i = 0; i < cumulative.getDimension(); i++) {
64 | sum += cumulative.getEntry(i);
65 | cumulative.setEntry(i, sum);
66 | }
67 | return cumulative;
68 | }
69 |
70 | public int getNumberOfComponents(double threshold) {
71 | RealVector cumulative = getCumulativeVariances();
72 | int numComponents=1;
73 | for (int i = 0; i < cumulative.getDimension(); i++) {
74 | numComponents = i + 1;
75 | if(cumulative.getEntry(i) >= threshold) {
76 | break;
77 | }
78 | }
79 | return numComponents;
80 | }
81 |
82 | public RealMatrix getPrincipalComponents(double threshold) {
83 | int numComponents = getNumberOfComponents(threshold);
84 | return getPrincipalComponents(numComponents);
85 | }
86 |
87 | public RealMatrix getPrincipalComponents(double threshold, RealMatrix otherData) {
88 | int numComponents = getNumberOfComponents(threshold);
89 | return getPrincipalComponents(numComponents, otherData);
90 | }
91 |
92 | }
93 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/PCAEIGImplementation.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.ArrayRealVector;
19 | import org.apache.commons.math3.linear.EigenDecomposition;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 | import org.apache.commons.math3.stat.correlation.Covariance;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class PCAEIGImplementation implements PCAImplementation {
29 |
30 | private transient RealMatrix data;
31 | private RealMatrix d; // eigenvalue matrix
32 | private RealMatrix v; // eigenvector matrix
33 | private RealVector explainedVariances;
34 | private transient EigenDecomposition eig;
35 | private final MatrixScaler matrixScaler;
36 |
37 | public PCAEIGImplementation() {
38 | matrixScaler = new MatrixScaler(MatrixScaleType.CENTER);
39 | }
40 |
41 | @Override
42 | public void compute(RealMatrix data) {
43 | this.data = data;
44 | eig = new EigenDecomposition(new Covariance(data).getCovarianceMatrix());
45 | d = eig.getD();
46 | v = eig.getV();
47 | }
48 |
49 | @Override
50 | public RealVector getExplainedVariances() {
51 | //TODO just make this a getter and compute in compute method
52 | int n = eig.getD().getColumnDimension(); //colD = rowD
53 | explainedVariances = new ArrayRealVector(n);
54 | double[] eigenValues = eig.getRealEigenvalues();
55 | double cumulative = 0.0;
56 | for (int i = 0; i < n; i++) {
57 | double var = eigenValues[i];
58 | cumulative += var;
59 | explainedVariances.setEntry(i, var);
60 | }
61 | /* dividing the vector by the last (highest) value maximizes to 1 */
62 | return explainedVariances.mapDivideToSelf(cumulative);
63 | }
64 |
65 | @Override
66 | public RealMatrix getPrincipalComponents(int k) {
67 | int m = eig.getV().getColumnDimension(); // rowD = colD
68 | matrixScaler.transform(data);
69 | // MatrixScaler.center(data);
70 | return data.multiply(eig.getV().getSubMatrix(0, m-1, 0, k-1));
71 | }
72 |
73 |
74 |
75 | @Override
76 | public RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData) {
77 | int numRows = v.getRowDimension();
78 | // NEW data transformed under OLD means
79 | matrixScaler.transform(otherData);
80 | return otherData.multiply(v.getSubMatrix(0, numRows-1, 0, numComponents-1));
81 | }
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/PCAImplementation.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public interface PCAImplementation {
26 |
27 | void compute(RealMatrix data);
28 |
29 | RealVector getExplainedVariances();
30 |
31 | RealMatrix getPrincipalComponents(int numComponents);
32 |
33 | RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData);
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/PCASVDImplementation.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.ArrayRealVector;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.RealVector;
21 | import org.apache.commons.math3.linear.SingularValueDecomposition;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class PCASVDImplementation implements PCAImplementation {
28 | private RealMatrix u;
29 | private RealMatrix s;
30 | private RealMatrix v;
31 | private MatrixScaler matrixScaler;
32 | private SingularValueDecomposition svd;
33 |
34 | @Override
35 | public void compute(RealMatrix data) {
36 | //centers the data in place and stores the column stats for later use
37 | matrixScaler = new MatrixScaler(data, MatrixScaleType.CENTER);
38 | svd = new SingularValueDecomposition(data);
39 | u = svd.getU();
40 | s = svd.getS();
41 | v = svd.getV();
42 | }
43 |
44 | @Override
45 | public RealVector getExplainedVariances() {
46 |
47 | double[] singularValues = svd.getSingularValues();
48 | int n = singularValues.length;
49 | int m = u.getRowDimension(); // number of rows in U is same as in data
50 | RealVector explainedVariances = new ArrayRealVector(n);
51 | double sum = 0.0;
52 | for (int i = 0; i < n; i++) {
53 | double var = Math.pow(singularValues[i], 2) / (double)(m-1);
54 | sum += var;
55 | explainedVariances.setEntry(i, var);
56 | }
57 | /* dividing the vector by the last (highest) value maximizes to 1 */
58 | return explainedVariances.mapDivideToSelf(sum);
59 |
60 | }
61 |
62 | @Override
63 | public RealMatrix getPrincipalComponents(int numComponents) {
64 | int numRows = svd.getU().getRowDimension();
65 | /* submatrix limits are inclusive */
66 | RealMatrix uk = u.getSubMatrix(0, numRows-1, 0, numComponents-1);
67 | RealMatrix sk = s.getSubMatrix(0, numComponents-1, 0, numComponents-1);
68 | return uk.multiply(sk);
69 | }
70 |
71 | @Override
72 | public RealMatrix getPrincipalComponents(int numComponents, RealMatrix otherData) {
73 | // center the (new) data on means from original data
74 | matrixScaler.transform(otherData);
75 | int numRows = v.getRowDimension();
76 | // subMatrix indeces are inclusive
77 | return otherData.multiply(v.getSubMatrix(0, numRows-1, 0, numComponents-1));
78 | }
79 |
80 | }
81 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/ProbabilityEncoder.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
19 | import org.apache.commons.math3.linear.ArrayRealVector;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class ProbabilityEncoder {
28 |
29 | public RealVector getBinary(RealVector probabilities, double threshhold) {
30 | RealVector out = new ArrayRealVector(probabilities.getDimension());
31 | for (int i = 0; i < probabilities.getDimension(); i++) {
32 | double entry = probabilities.getEntry(i) >= threshhold ? 1 : 0;
33 | out.setEntry(i, entry);
34 | }
35 | return out;
36 | }
37 |
38 | public RealVector getOneHot(RealVector probabilities) {
39 | RealVector out = new ArrayRealVector(probabilities.getDimension());
40 | out.setEntry(probabilities.getMaxIndex(), 1);
41 | return out;
42 | }
43 |
44 | public RealMatrix getOneHot(RealMatrix probabilities) {
45 | int numRows = probabilities.getRowDimension();
46 | int numCols = probabilities.getColumnDimension();
47 | RealMatrix out = new Array2DRowRealMatrix(numRows, numCols);
48 | for (int i = 0; i < numRows; i++) {
49 | int maxIndex = probabilities.getRowVector(i).getMaxIndex();
50 | out.setEntry(i, maxIndex, 1);
51 | }
52 | return out;
53 | }
54 | }
55 |
56 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/QuadraticLossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class QuadraticLossFunction implements LossFunction {
27 | @Override
28 | public double getSampleLoss(double predicted, double target) {
29 | double diff = predicted - target;
30 | return 0.5 * diff * diff;
31 | }
32 |
33 | @Override
34 | public double getSampleLoss(RealVector predicted, RealVector target) {
35 | double dist = predicted.getDistance(target);
36 | return 0.5 * dist * dist;
37 | }
38 |
39 | @Override
40 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) {
41 | SummaryStatistics stats = new SummaryStatistics();
42 | for (int i = 0; i < predicted.getRowDimension(); i++) {
43 | double dist = getSampleLoss(predicted.getRowVector(i), target.getRowVector(i));
44 | stats.addValue(dist);
45 | }
46 | return stats.getMean();
47 | }
48 |
49 | @Override
50 | public double getSampleLossGradient(double predicted, double target) {
51 | return predicted - target;
52 | }
53 |
54 | @Override
55 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) {
56 | return predicted.subtract(target);
57 | }
58 |
59 | @Override
60 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) {
61 | return predicted.subtract(target);
62 | }
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/SimpleTokenizer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.ArrayList;
19 | import java.util.List;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class SimpleTokenizer implements Tokenizer {
26 |
27 | private final int minTokenSize;
28 |
29 | public SimpleTokenizer(int minTokenSize) {
30 | this.minTokenSize = minTokenSize;
31 | }
32 |
33 | public SimpleTokenizer() {
34 | this(0);
35 | }
36 |
37 |
38 |
39 | @Override
40 | public String[] getTokens(String document) {
41 | String[] tokens = document.trim().split("\\s+");
42 | List cleanTokens = new ArrayList<>();
43 | for (String token : tokens) {
44 | String cleanToken = token.trim().toLowerCase().replaceAll("[^A-Za-z\']+", "");
45 | if(cleanToken.length() > minTokenSize) {
46 | cleanTokens.add(cleanToken);
47 | }
48 | }
49 | return cleanTokens.toArray(new String[0]);
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/SoftMaxCrossEntropyLossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.RealVector;
21 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
22 | import org.apache.commons.math3.util.FastMath;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class SoftMaxCrossEntropyLossFunction implements LossFunction {
29 |
30 | @Override
31 | public double getSampleLoss(double predicted, double target) {
32 | return predicted > 0 ? -1.0 * target * FastMath.log(predicted) : 0;
33 | }
34 |
35 | @Override
36 | public double getSampleLoss(RealVector predicted, RealVector target) {
37 | double sampleLoss = 0.0;
38 | for (int i = 0; i < predicted.getDimension(); i++) {
39 | sampleLoss += getSampleLoss(predicted.getEntry(i), target.getEntry(i));
40 | }
41 | return sampleLoss;
42 | }
43 |
44 | @Override
45 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) {
46 | SummaryStatistics stats = new SummaryStatistics();
47 | for (int i = 0; i < predicted.getRowDimension(); i++) {
48 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i)));
49 | }
50 | return stats.getMean();
51 | }
52 |
53 | /**
54 | * dE/dy = - t_i / y_i
55 | * @param predicted
56 | * @param target
57 | * @return
58 | */
59 | @Override
60 | public double getSampleLossGradient(double predicted, double target) {
61 | return -1.0 * target / predicted;
62 | }
63 |
64 | @Override
65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) {
66 | return target.ebeDivide(predicted).mapMultiplyToSelf(-1.0);
67 | }
68 |
69 | @Override
70 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) {
71 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension());
72 | for (int i = 0; i < predicted.getRowDimension(); i++) {
73 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i)));
74 | }
75 | return loss;
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/TFIDF.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class TFIDF implements RealMatrixChangingVisitor {
26 |
27 | private final int numDocuments;
28 | private final RealVector termDocumentFrequency;
29 | private final double logNumDocuments;
30 |
31 | public TFIDF(int numDocuments, RealVector termDocumentFrequency) {
32 | this.numDocuments = numDocuments;
33 | this.termDocumentFrequency = termDocumentFrequency;
34 | this.logNumDocuments = numDocuments > 0 ? Math.log(numDocuments) : 0;
35 | }
36 |
37 | @Override
38 | public void start(int rows, int columns, int startRow, int endRow,
39 | int startColumn, int endColumn) {
40 | //NA
41 | }
42 |
43 | @Override
44 | public double visit(int row, int column, double value) {
45 | double df = termDocumentFrequency.getEntry(column);
46 | double logDF = df > 0 ? Math.log(df) : 0.0;
47 | // TFIDF = TF_i * log(N/DF_i) = TF_i * ( log(N) - log(DF_i) )
48 | return value * (logNumDocuments - logDF);
49 | }
50 |
51 | @Override
52 | public double end() {
53 | return 0.0;
54 | }
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/TFIDFVectorizer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.List;
19 | import org.apache.commons.math3.linear.OpenMapRealVector;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class TFIDFVectorizer {
28 |
29 | private Vectorizer vectorizer;
30 | private Vectorizer binaryVectorizer;
31 | private int numTerms;
32 |
33 | public TFIDFVectorizer(Dictionary dictionary, Tokenizer tokenzier) {
34 | vectorizer = new Vectorizer(dictionary, tokenzier, false);
35 | binaryVectorizer = new Vectorizer(dictionary, tokenzier, true);
36 | numTerms = dictionary.getNumTerms();
37 | }
38 |
39 | public TFIDFVectorizer() {
40 | this(new HashingDictionary(16384), new SimpleTokenizer());
41 | }
42 |
43 | public RealVector getTermDocumentCount(List documents) {
44 | RealVector vector = new OpenMapRealVector(numTerms);
45 | for (String document : documents) {
46 | vector.add(binaryVectorizer.getCountVector(document));
47 | }
48 | return vector;
49 | }
50 |
51 | public RealMatrix getTFIDF(List documents) {
52 | int numDocuments = documents.size();
53 | RealVector df = getTermDocumentCount(documents);
54 | RealMatrix tfidf = vectorizer.getCountMatrix(documents);
55 | tfidf.walkInOptimizedOrder(new TFIDF(numDocuments, df));
56 | return tfidf;
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/TermDictionary.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.HashMap;
19 | import java.util.Map;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class TermDictionary implements Dictionary {
26 |
27 | private final Map indexedTerms;
28 | private int counter;
29 |
30 | public TermDictionary() {
31 | indexedTerms = new HashMap<>();
32 | counter = 0;
33 | }
34 |
35 | public void addTerm(String term) {
36 | if(!indexedTerms.containsKey(term)) {
37 | indexedTerms.put(term, counter++);
38 | }
39 | }
40 |
41 | public void addTerms(String[] terms) {
42 | for (String term : terms) {
43 | addTerm(term);
44 | }
45 | }
46 |
47 | @Override
48 | public Integer getTermIndex(String term) {
49 | return indexedTerms.get(term);
50 | }
51 |
52 | @Override
53 | public int getNumTerms() {
54 | return indexedTerms.size();
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/Tokenizer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | /**
19 | *
20 | * @author Michael Brzustowicz
21 | */
22 | public interface Tokenizer {
23 | String[] getTokens(String document);
24 | }
25 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/TwoPointLossFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
19 | import org.apache.commons.math3.linear.ArrayRealVector;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
23 | import org.apache.commons.math3.util.FastMath;
24 |
25 | /**
26 | * target = -1 or 1
27 | * predicted is double between -1.0 and 1.0
28 | * @author Michael Brzustowicz
29 | */
30 | public class TwoPointLossFunction implements LossFunction {
31 |
32 | @Override
33 | public double getSampleLoss(double predicted, double target) {
34 | // convert -1:1 to 0:1 scale
35 | double y = 0.5 * (predicted + 1);
36 | double t = 0.5 * (target + 1);
37 | return -1.0 * (t * ((y>0)?FastMath.log(y):0) +
38 | (1.0 - t)*(y<1?FastMath.log(1.0-y):0));
39 | }
40 |
41 | @Override
42 | public double getSampleLoss(RealVector predicted, RealVector target) {
43 | double loss = 0.0;
44 | for (int i = 0; i < predicted.getDimension(); i++) {
45 | loss += getSampleLoss(predicted.getEntry(i), target.getEntry(i));
46 | }
47 | return loss;
48 | }
49 |
50 | @Override
51 | public double getMeanLoss(RealMatrix predicted, RealMatrix target) {
52 | SummaryStatistics stats = new SummaryStatistics();
53 | for (int i = 0; i < predicted.getRowDimension(); i++) {
54 | stats.addValue(getSampleLoss(predicted.getRowVector(i), target.getRowVector(i)));
55 | }
56 | return stats.getMean();
57 | }
58 |
59 | @Override
60 | public double getSampleLossGradient(double predicted, double target) {
61 | return (predicted - target) / (1 - predicted * predicted);
62 | }
63 |
64 | @Override
65 | public RealVector getSampleLossGradient(RealVector predicted, RealVector target) {
66 | RealVector loss = new ArrayRealVector(predicted.getDimension());
67 | for (int i = 0; i < predicted.getDimension(); i++) {
68 | loss.setEntry(i, getSampleLossGradient(predicted.getEntry(i), target.getEntry(i)));
69 | }
70 | return loss;
71 | }
72 |
73 | @Override
74 | public RealMatrix getLossGradient(RealMatrix predicted, RealMatrix target) {
75 | RealMatrix loss = new Array2DRowRealMatrix(predicted.getRowDimension(), predicted.getColumnDimension());
76 | for (int i = 0; i < predicted.getRowDimension(); i++) {
77 | loss.setRowVector(i, getSampleLossGradient(predicted.getRowVector(i), target.getRowVector(i)));
78 | }
79 | return loss;
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/Vectorizer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import java.util.List;
19 | import org.apache.commons.math3.linear.OpenMapRealMatrix;
20 | import org.apache.commons.math3.linear.OpenMapRealVector;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 | import org.apache.commons.math3.linear.RealVector;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class Vectorizer {
29 |
30 | private final Dictionary dictionary;
31 | private final Tokenizer tokenzier;
32 | private final boolean isBinary;
33 |
34 | public Vectorizer(Dictionary dictionary, Tokenizer tokenzier, boolean isBinary) {
35 | this.dictionary = dictionary;
36 | this.tokenzier = tokenzier;
37 | this.isBinary = isBinary;
38 | }
39 |
40 | public Vectorizer() {
41 | this(new HashingDictionary(), new SimpleTokenizer(), false);
42 | }
43 |
44 | public RealVector getCountVector(String document) {
45 |
46 | RealVector vector = new OpenMapRealVector(dictionary.getNumTerms());
47 |
48 | String[] tokens = tokenzier.getTokens(document);
49 |
50 | for (String token : tokens) {
51 |
52 | Integer index = dictionary.getTermIndex(token);
53 |
54 | if(index != null) {
55 |
56 | if(isBinary) {
57 | vector.setEntry(index, 1);
58 | } else {
59 | vector.addToEntry(index, 1); // increment !
60 | }
61 | }
62 | }
63 | return vector;
64 | }
65 |
66 | public RealMatrix getCountMatrix(List documents) {
67 | int rowDimension = documents.size();
68 | int columnDimension = dictionary.getNumTerms();
69 | RealMatrix matrix = new OpenMapRealMatrix(rowDimension, columnDimension);
70 | int counter = 0;
71 | for (String document : documents) {
72 | matrix.setRowVector(counter++, getCountVector(document));
73 | }
74 | return matrix;
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/dataops/VectorizerExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.dataops;
17 |
18 | import com.oreilly.dswj.datasets.Sentiment;
19 | import java.io.IOException;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class VectorizerExample {
27 |
28 | /**
29 | * @param args the command line arguments
30 | * @throws java.io.IOException
31 | */
32 | public static void main(String[] args) throws IOException {
33 |
34 | Sentiment sentiment = new Sentiment();
35 |
36 | TermDictionary termDictionary = new TermDictionary();
37 |
38 | SimpleTokenizer tokenizer = new SimpleTokenizer();
39 |
40 | for (String document : sentiment.getDocuments()) {
41 | String[]tokens = tokenizer.getTokens(document);
42 | termDictionary.addTerms(tokens);
43 | }
44 |
45 | // System.out.println(termDictionary.getTermIndex("iasdfasd"));
46 | // System.exit(0);
47 |
48 |
49 | Vectorizer vectorizer = new Vectorizer(termDictionary, tokenizer, false);
50 | RealMatrix counts = vectorizer.getCountMatrix(sentiment.getDocuments());
51 | System.out.println(counts.getSubMatrix(0, 5, 0, 5));
52 |
53 | Vectorizer binaryVectorizer = new Vectorizer(termDictionary, tokenizer, true);
54 | RealMatrix binCounts = binaryVectorizer.getCountMatrix(sentiment.getDocuments());
55 | System.out.println(binCounts.getSubMatrix(0, 5, 0, 5));
56 |
57 | TFIDFVectorizer tfidfVectorizer = new TFIDFVectorizer(termDictionary, tokenizer);
58 | RealMatrix tfidf = tfidfVectorizer.getTFIDF(sentiment.getDocuments());
59 | System.out.println(tfidf.getSubMatrix(0, 5, 0, 5));
60 |
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/datasets/Anscombe.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.datasets;
17 |
18 | /**
19 | *
20 | * @author Michael Brzustowicz
21 | */
22 | public class Anscombe {
23 | public static final double[] x1 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0};
24 | public static final double[] y1 = {8.04, 6.95, 7.58, 8.81, 8.33, 9.96, 7.24, 4.26, 10.84, 4.82, 5.68};
25 | public static final double[] x2 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0};
26 | public static final double[] y2 = {9.14, 8.14, 8.74, 8.77, 9.26, 8.10, 6.13, 3.10, 9.13, 7.26, 4.74};
27 | public static final double[] x3 = {10.0, 8.0, 13.0, 9.0, 11.0, 14.0, 6.0, 4.0, 12.0, 7.0, 5.0};
28 | public static final double[] y3 = {7.46, 6.77, 12.74, 7.11, 7.81, 8.84, 6.08, 5.39, 8.15, 6.42, 5.73};
29 | public static final double[] x4 = {8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 19.0, 8.0, 8.0, 8.0};
30 | public static final double[] y4 = {6.58, 5.76, 7.71, 8.84, 8.47, 7.04, 5.25, 12.50, 5.56, 7.91, 6.89};
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/datasets/Iris.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.datasets;
17 |
18 | import java.io.BufferedReader;
19 |
20 | import java.io.IOException;
21 | import java.io.InputStream;
22 | import java.io.InputStreamReader;
23 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
24 | import org.apache.commons.math3.linear.RealMatrix;
25 |
26 | /**
27 | *
28 | * @author Michael Brzustowicz
29 | */
30 | public class Iris {
31 |
32 | private final RealMatrix data;
33 | private final RealMatrix labels;
34 | private static final String FILEPATH = "/datasets/iris/iris_data.csv";
35 |
36 | public Iris() throws IOException {
37 |
38 | data = new Array2DRowRealMatrix(150, 4);
39 | labels = new Array2DRowRealMatrix(150, 3); // binarized
40 |
41 | try(InputStream inputStream = getClass().getResourceAsStream(FILEPATH)) {
42 | BufferedReader br = new BufferedReader(new InputStreamReader(inputStream));
43 | String line;
44 | int rowCounter = 0;
45 | while ((line = br.readLine()) != null) {
46 |
47 | String[] s = line.split(",");
48 | double sepalLength = Double.parseDouble(s[0].trim());
49 | double sepalWidth = Double.parseDouble(s[1].trim());
50 | double petalLength = Double.parseDouble(s[2].trim());
51 | double petalWidth = Double.parseDouble(s[3].trim());
52 | String plantClass = s[4].trim();
53 |
54 | data.setEntry(rowCounter, 0, sepalLength);
55 | data.setEntry(rowCounter, 1, sepalWidth);
56 | data.setEntry(rowCounter, 2, petalLength);
57 | data.setEntry(rowCounter, 3, petalWidth);
58 |
59 | if (null != plantClass) switch (plantClass) {
60 | case "Iris-setosa":
61 | labels.setEntry(rowCounter, 0, 1);
62 | break;
63 | case "Iris-versicolor":
64 | labels.setEntry(rowCounter, 1, 1);
65 | break;
66 | case "Iris-virginica":
67 | labels.setEntry(rowCounter, 3, 1);
68 | break;
69 | default:
70 | System.out.println("something wrong with " + plantClass);
71 | break;
72 | }
73 |
74 | rowCounter++;
75 | }
76 | }
77 | }
78 |
79 | public RealMatrix getData() {
80 | return data;
81 | }
82 |
83 | public RealMatrix getLabels() {
84 | return labels;
85 | }
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/datasets/MNIST.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.datasets;
17 |
18 | import java.io.BufferedInputStream;
19 | import java.io.DataInputStream;
20 | import java.io.FileNotFoundException;
21 | import java.io.IOException;
22 | import org.apache.commons.math3.linear.BlockRealMatrix;
23 | import org.apache.commons.math3.linear.OpenMapRealMatrix;
24 | import org.apache.commons.math3.linear.RealMatrix;
25 |
26 | /**
27 | * all external data files are in src/main/resources/
28 | * these four files in src/main/resources/datasets/mnist/
29 | * @author Michael Brzustowicz
30 | */
31 | public class MNIST {
32 |
33 | public RealMatrix trainingData;
34 | public RealMatrix trainingLabels;
35 | public RealMatrix testingData;
36 | public RealMatrix testingLabels;
37 |
38 | public MNIST() throws IOException {
39 | trainingData = new BlockRealMatrix(60000, 784); // image to vector
40 | trainingLabels = new OpenMapRealMatrix(60000, 10); // the one hot label
41 | testingData = new BlockRealMatrix(10000, 784); // image to vector
42 | testingLabels = new OpenMapRealMatrix(10000, 10); // the one hot label
43 | loadTrainingData("/datasets/mnist/train-images-idx3-ubyte"); // loads from jar
44 | loadTrainingLabels("/datasets/mnist/train-labels-idx1-ubyte");
45 | loadTestingData("/datasets/mnist/t10k-images-idx3-ubyte");
46 | loadTestingLabels("/datasets/mnist/t10k-labels-idx1-ubyte");
47 |
48 | }
49 |
50 | private void loadTrainingData(String filename) throws FileNotFoundException, IOException {
51 | try (DataInputStream di = new DataInputStream(new BufferedInputStream(getClass().getResourceAsStream(filename)))) {
52 | int magicNumber = di.readInt(); //2051
53 | int numImages = di.readInt(); // 60000
54 | int numRows = di.readInt(); // 28
55 | int numCols = di.readInt(); // 28
56 | int vecSize = numRows * numCols; // 784
57 | for (int i = 0; i < numImages; i++) {
58 | for (int j = 0; j < vecSize; j++) {
59 | // values are 0 to 255, so normalize
60 | trainingData.setEntry(i, j, di.readUnsignedByte() / 255.0);
61 | }
62 | }
63 | }
64 | }
65 |
66 | private void loadTestingData(String filename) throws FileNotFoundException, IOException {
67 |
68 | try (DataInputStream di = new DataInputStream(new BufferedInputStream(getClass().getResourceAsStream(filename)))) {
69 | int magicNumber = di.readInt(); //2051
70 | int numImages = di.readInt(); // 10000
71 | int numRows = di.readInt(); // 28
72 | int numCols = di.readInt(); // 28
73 | for (int i = 0; i < numImages; i++) {
74 | for (int j = 0; j < 784; j++) {
75 | // values are 0 to 255, so normalize
76 | testingData.setEntry(i, j, di.readUnsignedByte() / 255.0);
77 | }
78 | }
79 | }
80 | }
81 |
82 | private void loadTrainingLabels(String filename) throws FileNotFoundException, IOException {
83 | try (DataInputStream di = new DataInputStream( new BufferedInputStream(getClass().getResourceAsStream(filename)))) {
84 | int magicNumber = di.readInt(); //2049
85 | int numImages = di.readInt(); // 60000
86 | for (int i = 0; i < numImages; i++) {
87 | // one-hot-encoding, column of 0-9 is given one all else 0
88 | trainingLabels.setEntry(i, di.readUnsignedByte(), 1.0);
89 | }
90 | }
91 | }
92 |
93 | private void loadTestingLabels(String filename) throws FileNotFoundException, IOException {
94 | try (DataInputStream di = new DataInputStream( new BufferedInputStream(getClass().getResourceAsStream(filename)))) {
95 | int magicNumber = di.readInt(); //2049
96 | int numImages = di.readInt(); // 10000
97 | for (int i = 0; i < numImages; i++) {
98 | // one-hot-encoding, column of 0-9 is given one all else 0
99 | testingLabels.setEntry(i, di.readUnsignedByte(), 1.0);
100 | }
101 | }
102 | }
103 |
104 | }
105 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/datasets/MultiNormalMixtureDataset.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.datasets;
17 |
18 | import java.util.ArrayList;
19 | import java.util.List;
20 | import java.util.Random;
21 | import org.apache.commons.math3.distribution.AbstractRealDistribution;
22 | import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
23 | import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
24 | import org.apache.commons.math3.distribution.UniformRealDistribution;
25 | import org.apache.commons.math3.linear.CholeskyDecomposition;
26 | import org.apache.commons.math3.stat.correlation.Covariance;
27 | import org.apache.commons.math3.util.Pair;
28 |
29 | /**
30 | *
31 | * @author Michael Brzustowicz
32 | */
33 | public class MultiNormalMixtureDataset {
34 | int dimension;
35 | List> mixture;
36 | MixtureMultivariateNormalDistribution mixtureDistribution;
37 |
38 | public MultiNormalMixtureDataset(int dimension) {
39 | this.dimension = dimension;
40 | mixture = new ArrayList<>();
41 | }
42 |
43 | public MixtureMultivariateNormalDistribution getMixtureDistribution() {
44 | return mixtureDistribution;
45 | }
46 |
47 | public void createRandomMixtureModel(int numComponents, double boxSize, long seed) {
48 | Random rnd = new Random(seed);
49 | double limit = boxSize / dimension;
50 | UniformRealDistribution dist = new UniformRealDistribution(-limit, limit);
51 | UniformRealDistribution disC = new UniformRealDistribution(-1, 1);
52 | dist.reseedRandomGenerator(seed);
53 | disC.reseedRandomGenerator(seed);
54 | for (int i = 0; i < numComponents; i++) {
55 | double alpha = rnd.nextDouble();
56 | double[] means = dist.sample(dimension);
57 | double[][] cov = getRandomCovariance(disC);
58 | MultivariateNormalDistribution multiNorm = new MultivariateNormalDistribution(means, cov);
59 | addMultinormalDistributionToModel(alpha, multiNorm);
60 | }
61 | mixtureDistribution = new MixtureMultivariateNormalDistribution(mixture);
62 | mixtureDistribution.reseedRandomGenerator(seed); // calls to sample() will return same results
63 | }
64 |
65 | /**
66 | * NOTE this is for adding both internal and external, known distros but
67 | * need to figure out clean way to add the mixture to mixtureDistribution!!!
68 | * @param alpha
69 | * @param dist
70 | */
71 | public void addMultinormalDistributionToModel(double alpha, MultivariateNormalDistribution dist) {
72 | // note all alpha will be L1 normed
73 | mixture.add(new Pair(alpha, dist));
74 | }
75 |
76 | public double[][] getSimulatedData(int size) {
77 | return mixtureDistribution.sample(size);
78 | }
79 |
80 | private double[][] getRandomCovariance(AbstractRealDistribution dist) {
81 | double[][] data = new double[2*dimension][dimension];
82 | double determinant = 0.0;
83 | Covariance cov = new Covariance();
84 | while(Math.abs(determinant) == 0) {
85 | for (int i = 0; i < data.length; i++) {
86 | data[i] = dist.sample(dimension);
87 | }
88 | // check if cov matrix is singular ... if so, keep going
89 | cov = new Covariance(data);
90 | determinant = new CholeskyDecomposition(cov.getCovarianceMatrix()).getDeterminant();
91 | }
92 | return cov.getCovarianceMatrix().getData();
93 | }
94 |
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/datasets/Sentiment.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.datasets;
17 |
18 | import java.io.BufferedReader;
19 | import java.io.IOException;
20 | import java.io.InputStream;
21 | import java.io.InputStreamReader;
22 | import java.util.ArrayList;
23 | import java.util.List;
24 |
25 |
26 | /**
27 | * Sentiment labelled sentences
28 | * https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences
29 | * contains data from imdb, yelp and amazon
30 | * imdb has 1000 sent with 500 pos (1) and 500 neg (0)
31 | * yelp has 3729 sent with 500 pos (1) and 500 neg (0)
32 | * amzn has 15004 sent with 500 pos (1) and 500 neg (0)
33 | * @author Michael Brzustowicz
34 | */
35 | public class Sentiment {
36 |
37 | private final List documents = new ArrayList<>();
38 | private final List sentiments = new ArrayList<>();
39 | private static final String IMDB_RESOURCE = "/datasets/sentiment/imdb_labelled.txt";
40 | private static final String YELP_RESOURCE = "/datasets/sentiment/yelp_labelled.txt";
41 | private static final String AMZN_RESOURCE = "/datasets/sentiment/amazon_cells_labelled.txt";
42 | public enum DataSource {IMDB, YELP, AMZN};
43 |
44 | public Sentiment() throws IOException {
45 | parseResource(IMDB_RESOURCE); // 1000 sentences
46 | parseResource(YELP_RESOURCE); // 1000 sentences
47 | parseResource(AMZN_RESOURCE); // 1000 sentences
48 | }
49 |
50 | public List getSentiments(DataSource dataSource) {
51 | int fromIndex = 0; // inclusive
52 | int toIndex = 3000; // exclusive
53 | switch(dataSource) {
54 | case IMDB:
55 | fromIndex = 0;
56 | toIndex = 1000;
57 | break;
58 | case YELP:
59 | fromIndex = 1000;
60 | toIndex = 2000;
61 | break;
62 | case AMZN:
63 | fromIndex = 2000;
64 | toIndex = 3000;
65 | break;
66 | }
67 | return sentiments.subList(fromIndex, toIndex);
68 | }
69 |
70 | public List getDocuments(DataSource dataSource) {
71 | int fromIndex = 0; // inclusive
72 | int toIndex = 3000; // exclusive
73 | switch(dataSource) {
74 | case IMDB:
75 | fromIndex = 0;
76 | toIndex = 1000;
77 | break;
78 | case YELP:
79 | fromIndex = 1000;
80 | toIndex = 2000;
81 | break;
82 | case AMZN:
83 | fromIndex = 2000;
84 | toIndex = 3000;
85 | break;
86 | }
87 | return documents.subList(fromIndex, toIndex);
88 | }
89 |
90 | public List getSentiments() {
91 | return sentiments;
92 | }
93 |
94 | public List getDocuments() {
95 | return documents;
96 | }
97 |
98 | private void parseResource(String resource) throws IOException {
99 | try(InputStream inputStream = getClass().getResourceAsStream(resource)) {
100 | BufferedReader br = new BufferedReader(new InputStreamReader(inputStream));
101 | String line;
102 | while ((line = br.readLine()) != null) {
103 | String[] splitLine = line.split("\t");
104 | // both yelp and amzn have many sentences with no label
105 | if (splitLine.length > 1) {
106 | documents.add(splitLine[0]);
107 | sentiments.add(Integer.parseInt(splitLine[1]));
108 | }
109 | }
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/BasicBarChart.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 |
19 | import java.io.File;
20 | import javafx.application.Application;
21 | import javafx.collections.FXCollections;
22 | import javafx.embed.swing.SwingFXUtils;
23 | import javafx.scene.Scene;
24 | import javafx.scene.SnapshotParameters;
25 | import javafx.scene.chart.BarChart;
26 | import javafx.scene.chart.CategoryAxis;
27 | import javafx.scene.chart.NumberAxis;
28 | import javafx.scene.chart.XYChart.Series;
29 | import javafx.scene.chart.XYChart.Data;
30 | import javafx.scene.image.WritableImage;
31 | import javafx.stage.Stage;
32 | import javax.imageio.ImageIO;
33 |
34 | /**
35 | *
36 | * @author Michael Brzustowicz
37 | */
38 | public class BasicBarChart extends Application {
39 |
40 |
41 | /**
42 | * @param args the command line arguments
43 | */
44 | public static void main(String[] args) {
45 |
46 | launch(args);
47 | }
48 |
49 | @Override
50 | public void start(Stage stage) throws Exception {
51 |
52 | String[] catData = {"Mon", "Tues", "Wed", "Thurs", "Fri"};
53 | double[] yData = {1.3, 2.1, 3.3, 4.0, 4.8};
54 |
55 | /*
56 | create some data
57 | */
58 | Series series = new Series();
59 | for (int i = 0; i < yData.length; i++) {
60 | series.getData().add(new Data(catData[i], yData[i]));
61 | }
62 |
63 |
64 | //defining the axes
65 | CategoryAxis xAxis = new CategoryAxis();
66 | NumberAxis yAxis = new NumberAxis();
67 | xAxis.setLabel("x");
68 | yAxis.setLabel("y");
69 |
70 | //creating the bar chart;
71 | BarChart barChart = new BarChart<>(xAxis, yAxis);
72 | barChart.setAnimated(false);
73 | barChart.getData().add(series);
74 | barChart.setTitle("x vs. y");
75 | barChart.setHorizontalGridLinesVisible(false);
76 | barChart.setVerticalGridLinesVisible(false);
77 | barChart.setVerticalZeroLineVisible(false);
78 |
79 | /*
80 | create a scene using the chart
81 | */
82 | Scene scene = new Scene(barChart,800,600);
83 |
84 | /*
85 | tell the stage what scene to use and render it!
86 | */
87 | stage.setScene(scene);
88 | stage.show();
89 |
90 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null);
91 | // File file = new File("chart.png");
92 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file);
93 |
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/BasicScatterChart.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 |
19 | import java.io.File;
20 | import javafx.application.Application;
21 | import javafx.embed.swing.SwingFXUtils;
22 | import javafx.scene.Scene;
23 | import javafx.scene.SnapshotParameters;
24 | import javafx.scene.chart.NumberAxis;
25 | import javafx.scene.chart.ScatterChart;
26 | import javafx.scene.chart.XYChart.Series;
27 | import javafx.scene.chart.XYChart.Data;
28 | import javafx.scene.image.WritableImage;
29 | import javafx.stage.Stage;
30 | import javax.imageio.ImageIO;
31 |
32 | /**
33 | *
34 | * @author Michael Brzustowicz
35 | */
36 | public class BasicScatterChart extends Application {
37 |
38 |
39 | /**
40 | * @param args the command line arguments
41 | */
42 | public static void main(String[] args) {
43 |
44 | launch(args);
45 | }
46 |
47 | @Override
48 | public void start(Stage stage) throws Exception {
49 | int[] xData = {1, 2, 3, 4, 5};
50 | double[] yData = {1.3, 2.1, 3.3, 4.0, 4.8};
51 |
52 | /*
53 | create some data
54 | */
55 | Series series = new Series();
56 | for (int i = 0; i < yData.length; i++) {
57 | series.getData().add(new Data(xData[i], yData[i]));
58 | }
59 |
60 |
61 | //defining the axes
62 | NumberAxis xAxis = new NumberAxis();
63 | NumberAxis yAxis = new NumberAxis();
64 | xAxis.setLabel("x");
65 | yAxis.setLabel("y");
66 |
67 | //creating the scatter chart
68 | ScatterChart scatterChart = new ScatterChart<>(xAxis,yAxis);
69 | scatterChart.setAnimated(false);
70 | scatterChart.getData().add(series);
71 | scatterChart.setTitle("x vs. y");
72 | scatterChart.setHorizontalGridLinesVisible(false);
73 | scatterChart.setVerticalGridLinesVisible(false);
74 | scatterChart.setVerticalZeroLineVisible(false);
75 |
76 | /*
77 | create a scene using the chart
78 | */
79 | Scene scene = new Scene(scatterChart,800,600);
80 |
81 | /*
82 | tell the stage what scene to use and render it!
83 | */
84 | stage.setScene(scene);
85 | stage.show();
86 |
87 | // WritableImage image = scatterChart.snapshot(new SnapshotParameters(), null);
88 | // File file = new File("chart.png");
89 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file);
90 |
91 | }
92 |
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/DBApp.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 | import java.sql.Connection;
19 | import java.sql.DriverManager;
20 | import java.sql.PreparedStatement;
21 | import java.sql.ResultSet;
22 | import java.sql.SQLException;
23 | import java.sql.Statement;
24 |
25 | /**
26 | *
27 | * @author Michael Brzustowicz
28 | */
29 | public class DBApp {
30 |
31 | /**
32 | * @param args the command line arguments
33 | */
34 | public static void main(String[] args) {
35 |
36 | String uri = "jdbc:mysql://localhost:3306/mydb?user=root";
37 |
38 | try(Connection c = DriverManager.getConnection(uri)) {
39 |
40 | /* DROP / CREATE TABLE */
41 | String dropSQL = "DROP TABLE IF EXISTS data";
42 | String createSQL = "CREATE TABLE IF NOT EXISTS data(id INTEGER PRIMARY KEY, yr INTEGER, city VARCHAR(80))";
43 |
44 | try (Statement stmt = c.createStatement()) {
45 |
46 | stmt.execute(dropSQL);
47 |
48 | stmt.execute(createSQL);
49 |
50 | }
51 |
52 | /* INSERT DATA */
53 | String insertSQL = "INSERT INTO data(id, yr, city) VALUES(?, ?, ?)";
54 |
55 | try (PreparedStatement ps = c.prepareStatement(insertSQL)) {
56 |
57 | /* for one entry */
58 | ps.setInt(1, 1);
59 | ps.setInt(2, 2015);
60 | ps.setString(3, "San Francisco");
61 | ps.execute();
62 | }
63 |
64 | /* SELECT DATA */
65 | String selectSQL = "SELECT id, yr, city FROM data";
66 | try (Statement st = c.createStatement(); ResultSet rs = st.executeQuery(selectSQL)) {
67 |
68 | while(rs.next()) {
69 | // TODO ... do something with data
70 | System.out.println(rs.getInt("id")+" "+rs.getInt("yr")+" "+rs.getString("city"));
71 | }
72 | }
73 |
74 | } catch (SQLException e) {
75 | System.out.println(e.getMessage());
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/DBInsertBatchApp.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 | import java.sql.Connection;
19 | import java.sql.DriverManager;
20 | import java.sql.PreparedStatement;
21 | import java.sql.ResultSet;
22 | import java.sql.SQLException;
23 | import java.sql.Statement;
24 | import java.util.ArrayList;
25 | import java.util.List;
26 |
27 | /**
28 | *
29 | * @author Michael Brzustowicz
30 | */
31 | public class DBInsertBatchApp {
32 |
33 | public static List getData() {
34 | List data = new ArrayList<>();
35 | data.add(new Record(1, 2015, "San Francisco"));
36 | data.add(new Record(2, 2014, "New York"));
37 | data.add(new Record(3, 2012, "Los Angeles"));
38 | return data;
39 | }
40 |
41 | /**
42 | * @param args the command line arguments
43 | */
44 | public static void main(String[] args) {
45 |
46 | String uri = "jdbc:mysql://localhost:3306/mydb?user=root";
47 |
48 | try(Connection c = DriverManager.getConnection(uri)) {
49 |
50 | /* DROP / CREATE TABLE */
51 | String dropSQL = "DROP TABLE IF EXISTS data";
52 | String createSQL = "CREATE TABLE IF NOT EXISTS data(id INTEGER PRIMARY KEY, yr INTEGER, city VARCHAR(80))";
53 |
54 | try (Statement stmt = c.createStatement()) {
55 |
56 | stmt.execute(dropSQL);
57 |
58 | stmt.execute(createSQL);
59 |
60 | }
61 |
62 | /* INSERT BATCH DATA */
63 | String insertSQL = "INSERT INTO data(id, yr, city) VALUES(?, ?, ?)";
64 |
65 | try (PreparedStatement ps = c.prepareStatement(insertSQL)) {
66 |
67 | for (Record data : getData()) {
68 |
69 | ps.setInt(1, data.id);
70 | ps.setInt(2, data.year);
71 | ps.setString(3, data.city);
72 |
73 | /* add record to the batch !!! */
74 | ps.addBatch();
75 | }
76 |
77 | /* note this is different than the regular execute */
78 | ps.executeBatch();
79 |
80 | }
81 |
82 | /* SELECT DATA */
83 | String selectSQL = "SELECT id, yr, city FROM data";
84 | try (Statement st = c.createStatement(); ResultSet rs = st.executeQuery(selectSQL)) {
85 |
86 | while(rs.next()) {
87 | // TODO ... do something with data
88 | System.out.println(rs.getInt("id")+" "+rs.getInt("yr")+" "+rs.getString("city"));
89 | }
90 | }
91 |
92 | } catch (SQLException e) {
93 | System.out.println(e.getMessage());
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/FileIOExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 | import java.io.BufferedReader;
19 | import java.io.FileReader;
20 | import java.io.IOException;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class FileIOExample {
27 |
28 | /**
29 | * @param args the command line arguments
30 | */
31 | public static void main(String[] args) {
32 |
33 | String filename = "";
34 |
35 | // or use args to get filename
36 |
37 | try(BufferedReader br = new BufferedReader(new FileReader(filename))) {
38 |
39 | String line;
40 |
41 | while ((line = br.readLine()) != null) {
42 | // TODO ... do something with line
43 | // TODO ... parse line e.g. CSV, TSV, JSON
44 | // TODO ... check each value if required
45 | System.out.println(line);
46 | }
47 |
48 | } catch (IOException e) {
49 | System.err.println(e);
50 | }
51 |
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/io/Record.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.io;
17 |
18 | /**
19 | *
20 | * @author Michael Brzustowicz
21 | */
22 | public class Record {
23 | public int id;
24 | public int year;
25 | public String city;
26 |
27 | public Record() {
28 | }
29 |
30 | public Record(int id, int year, String city) {
31 | this.id = id;
32 | this.year = year;
33 | this.city = city;
34 | }
35 |
36 | }
37 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/BernoulliConditionalProbabilityEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
19 |
20 | /**
21 | *
22 | * @author Michael Brzustowicz
23 | */
24 | public class BernoulliConditionalProbabilityEstimator implements ConditionalProbabilityEstimator {
25 |
26 | @Override
27 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) {
28 | int n = features.length;
29 | double[] means = mss.getMean(); // this is actually the prob per features e.g. count / total
30 | double prob = 1.0;
31 | for (int i = 0; i < n; i++) {
32 | // if x_i = 1, the p, if x_i = 0 then 1-p ... but here x_i is a double
33 | prob *= (features[i] > 0.0) ? means[i] : 1-means[i];
34 | }
35 | return prob;
36 | }
37 |
38 | }
39 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/ClassifierAccuracy.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.ProbabilityEncoder;
19 | import org.apache.commons.math3.analysis.function.Abs;
20 | import org.apache.commons.math3.linear.ArrayRealVector;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 | import org.apache.commons.math3.linear.RealVector;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class ClassifierAccuracy {
29 |
30 | private final RealMatrix predictions;
31 | private final RealMatrix targets;
32 | private final ProbabilityEncoder probabilityEncoder;
33 | private RealVector classCount;
34 |
35 | /**
36 | *
37 | * @param predictions assumes probabilities
38 | * @param targets assumes binary multi-label OR one-hot-encoding
39 | */
40 | public ClassifierAccuracy(RealMatrix predictions, RealMatrix targets) {
41 | this.predictions = predictions;
42 | this.targets = targets;
43 | probabilityEncoder = new ProbabilityEncoder();
44 | //tally the binary class occurances per dimension
45 | classCount = new ArrayRealVector(targets.getColumnDimension());
46 | for (int i = 0; i < targets.getRowDimension(); i++) {
47 | classCount = classCount.add(targets.getRowVector(i));
48 | }
49 | }
50 |
51 | /**
52 | * assumes one hot encoding
53 | * @return
54 | */
55 | public RealVector getAccuracyPerDimension() {
56 |
57 | RealVector accuracy = new ArrayRealVector(predictions.getColumnDimension());
58 |
59 | for (int i = 0; i < predictions.getRowDimension(); i++) {
60 |
61 | RealVector binarized = probabilityEncoder.getOneHot(predictions.getRowVector(i));
62 |
63 | // 0*0, 0*1, 1*1 = 0 and 1*1 = 1 giving only true positives as 1 and all other 0
64 | RealVector decision = binarized.ebeMultiply(targets.getRowVector(i));
65 |
66 | // append TP counts to accuracy
67 | accuracy = accuracy.add(decision);
68 | }
69 | return accuracy.ebeDivide(classCount);
70 | }
71 |
72 | /**
73 | * assumes one hot encoding
74 | * @return
75 | */
76 | public double getAccuracy() {
77 | // convert accuracy_per_dim back to counts, then sum and divide by total rows
78 | return getAccuracyPerDimension().ebeMultiply(classCount).getL1Norm() / targets.getRowDimension();
79 | }
80 |
81 | // implements jaccard similarity scores
82 | public RealVector getAccuracyPerDimension(double threshold) { // assumes un-correlated multi-output
83 |
84 | RealVector accuracy = new ArrayRealVector(targets.getColumnDimension());
85 |
86 | for (int i = 0; i < predictions.getRowDimension(); i++) {
87 |
88 | //binarize the row vector according to the threshold
89 | RealVector binarized = probabilityEncoder.getBinary(predictions.getRowVector(i), threshold);
90 |
91 | // 0-0 (TN) and 1-1 (TP) = 0 while 1-0 = 1 and 0-1 = -1
92 | RealVector decision = binarized.subtract(targets.getRowVector(i)).map(new Abs()).mapMultiply(-1).mapAdd(1);
93 |
94 | // append either TP and TN counts to accuracy
95 | accuracy = accuracy.add(decision);
96 | }
97 | return accuracy.mapDivide((double) predictions.getRowDimension()); // accuracy for each dimension, given the threshold
98 | }
99 |
100 | public double getAccuracy(double threshold) {
101 | // mean of the accuracy vector
102 | return getAccuracyPerDimension(threshold).getL1Norm() / targets.getColumnDimension();
103 | }
104 |
105 |
106 |
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/ConditionalProbabilityEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
19 |
20 | /**
21 | *
22 | * @author Michael Brzustowicz
23 | */
24 | public interface ConditionalProbabilityEstimator {
25 |
26 | /**
27 | *
28 | * @param mss multivariate statistics object for this class
29 | * @param features
30 | * @return
31 | */
32 | double getProbability(MultivariateSummaryStatistics mss, double[] features);
33 | }
34 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/DeepNetwork.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import java.util.ArrayList;
19 | import java.util.List;
20 | import java.util.ListIterator;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class DeepNetwork extends IterativeLearningProcess {
28 |
29 | private final List layers;
30 |
31 | public DeepNetwork() {
32 | this.layers = new ArrayList<>();
33 | }
34 |
35 | public void addLayer(NetworkLayer networkLayer) {
36 | layers.add(networkLayer);
37 | }
38 |
39 | public List getLayers() {
40 | return layers;
41 | }
42 |
43 | @Override
44 | public RealMatrix predict(RealMatrix input) {
45 |
46 | /* the initial input MUST BE DEEP COPIED or is overwritten */
47 | RealMatrix layerInput = input.copy();
48 |
49 | for (NetworkLayer layer : layers) {
50 |
51 | layer.setInput(layerInput);
52 |
53 | /* calc the output and set to next layer input*/
54 | RealMatrix output = layer.getOutput(layerInput);
55 | layer.setOutput(output);
56 |
57 | /*
58 | does not need a deep copy, but be aware that
59 | every layer input shares memory of prior layer output
60 | */
61 | layerInput = output;
62 |
63 | }
64 |
65 | /* layerInput is holding the final output ... get a deep copy */
66 | return layerInput.copy();
67 |
68 | }
69 |
70 | @Override
71 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) {
72 |
73 | /* this is the gradient of the network error and starts the back prop process */
74 | RealMatrix layerError = getLossFunction().getLossGradient(output, target).copy();
75 |
76 | /* create list iterator and set cursor to last! */
77 | ListIterator li = layers.listIterator(layers.size());
78 |
79 | while (li.hasPrevious()) {
80 |
81 | NetworkLayer layer = (NetworkLayer) li.previous();
82 |
83 | /* get error input from higher layer */
84 | layer.setOutputError(layerError);
85 |
86 | /* this back propagates the error and updates weights */
87 | layer.update();
88 |
89 | /* pass along error to next layer down */
90 | layerError = layer.getInputError();
91 |
92 | }
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/DeepNetworkIrisExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.MatrixResampler;
19 | import com.oreilly.dswj.datasets.Iris;
20 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction;
21 | import java.io.IOException;
22 | import org.apache.commons.math3.linear.RealMatrix;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class DeepNetworkIrisExample {
29 |
30 | /**
31 | * @param args the command line arguments
32 | * @throws java.io.IOException
33 | */
34 | public static void main(String[] args) throws IOException {
35 | Iris iris = new Iris();
36 | MatrixResampler mr = new MatrixResampler(iris.getData(), iris.getLabels());
37 | mr.calculateTestTrainSplit(0.4, 0L);
38 | DeepNetwork net = new DeepNetwork();
39 | net.addLayer(new NetworkLayer(4, 10, new TanhOutputFunction(), new GradientDescent(0.001)));
40 | net.addLayer(new NetworkLayer(10, 3, new SoftmaxOutputFunction(), new GradientDescent(0.001)));
41 | net.setLossFunction(new SoftMaxCrossEntropyLossFunction());
42 | net.setBatchSize(0);
43 | net.setMaxIterations(6000);
44 | net.setTolerance(10E-6);
45 | net.learn(mr.getTrainingFeatures(), mr.getTrainingLabels());
46 | RealMatrix predictions = net.predict(mr.getTestingFeatures());
47 | ClassifierAccuracy acc = new ClassifierAccuracy(predictions, mr.getTestingLabels());
48 | System.out.println("converged = " + net.isConverged());
49 | System.out.println("iterations = " + net.getNumIterations());
50 | System.out.println("accuracy = " + acc.getAccuracy());
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/DeepNetworkMNISTExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.datasets.MNIST;
19 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction;
20 | import java.io.IOException;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class DeepNetworkMNISTExample {
28 |
29 | /**
30 | * @param args the command line arguments
31 | * @throws java.io.IOException
32 | */
33 | public static void main(String[] args) throws IOException {
34 |
35 | MNIST mnist = new MNIST();
36 |
37 | DeepNetwork network = new DeepNetwork();
38 |
39 | /* input, hidden and output layers */
40 | network.addLayer(new NetworkLayer(784, 500, new TanhOutputFunction(),
41 | new GradientDescentMomentum(0.0001, 0.95)));
42 |
43 | network.addLayer(new NetworkLayer(500, 300, new TanhOutputFunction(),
44 | new GradientDescentMomentum(0.0001, 0.95)));
45 |
46 | network.addLayer(new NetworkLayer(300, 10, new SoftmaxOutputFunction(),
47 | new GradientDescentMomentum(0.0001, 0.95)));
48 |
49 | /* runtime parameters */
50 | network.setLossFunction(new SoftMaxCrossEntropyLossFunction());
51 | network.setMaxIterations(10000);
52 | network.setTolerance(10E-6);
53 | network.setBatchSize(200);
54 |
55 | /* learn */
56 | network.learn(mnist.trainingData, mnist.trainingLabels);
57 |
58 | /* predict */
59 | RealMatrix prediction = network.predict(mnist.testingData);
60 |
61 | /* compute accuracy */
62 | ClassifierAccuracy accuracy = new ClassifierAccuracy(prediction, mnist.testingLabels);
63 |
64 | /* print report */
65 | System.out.println("isConverged = " + network.isConverged());
66 | System.out.println("numIter = " + network.getNumIterations());
67 | System.out.println("error = " + network.getLoss());
68 | System.out.println("accuracy = " + accuracy.getAccuracy());
69 | System.out.println("accuracy per dim = " + accuracy.getAccuracyPerDimension());
70 |
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/DeltaRule.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class DeltaRule implements Optimizer {
26 |
27 | private final double learningRate;
28 |
29 | public DeltaRule(double learningRate) {
30 | this.learningRate = learningRate;
31 | }
32 |
33 | @Override
34 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) {
35 | return weightGradient.scalarMultiply(-1.0 * learningRate);
36 | }
37 |
38 | @Override
39 | public RealVector getBiasUpdate(RealVector biasGradient) {
40 | return biasGradient.mapMultiply(-1.0 * learningRate);
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/GaussianConditionalProbabilityEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.distribution.NormalDistribution;
19 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
20 |
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class GaussianConditionalProbabilityEstimator implements ConditionalProbabilityEstimator{
27 |
28 | @Override
29 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) {
30 | double[] means = mss.getMean();
31 | double[] stds = mss.getStandardDeviation();
32 | double prob = 1.0;
33 | for (int i = 0; i < features.length; i++) {
34 | prob *= new NormalDistribution(means[i], stds[i]).density(features[i]);
35 | }
36 | return prob;
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/GradientDescent.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class GradientDescent implements Optimizer {
26 |
27 | private double learningRate;
28 |
29 | public GradientDescent(double learningRate) {
30 | this.learningRate = learningRate;
31 | }
32 |
33 | public void setLearningRate(double learningRate) {
34 | this.learningRate = learningRate;
35 | }
36 |
37 | @Override
38 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) {
39 | return weightGradient.scalarMultiply(-1.0 * learningRate);
40 | }
41 |
42 | @Override
43 | public RealVector getBiasUpdate(RealVector biasGradient) {
44 | return biasGradient.mapMultiply(-1.0 * learningRate);
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/GradientDescentMomentum.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.ArrayRealVector;
19 | import org.apache.commons.math3.linear.BlockRealMatrix;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class GradientDescentMomentum extends GradientDescent {
28 |
29 | private final double momentum;
30 | private RealMatrix priorWeightUpdate;
31 | private RealVector priorBiasUpdate;
32 |
33 | /**
34 | *
35 | * @param learningRate good default is 0.0001
36 | * @param momentum good default is 0.95
37 | */
38 | public GradientDescentMomentum(double learningRate, double momentum) {
39 | super(learningRate);
40 | this.momentum = momentum;
41 | priorWeightUpdate = null;
42 | priorBiasUpdate = null;
43 | }
44 |
45 | @Override
46 | public RealMatrix getWeightUpdate(RealMatrix weightGradient) {
47 | // creates matrix of zeros same size as gradients if not already exists
48 | if(priorWeightUpdate == null) {
49 | priorWeightUpdate = new BlockRealMatrix(weightGradient.getRowDimension(), weightGradient.getColumnDimension());
50 | }
51 | // add term from GradientDescent since it is already negative ( - eta * gradW )
52 | RealMatrix update = priorWeightUpdate.scalarMultiply(momentum).add(super.getWeightUpdate(weightGradient));
53 | priorWeightUpdate = update;
54 | return update;
55 | }
56 |
57 | @Override
58 | public RealVector getBiasUpdate(RealVector biasGradient) {
59 | if(priorBiasUpdate == null) {
60 | priorBiasUpdate = new ArrayRealVector(biasGradient.getDimension());
61 | }
62 | // add term from GradientDescent since it is already negative ( - eta * gradW )
63 | RealVector update = priorBiasUpdate.mapMultiply(momentum).add(super.getBiasUpdate(biasGradient));
64 | priorBiasUpdate = update;
65 | return update;
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/IterativeLearningProcess.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.Batch;
19 | import com.oreilly.dswj.dataops.LossFunction;
20 | import com.oreilly.dswj.dataops.QuadraticLossFunction;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class IterativeLearningProcess {
28 |
29 | private boolean converged;
30 | private int numIterations;
31 | private int maxIterations;
32 | private double loss;
33 | private double tolerance;
34 | private int batchSize; // if == 0 then uses ALL data
35 | private LossFunction lossFunction;
36 |
37 | public IterativeLearningProcess(LossFunction lossFunction) {
38 | this.lossFunction = lossFunction;
39 | loss = 0;
40 | converged = false;
41 | numIterations = 0;
42 | maxIterations = 200;
43 | tolerance = 10E-6;
44 | batchSize = 100;
45 | }
46 |
47 | public IterativeLearningProcess() {
48 | this(new QuadraticLossFunction());
49 | }
50 |
51 | public void learn(RealMatrix input, RealMatrix target) {
52 |
53 | double priorLoss = tolerance;
54 |
55 | numIterations = 0;
56 |
57 | loss = 0;
58 |
59 | converged = false;
60 |
61 | Batch batch = new Batch(input, target);
62 | RealMatrix inputBatch;
63 | RealMatrix targetBatch;
64 |
65 |
66 | while(numIterations < maxIterations && !converged) {
67 |
68 | if(batchSize > 0 && batchSize < input.getRowDimension()) {
69 |
70 | batch.calcNextBatch(batchSize);
71 | inputBatch = batch.getInputBatch();
72 | targetBatch = batch.getTargetBatch();
73 |
74 | } else {
75 |
76 | inputBatch = input;
77 | targetBatch = target;
78 |
79 | }
80 |
81 | RealMatrix outputBatch = predict(inputBatch);
82 |
83 | loss = lossFunction.getMeanLoss(outputBatch, targetBatch);
84 |
85 | if(Math.abs(priorLoss - loss) < tolerance) {
86 |
87 | converged = true;
88 |
89 | } else {
90 |
91 | update(inputBatch, targetBatch, outputBatch);
92 |
93 | priorLoss = loss;
94 |
95 | }
96 |
97 | numIterations++;
98 | }
99 |
100 | }
101 |
102 | public RealMatrix predict(RealMatrix input) {
103 | throw new UnsupportedOperationException("Implement the predict method!");
104 | }
105 |
106 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) {
107 | throw new UnsupportedOperationException("Implement the update method!");
108 | }
109 |
110 | public void setBatchSize(int batchSize) {
111 | this.batchSize = batchSize;
112 | }
113 |
114 | public void setMaxIterations(int maxIterations) {
115 | this.maxIterations = maxIterations;
116 | }
117 |
118 | public void setTolerance(double tolerance) {
119 | this.tolerance = tolerance;
120 | }
121 |
122 | public int getNumIterations() {
123 | return numIterations;
124 | }
125 |
126 | public double getLoss() {
127 | return loss;
128 | }
129 |
130 | public LossFunction getLossFunction() {
131 | return lossFunction;
132 | }
133 |
134 | public void setLossFunction(LossFunction lossFunction) {
135 | this.lossFunction = lossFunction;
136 | }
137 |
138 | public boolean isConverged() {
139 | return converged;
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/KMeansExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import java.util.ArrayList;
19 | import java.util.List;
20 | import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
21 | import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer;
22 | import org.apache.commons.math3.distribution.NormalDistribution;
23 | import org.apache.commons.math3.ml.clustering.CentroidCluster;
24 | import org.apache.commons.math3.ml.clustering.DoublePoint;
25 | import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
26 | import org.apache.commons.math3.ml.distance.EuclideanDistance;
27 |
28 | /**
29 | *
30 | * @author Michael Brzustowicz
31 | */
32 | public class KMeansExample {
33 |
34 | /**
35 | * @param args the command line arguments
36 | */
37 | public static void main(String[] args) {
38 | int p = 7; //dimensions
39 | int n = 1000; //num points
40 |
41 | List data = new ArrayList<>();
42 | NormalDistribution dist = new NormalDistribution();
43 | for (int i = 0; i < 1000; i++) {
44 | data.add(new DoublePoint(dist.sample(p)));
45 | }
46 |
47 | for (int i = 1; i < 5; i++) {
48 |
49 | System.out.println("cluster: " + i);
50 |
51 | KMeansPlusPlusClusterer kmpp = new KMeansPlusPlusClusterer<>(i);
52 | List> results = kmpp.cluster(data);
53 |
54 | /* use cluster vars to observe cluster quality */
55 | SumOfClusterVariances clusterVar = new SumOfClusterVariances<>(new EuclideanDistance());
56 | System.out.println("score: " + clusterVar.score(results));
57 |
58 | for (CentroidCluster result : results) {
59 | DoublePoint centroid = (DoublePoint) result.getCenter();
60 | // System.out.println(centroid);
61 | // result.getPoints();
62 | }
63 | }
64 |
65 | /* performs k++ numTrials times and returns only the best result */
66 | int numTrials = 10;
67 | for (int i = 1; i < 5; i++) {
68 | System.out.println("MULTI " + i);
69 | KMeansPlusPlusClusterer kmpp2 = new KMeansPlusPlusClusterer<>(i);
70 | MultiKMeansPlusPlusClusterer multiKMPP = new MultiKMeansPlusPlusClusterer<>(kmpp2, numTrials);
71 | List> multiResults = multiKMPP.cluster(data);
72 | /* use cluster vars to observe cluster quality */
73 | SumOfClusterVariances clusterVar = new SumOfClusterVariances<>(new EuclideanDistance());
74 | System.out.println("score: " + clusterVar.score(multiResults));
75 | for (CentroidCluster multiResult : multiResults) {
76 | // System.out.println(multiResult.getCenter());
77 | }
78 | }
79 |
80 |
81 | }
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/LinearModel.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.linalg.RandomizedMatrix;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.RealVector;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class LinearModel {
27 |
28 | private RealMatrix weight;
29 | private RealVector bias;
30 | private final OutputFunction outputFunction;
31 |
32 | public LinearModel(int inputDimension, int outputDimension,
33 | OutputFunction outputFunction) {
34 | RandomizedMatrix randM = new RandomizedMatrix();
35 | weight = randM.getMatrix(inputDimension, outputDimension);
36 | bias = randM.getVector(outputDimension);
37 | this.outputFunction = outputFunction;
38 | }
39 |
40 | public RealMatrix getOutput(RealMatrix input) {
41 | return outputFunction.getOutput(input, weight, bias);
42 | }
43 |
44 | public void addUpdateToWeight(RealMatrix weightUpdate) {
45 | weight = weight.add(weightUpdate);
46 | }
47 |
48 | public void addUpdateToBias(RealVector biasUpdate) {
49 | bias = bias.add(biasUpdate);
50 | }
51 |
52 | /* setter and getters */
53 |
54 | public void setWeight(RealMatrix weight) {
55 | this.weight = weight;
56 | }
57 |
58 | public void setBias(RealVector bias) {
59 | this.bias = bias;
60 | }
61 |
62 | public RealMatrix getWeight() {
63 | return weight;
64 | }
65 |
66 | public RealVector getBias() {
67 | return bias;
68 | }
69 |
70 | public OutputFunction getOutputFunction() {
71 | return outputFunction;
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/LinearModelEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.LossFunction;
19 | import org.apache.commons.math3.linear.ArrayRealVector;
20 | import org.apache.commons.math3.linear.RealMatrix;
21 | import org.apache.commons.math3.linear.RealVector;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class LinearModelEstimator extends IterativeLearningProcess {
28 |
29 | private final LinearModel linearModel;
30 | private final Optimizer optimizer;
31 |
32 | public LinearModelEstimator(
33 | LinearModel linearModel,
34 | LossFunction lossFunction,
35 | Optimizer optimizer) {
36 | super(lossFunction);
37 | this.linearModel = linearModel;
38 | this.optimizer = optimizer;
39 | }
40 |
41 | @Override
42 | public RealMatrix predict(RealMatrix input) {
43 | return linearModel.getOutput(input);
44 | }
45 |
46 | @Override
47 | protected void update(RealMatrix input, RealMatrix target, RealMatrix output) {
48 | RealMatrix weightGradient = input.transpose().multiply(output.subtract(target));
49 | RealMatrix weightUpdate = optimizer.getWeightUpdate(weightGradient);
50 | linearModel.addUpdateToWeight(weightUpdate);
51 |
52 | RealVector h = new ArrayRealVector(input.getRowDimension(), 1.0);
53 | RealVector biasGradient = output.subtract(target).preMultiply(h);
54 | RealVector biasUpdate = optimizer.getBiasUpdate(biasGradient);
55 | linearModel.addUpdateToBias(biasUpdate);
56 |
57 | }
58 |
59 | public LinearModel getLinearModel() {
60 | return linearModel;
61 | }
62 |
63 | public Optimizer getOptimizer() {
64 | return optimizer;
65 | }
66 |
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/LinearOutputFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.linalg.MatrixOperations;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.RealVector;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class LinearOutputFunction implements OutputFunction {
27 |
28 | @Override
29 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) {
30 | return MatrixOperations.XWplusB(input, weight, bias);
31 | }
32 |
33 | @Override
34 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) {
35 | // output gradient is all 1's ... so just return errorGradient
36 | return errorGradient;
37 | }
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/LogisticOutputFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.linalg.MatrixOperations;
19 | import com.oreilly.dswj.linalg.UnivariateFunctionMapper;
20 | import org.apache.commons.math3.analysis.UnivariateFunction;
21 | import org.apache.commons.math3.analysis.function.Sigmoid;
22 | import org.apache.commons.math3.linear.RealMatrix;
23 | import org.apache.commons.math3.linear.RealVector;
24 |
25 | /**
26 | *
27 | * @author Michael Brzustowicz
28 | */
29 | public class LogisticOutputFunction implements OutputFunction {
30 |
31 | @Override
32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) {
33 | return MatrixOperations.XWplusB(input, weight, bias, new Sigmoid());
34 | }
35 |
36 | @Override
37 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) {
38 |
39 | // this changes output permanently
40 | output.walkInOptimizedOrder(new UnivariateFunctionMapper(new LogisticGradient()));
41 |
42 | // output is now the output gradient
43 | return MatrixOperations.ebeMultiply(errorGradient, output);
44 | }
45 |
46 | private class LogisticGradient implements UnivariateFunction {
47 |
48 | @Override
49 | public double value(double x) {
50 | return x * (1 - x);
51 | }
52 |
53 | }
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/MultinomialConditionalProbabilityEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
19 |
20 | /**
21 | *
22 | * @author Michael Brzustowicz
23 | */
24 | public class MultinomialConditionalProbabilityEstimator implements ConditionalProbabilityEstimator {
25 |
26 | private double alpha;
27 |
28 | public MultinomialConditionalProbabilityEstimator(double alpha) {
29 | this.alpha = alpha; // Lidstone smoothing 0 > alpha > 1
30 | }
31 |
32 | public MultinomialConditionalProbabilityEstimator() {
33 | this(1); // Laplace smoothing
34 | }
35 |
36 | @Override
37 | public double getProbability(MultivariateSummaryStatistics mss, double[] features) {
38 | int n = features.length;
39 | double prob = 1.0;
40 | double[] sum = mss.getSum(); // array of x_i sums for this class
41 | double total = 0.0; // total count of all features
42 | for (int i = 0; i < n; i++) {
43 | total += sum[i];
44 | }
45 |
46 | /* works great for smaller x_i ie features[i] */
47 | // for (int i = 0; i < n; i++) {
48 | // prob *= Math.pow((sum[i] + alpha) / (total + alpha * n), features[i]);
49 | // }
50 | // return prob;
51 |
52 | /* for large x_i need to solve in log space and convert back with exp */
53 | prob = 0;
54 | for (int i = 0; i < n; i++) {
55 | prob += features[i] * Math.log((sum[i] + alpha) / (total + alpha * n));
56 | }
57 | return Math.exp(prob);
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/NaiveBayes.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import java.util.HashMap;
19 | import java.util.Map;
20 | import org.apache.commons.math3.linear.BlockRealMatrix;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 | import org.apache.commons.math3.stat.descriptive.MultivariateSummaryStatistics;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class NaiveBayes {
29 |
30 | private Map statistics;
31 | private final Map stats;
32 | private final ConditionalProbabilityEstimator conditionalProbabilityEstimator;
33 | private int numberOfPoints; // total number of points the model was trained on
34 |
35 | public NaiveBayes() {
36 | this(new GaussianConditionalProbabilityEstimator());
37 | }
38 |
39 | /**
40 | * provide the strategy for implementing the conditional probability
41 | * @param conditionalProbabilityEstimator
42 | */
43 | public NaiveBayes(ConditionalProbabilityEstimator conditionalProbabilityEstimator) {
44 | stats = new HashMap<>();
45 | statistics = new HashMap<>();
46 | this.conditionalProbabilityEstimator = conditionalProbabilityEstimator;
47 | numberOfPoints = 0;
48 | }
49 |
50 |
51 | /**
52 | *
53 | * @param input
54 | * @param target multi class OR one hot encoded labels
55 | */
56 | public void learn(RealMatrix input, RealMatrix target) {
57 | // if numTargetCols == 1 then multiclass e.g. 0, 1, 2, 3
58 | // else one-hot e.g. 1000, 0100, 0010, 0001
59 |
60 | numberOfPoints += input.getRowDimension();
61 |
62 | for (int i = 0; i < input.getRowDimension(); i++) {
63 |
64 | double[] rowData = input.getRow(i);
65 | int label;
66 |
67 | if (target.getColumnDimension()==1) {
68 | label = new Double(target.getEntry(i, 0)).intValue();
69 | } else {
70 | label = target.getRowVector(i).getMaxIndex();
71 | }
72 |
73 | if(!statistics.containsKey(label)) {
74 | statistics.put(label, new MultivariateSummaryStatistics(rowData.length, true));
75 | }
76 |
77 | statistics.get(label).addValue(rowData);
78 |
79 | }
80 | }
81 |
82 | public RealMatrix predict(RealMatrix input) {
83 |
84 | int numRows = input.getRowDimension();
85 | int numCols = statistics.size();
86 | RealMatrix predictions = new BlockRealMatrix(numRows, numCols);
87 |
88 | for (int i = 0; i < numRows; i++) {
89 |
90 | // double[] rowData = input.getRow(i);
91 | double[] probs = new double[numCols];
92 | double sumProbs = 0;
93 |
94 | for (Map.Entry entrySet : statistics.entrySet()) {
95 |
96 | Integer classNumber = entrySet.getKey(); // assumes these are 0, 1, 2 ... n-1
97 | MultivariateSummaryStatistics mss = entrySet.getValue();
98 |
99 | /* prior prob n_k / N ie num points in class divided by total points */
100 | double prob = new Long(mss.getN()).doubleValue() / numberOfPoints;
101 |
102 | /* depends on type ... Gaussian, Multinomial or Bernoulli */
103 | prob *= conditionalProbabilityEstimator.getProbability(mss, input.getRow(i));
104 |
105 | probs[classNumber] = prob;
106 | sumProbs += prob;
107 | }
108 |
109 | /* L1 norm the probs */
110 | for (int j = 0; j < numCols; j++) {
111 | probs[j] /= sumProbs;
112 | }
113 |
114 | predictions.setRow(i, probs);
115 |
116 | }
117 |
118 | return predictions;
119 | }
120 |
121 | }
122 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/NaiveBayesGaussianIrisExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.MatrixResampler;
19 | import com.oreilly.dswj.datasets.Iris;
20 | import java.io.IOException;
21 | import org.apache.commons.math3.linear.RealMatrix;
22 |
23 | /**
24 | *
25 | * @author Michael Brzustowicz
26 | */
27 | public class NaiveBayesGaussianIrisExample {
28 |
29 | /**
30 | * @param args the command line arguments
31 | * @throws java.io.IOException
32 | */
33 | public static void main(String[] args) throws IOException {
34 |
35 | Iris iris = new Iris();
36 | MatrixResampler mr = new MatrixResampler(iris.getData(), iris.getLabels());
37 | mr.calculateTestTrainSplit(0.4, 0L);
38 |
39 | NaiveBayes nb = new NaiveBayes(new GaussianConditionalProbabilityEstimator());
40 | nb.learn(mr.getTrainingFeatures(), mr.getTrainingLabels());
41 |
42 | RealMatrix predictions = nb.predict(mr.getTestingFeatures());
43 |
44 | ClassifierAccuracy acc = new ClassifierAccuracy(predictions, mr.getTestingLabels());
45 | System.out.println(acc.getAccuracyPerDimension());
46 | System.out.println(acc.getAccuracy());
47 |
48 |
49 |
50 |
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/NetworkLayer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.ArrayRealVector;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.RealVector;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class NetworkLayer extends LinearModel {
27 |
28 | RealMatrix input;
29 | RealMatrix inputError;
30 | RealMatrix output;
31 | RealMatrix outputError;
32 | Optimizer optimizer;
33 |
34 | public NetworkLayer(int inputDimension, int outputDimension,
35 | OutputFunction outputFunction, Optimizer optimizer) {
36 | super(inputDimension, outputDimension, outputFunction);
37 | this.optimizer = optimizer;
38 | }
39 |
40 | public void update() {
41 | //back propagate error
42 | /* D = eps o f'(XW) where o is Hadamard product or J f'(XW) where J is Jacobian */
43 | RealMatrix deltas = getOutputFunction().getDelta(outputError, output);
44 |
45 | /* E_out = D W^T */
46 | inputError = deltas.multiply(getWeight().transpose());
47 |
48 | /* W = W - alpha * delta * input */
49 | RealMatrix weightGradient = input.transpose().multiply(deltas);
50 |
51 | /* w_{t+1} = w_{t} + \delta w_{t} */
52 | addUpdateToWeight(optimizer.getWeightUpdate(weightGradient));
53 |
54 | // this essentially sums the columns of delta and that vector is grad_b
55 | RealVector h = new ArrayRealVector(input.getRowDimension(), 1.0);
56 | RealVector biasGradient = deltas.preMultiply(h);
57 | addUpdateToBias(optimizer.getBiasUpdate(biasGradient));
58 | }
59 |
60 | public void setOutputError(RealMatrix outputError) {
61 | this.outputError = outputError;
62 | }
63 |
64 | public void setInputError(RealMatrix inputError) {
65 | this.inputError = inputError;
66 | }
67 |
68 | public void setInput(RealMatrix input) {
69 | this.input = input;
70 | }
71 |
72 | public void setOutput(RealMatrix output) {
73 | this.output = output;
74 | }
75 |
76 | public RealMatrix getInputError() {
77 | return inputError;
78 | }
79 |
80 | }
81 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/Optimizer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public interface Optimizer {
26 | RealMatrix getWeightUpdate(RealMatrix weightGradient);
27 | RealVector getBiasUpdate(RealVector biasGradient);
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/OutputFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import org.apache.commons.math3.linear.RealMatrix;
19 | import org.apache.commons.math3.linear.RealVector;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public interface OutputFunction {
26 | RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias);
27 | RealMatrix getDelta(RealMatrix error, RealMatrix output);
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/SilhouetteCoefficient.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import java.util.List;
19 | import org.apache.commons.math3.exception.OutOfRangeException;
20 | import org.apache.commons.math3.linear.ArrayRealVector;
21 | import org.apache.commons.math3.linear.RealVector;
22 | import org.apache.commons.math3.ml.clustering.Cluster;
23 | import org.apache.commons.math3.ml.clustering.DoublePoint;
24 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
25 |
26 | /**
27 | *
28 | * @author Michael Brzustowicz
29 | */
30 | public class SilhouetteCoefficient {
31 | private final List> clusters;
32 | private double coefficient;// = 0.0;
33 | private int numClusters;
34 | private int numSamples;
35 |
36 | public SilhouetteCoefficient(List> clusters) throws OutOfRangeException {
37 |
38 | this.clusters = clusters;
39 |
40 | if(checkClusterSize()) {
41 | calculateMeanCoefficient();
42 | } else {
43 | throw new OutOfRangeException(clusters.size(), 2, numSamples - 1);
44 | }
45 | }
46 |
47 | public double getCoefficient() {
48 | return coefficient;
49 | }
50 |
51 |
52 | private void calculateMeanCoefficient() {
53 | SummaryStatistics stats = new SummaryStatistics();
54 | int clusterNumber = 0;
55 | for (Cluster cluster : clusters) {
56 | for (DoublePoint point : cluster.getPoints()) {
57 | double s = calculateCoefficientForOnePoint(point, clusterNumber);
58 | stats.addValue(s);
59 | }
60 | clusterNumber++;
61 | }
62 | coefficient = stats.getMean();
63 | }
64 |
65 | private double calculateCoefficientForOnePoint(DoublePoint onePoint, int clusterLabel) {
66 |
67 | /* all other points will compared to this one */
68 | RealVector vector = new ArrayRealVector(onePoint.getPoint());
69 |
70 | double a = 0;
71 | double b = Double.MAX_VALUE;
72 |
73 | int clusterNumber = 0;
74 |
75 | for (Cluster cluster : clusters) {
76 |
77 | SummaryStatistics clusterStats = new SummaryStatistics();
78 |
79 | for (DoublePoint otherPoint : cluster.getPoints()) {
80 | RealVector otherVector = new ArrayRealVector(otherPoint.getPoint());
81 | double dist = vector.getDistance(otherVector);
82 | clusterStats.addValue(dist);
83 | }
84 |
85 | double avgDistance = clusterStats.getMean();
86 |
87 | if(clusterNumber==clusterLabel) {
88 | /* we have included a 0 distance of point with itself and need to subtract it out of the mean */
89 | double n = new Long(clusterStats.getN()).doubleValue();
90 | double correction = n / (n - 1.0);
91 | a = correction * avgDistance;
92 | } else {
93 | b = Math.min(avgDistance, b);
94 | }
95 |
96 | clusterNumber++;
97 | }
98 |
99 | return (b-a) / Math.max(a, b);
100 | }
101 |
102 | private boolean checkClusterSize() throws OutOfRangeException {
103 | numClusters = clusters.size();
104 | numSamples = 0;
105 | for (Cluster cluster : clusters) {
106 | numSamples += cluster.getPoints().size();
107 | }
108 | return numClusters >= 2 && numClusters <= (numSamples - 1) ;
109 | }
110 |
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/SoftmaxLinearModelExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.MatrixResampler;
19 | import com.oreilly.dswj.datasets.Iris;
20 | import com.oreilly.dswj.dataops.SoftMaxCrossEntropyLossFunction;
21 | import java.io.IOException;
22 | import org.apache.commons.math3.linear.RealMatrix;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class SoftmaxLinearModelExample {
29 |
30 | /**
31 | * @param args the command line arguments
32 | * @throws java.io.IOException
33 | */
34 | public static void main(String[] args) throws IOException {
35 | Iris iris = new Iris();
36 | MatrixResampler resampler = new MatrixResampler(iris.getData(), iris.getLabels());
37 | resampler.calculateTestTrainSplit(0.40, 0L);
38 |
39 | LinearModelEstimator estimator = new LinearModelEstimator(
40 | new LinearModel(4, 3, new SoftmaxOutputFunction()),
41 | new SoftMaxCrossEntropyLossFunction(),
42 | new DeltaRule(0.001));
43 |
44 |
45 | // /* this is the SAME thing as a lone layer network */
46 | // DeepNetwork estimator = new DeepNetwork();
47 | // estimator.setLossFunction(new SoftMaxCrossEntropyLossFunction());
48 | // estimator.addLayer(new NetworkLayer(4, 3, new SoftmaxOutputFunction(), new GradientDescent(0.001)));
49 | // estimator.setBatchSize(0);
50 |
51 |
52 | estimator.setMaxIterations(6000);
53 | estimator.setTolerance(10E-6);
54 |
55 | estimator.learn(resampler.getTrainingFeatures(), resampler.getTrainingLabels());
56 |
57 | RealMatrix prediction = estimator.predict(resampler.getTestingFeatures());
58 |
59 | ClassifierAccuracy accuracy = new ClassifierAccuracy(prediction, resampler.getTestingLabels());
60 |
61 | System.out.println("isConverged " + estimator.isConverged());
62 | System.out.println("numIterations " + estimator.getNumIterations());
63 | System.out.println("loss " + estimator.getLoss());
64 | System.out.println("accuracy " + accuracy.getAccuracy());
65 | System.out.println("accuracy per dim " + accuracy.getAccuracyPerDimension());
66 |
67 | //isConverged true
68 | //numIterations 3094
69 | //loss 0.07695531148974591
70 | //accuracy 0.9833333333333333
71 | //accuracy per dim {1; 0.9230769231; 1}
72 |
73 |
74 | }
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/SoftmaxOutputFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.dataops.MatrixScaler;
19 | import com.oreilly.dswj.linalg.MatrixOperations;
20 | import org.apache.commons.math3.analysis.function.Exp;
21 | import org.apache.commons.math3.linear.BlockRealMatrix;
22 | import org.apache.commons.math3.linear.RealMatrix;
23 | import org.apache.commons.math3.linear.RealVector;
24 |
25 | /**
26 | *
27 | * @author Michael Brzustowicz
28 | */
29 | public class SoftmaxOutputFunction implements OutputFunction {
30 |
31 | @Override
32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) {
33 | RealMatrix output = MatrixOperations.XWplusB(input, weight, bias, new Exp());
34 | MatrixScaler.l1(output);
35 | return output;
36 | }
37 |
38 | @Override
39 | public RealMatrix getDelta(RealMatrix error, RealMatrix output) {
40 |
41 | RealMatrix delta = new BlockRealMatrix(error.getRowDimension(), error.getColumnDimension());
42 |
43 | for (int i = 0; i < output.getRowDimension(); i++) {
44 | delta.setRowVector(i, getJacobian(output.getRowVector(i)).preMultiply(error.getRowVector(i)));
45 | }
46 |
47 | return delta;
48 | }
49 |
50 | private RealMatrix getJacobian(RealVector output) {
51 |
52 | int numRows = output.getDimension();
53 |
54 | int numCols = output.getDimension();
55 |
56 | RealMatrix jacobian = new BlockRealMatrix(numRows, numCols);
57 |
58 | for (int i = 0; i < numRows; i++) {
59 |
60 | double output_i = output.getEntry(i);
61 |
62 | for (int j = i; j < numCols; j++) {
63 |
64 | double output_j = output.getEntry(j);
65 |
66 | if(i==j) {
67 |
68 | jacobian.setEntry(i, i, output_i*(1-output_i));
69 |
70 | } else {
71 |
72 | jacobian.setEntry(i, j, -1.0 * output_i * output_j);
73 | jacobian.setEntry(j, i, -1.0 * output_j * output_i);
74 | }
75 |
76 | }
77 | }
78 | return jacobian;
79 | }
80 |
81 | }
82 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/learn/TanhOutputFunction.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.learn;
17 |
18 | import com.oreilly.dswj.linalg.MatrixOperations;
19 | import com.oreilly.dswj.linalg.UnivariateFunctionMapper;
20 | import org.apache.commons.math3.analysis.UnivariateFunction;
21 | import org.apache.commons.math3.analysis.function.Tanh;
22 | import org.apache.commons.math3.linear.RealMatrix;
23 | import org.apache.commons.math3.linear.RealVector;
24 |
25 | /**
26 | *
27 | * @author Michael Brzustowicz
28 | */
29 | public class TanhOutputFunction implements OutputFunction {
30 |
31 | @Override
32 | public RealMatrix getOutput(RealMatrix input, RealMatrix weight, RealVector bias) {
33 | return MatrixOperations.XWplusB(input, weight, bias, new Tanh());
34 | }
35 |
36 | @Override
37 | public RealMatrix getDelta(RealMatrix errorGradient, RealMatrix output) {
38 | // this changes output permanently
39 | output.walkInOptimizedOrder(new UnivariateFunctionMapper(new TanhGradient()));
40 |
41 | // output is now the output gradient
42 | return MatrixOperations.ebeMultiply(errorGradient, output);
43 | }
44 |
45 | private class TanhGradient implements UnivariateFunction {
46 |
47 | @Override
48 | public double value(double x) {
49 | return (1 - x * x);
50 | }
51 |
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/linalg/FunctionMapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.linalg;
17 |
18 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
19 |
20 | /**
21 | *
22 | * @author Michael Brzustowicz
23 | */
24 | public class FunctionMapper implements RealMatrixChangingVisitor {
25 |
26 | private final double power;
27 |
28 | public FunctionMapper(double power) {
29 | this.power = power;
30 | }
31 |
32 | @Override
33 | public void start(int rows, int columns, int startRow, int endRow,
34 | int startColumn, int endColumn) {
35 | // do nothing
36 | }
37 |
38 | @Override
39 | public double visit(int row, int column, double value) {
40 | return Math.pow(value, power);
41 | }
42 |
43 | @Override
44 | public double end() {
45 | return 0;
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/linalg/LinearSystemExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.linalg;
17 |
18 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
19 | import org.apache.commons.math3.linear.RealMatrix;
20 | import org.apache.commons.math3.linear.SingularValueDecomposition;
21 |
22 | /**
23 | *
24 | * @author Michael Brzustowicz
25 | */
26 | public class LinearSystemExample {
27 |
28 | /**
29 | * @param args the command line arguments
30 | */
31 | public static void main(String[] args) {
32 | double[][] xData = {{0, 0.5, 0.2}, {1, 1.2, .9}, {2, 2.5, 1.9}, {3, 3.6, 4.2}};
33 | double[][] yData = {{-1, -0.5}, {0.2, 1}, {0.9, 1.2}, {2.1, 1.5}};
34 | double[] ones = {1.0, 1.0, 1.0, 1.0};
35 | int xRows = 4;
36 | int xCols = 3;
37 | RealMatrix x = new Array2DRowRealMatrix(xRows, xCols + 1);
38 | x.setSubMatrix(xData, 0, 0);
39 | x.setColumn(3, ones); // 4th column is index of 3 !!!
40 | RealMatrix y = new Array2DRowRealMatrix(yData);
41 |
42 | SingularValueDecomposition svd = new SingularValueDecomposition(x);
43 | RealMatrix solution = svd.getSolver().solve(y);
44 | System.out.println(solution);
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/linalg/MatrixOperations.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.linalg;
17 |
18 | import org.apache.commons.math3.analysis.UnivariateFunction;
19 | import org.apache.commons.math3.distribution.AbstractRealDistribution;
20 | import org.apache.commons.math3.distribution.UniformRealDistribution;
21 | import org.apache.commons.math3.linear.ArrayRealVector;
22 | import org.apache.commons.math3.linear.BlockRealMatrix;
23 | import org.apache.commons.math3.linear.RealMatrix;
24 | import org.apache.commons.math3.linear.RealVector;
25 |
26 | /**
27 | *
28 | * @author Michael Brzustowicz
29 | */
30 | public class MatrixOperations {
31 |
32 | // TODO name this similar to BLAS ???
33 | public static RealMatrix XWplusB(RealMatrix X, RealMatrix W, RealVector b) {
34 | RealVector h = new ArrayRealVector(X.getRowDimension(), 1.0);
35 | return X.multiply(W).add(h.outerProduct(b));
36 | }
37 |
38 | public static RealMatrix XWplusB(RealMatrix X, RealMatrix W, RealVector b, UnivariateFunction univariateFunction) {
39 | RealMatrix z = XWplusB(X, W, b);
40 | z.walkInOptimizedOrder(new UnivariateFunctionMapper(univariateFunction));
41 | return z;
42 | }
43 |
44 | public static RealMatrix ebeMultiply(RealMatrix a, RealMatrix b) {
45 | int rowDimension = a.getRowDimension();
46 | int columnDimension = a.getColumnDimension();
47 | //TODO a and b should have same dimensions
48 | RealMatrix output = new BlockRealMatrix(rowDimension, columnDimension);
49 | for (int i = 0; i < rowDimension; i++) {
50 | for (int j = 0; j < columnDimension; j++) {
51 | output.setEntry(i, j, a.getEntry(i, j) * b.getEntry(i, j));
52 | }
53 | }
54 | return output;
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/linalg/RandomizedMatrix.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.linalg;
17 |
18 | import org.apache.commons.math3.distribution.AbstractRealDistribution;
19 | import org.apache.commons.math3.distribution.NormalDistribution;
20 | import org.apache.commons.math3.distribution.UniformRealDistribution;
21 | import org.apache.commons.math3.linear.Array2DRowRealMatrix;
22 | import org.apache.commons.math3.linear.ArrayRealVector;
23 | import org.apache.commons.math3.linear.BlockRealMatrix;
24 | import org.apache.commons.math3.linear.RealMatrix;
25 | import org.apache.commons.math3.linear.RealVector;
26 |
27 | /**
28 | *
29 | * @author Michael Brzustowicz
30 | */
31 | public class RandomizedMatrix {
32 |
33 | private AbstractRealDistribution distribution;
34 |
35 | public RandomizedMatrix(AbstractRealDistribution distribution, long seed) {
36 | this.distribution = distribution;
37 | distribution.reseedRandomGenerator(seed);
38 | }
39 |
40 | public RandomizedMatrix() {
41 | this(new UniformRealDistribution(-1, 1), 0L);
42 | }
43 |
44 | public void fillMatrix(RealMatrix matrix) {
45 | for (int i = 0; i < matrix.getRowDimension(); i++) {
46 | matrix.setRow(i, distribution.sample(matrix.getColumnDimension()));
47 | }
48 | }
49 |
50 | public RealMatrix getMatrix(int numRows, int numCols) {
51 | RealMatrix output = new BlockRealMatrix(numRows, numCols);
52 | for (int i = 0; i < numRows; i++) {
53 | output.setRow(i, distribution.sample(numCols));
54 | }
55 | return output;
56 | }
57 |
58 | public void fillVector(RealVector vector) {
59 | for (int i = 0; i < vector.getDimension(); i++) {
60 | vector.setEntry(i, distribution.sample());
61 | }
62 | }
63 |
64 | public RealVector getVector(int dim) {
65 | return new ArrayRealVector(distribution.sample(dim));
66 | }
67 |
68 | public static RealMatrix getTruncatedGaussian(int numRows, int numCols, long seed) {
69 | RealMatrix out = new BlockRealMatrix(numRows, numCols);
70 | NormalDistribution dist = new NormalDistribution(0.0, 0.5);
71 | dist.reseedRandomGenerator(seed);
72 | for (int i = 0; i < numRows; i++) {
73 | out.setRow(i, dist.sample(numCols));
74 | }
75 | return out;
76 | }
77 |
78 | public static RealMatrix getUniform(int numRows, int numCols, long seed) {
79 | RealMatrix out = new Array2DRowRealMatrix(numRows, numCols);
80 | UniformRealDistribution dist = new UniformRealDistribution(-1, 1);
81 | dist.reseedRandomGenerator(seed);
82 | for (int i = 0; i < numRows; i++) {
83 | out.setRow(i, dist.sample(numCols));
84 | }
85 | return out;
86 | }
87 |
88 |
89 |
90 | }
91 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/linalg/UnivariateFunctionMapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.linalg;
17 |
18 | import org.apache.commons.math3.analysis.UnivariateFunction;
19 | import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class UnivariateFunctionMapper implements RealMatrixChangingVisitor {
26 |
27 | UnivariateFunction univariateFunction;
28 |
29 | public UnivariateFunctionMapper(UnivariateFunction univariateFunction) {
30 | this.univariateFunction = univariateFunction;
31 | }
32 |
33 | @Override
34 | public void start(int rows, int columns, int startRow, int endRow,
35 | int startColumn, int endColumn) {
36 | //NA
37 | }
38 |
39 | @Override
40 | public double visit(int row, int column, double value) {
41 | return univariateFunction.value(value);
42 | }
43 |
44 | @Override
45 | public double end() {
46 | return 0.0;
47 | }
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/BasicMapReduceExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import org.apache.hadoop.conf.Configured;
19 | import org.apache.hadoop.fs.Path;
20 | import org.apache.hadoop.mapreduce.Job;
21 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
22 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
23 | import org.apache.hadoop.util.Tool;
24 | import org.apache.hadoop.util.ToolRunner;
25 |
26 | /**
27 | *
28 | * @author Michael Brzustowicz
29 | */
30 | public class BasicMapReduceExample extends Configured implements Tool {
31 |
32 | /**
33 | * @param args the command line arguments
34 | * @throws java.lang.Exception
35 | */
36 | public static void main(String[] args) throws Exception {
37 | int exitCode = ToolRunner.run(new BasicMapReduceExample(), args);
38 | System.exit(exitCode);
39 | }
40 |
41 | @Override
42 | public int run(String[] args) throws Exception {
43 | /* this is deprecated */
44 | // Job job = new Job(getConf());
45 | /* the singleton method is preferred */
46 | Job job = Job.getInstance(getConf());
47 | job.setJarByClass(BasicMapReduceExample.class);
48 | job.setJobName("BasicMapReduceExample");
49 |
50 | FileInputFormat.addInputPath(job, new Path(args[0]));
51 | FileOutputFormat.setOutputPath(job, new Path(args[1]));
52 |
53 | return job.waitForCompletion(true) ? 0 : 1;
54 | }
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/CustomWordCountMapReduceExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import org.apache.hadoop.conf.Configured;
19 | import org.apache.hadoop.fs.Path;
20 | import org.apache.hadoop.io.LongWritable;
21 | import org.apache.hadoop.io.Text;
22 | import org.apache.hadoop.mapreduce.Job;
23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
24 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
25 | import org.apache.hadoop.mapreduce.lib.reduce.IntSumReducer;
26 | import org.apache.hadoop.util.Tool;
27 | import org.apache.hadoop.util.ToolRunner;
28 |
29 | /**
30 | *
31 | * @author Michael Brzustowicz
32 | */
33 | public class CustomWordCountMapReduceExample extends Configured implements Tool {
34 |
35 | /**
36 | * @param args the command line arguments
37 | * @throws java.lang.Exception
38 | */
39 | public static void main(String[] args) throws Exception {
40 | int exitCode = ToolRunner.run(new CustomWordCountMapReduceExample(), args);
41 | System.exit(exitCode);
42 | }
43 |
44 | @Override
45 | public int run(String[] args) throws Exception {
46 | Job job = Job.getInstance(getConf());
47 | job.setJarByClass(CustomWordCountMapReduceExample.class);
48 | job.setJobName("CustomWordCountMapReduceExample");
49 |
50 | FileInputFormat.addInputPath(job, new Path(args[0]));
51 | FileOutputFormat.setOutputPath(job, new Path(args[1]));
52 |
53 | job.setMapperClass(SimpleTokenMapper.class);
54 | job.setMapOutputKeyClass(Text.class);
55 | job.setMapOutputValueClass(LongWritable.class);
56 | job.setReducerClass(IntSumReducer.class);
57 | job.setOutputKeyClass(Text.class);
58 | job.setOutputValueClass(LongWritable.class);
59 | job.setNumReduceTasks(1);
60 |
61 |
62 | return job.waitForCompletion(true) ? 0 : 1;
63 | }
64 |
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/JSONMapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import java.io.IOException;
19 | import org.apache.hadoop.io.LongWritable;
20 | import org.apache.hadoop.io.Text;
21 | import org.apache.hadoop.mapreduce.Mapper;
22 | import org.json.simple.JSONObject;
23 | import org.json.simple.parser.JSONParser;
24 | import org.json.simple.parser.ParseException;
25 |
26 | /**
27 | *
28 | * @author Michael Brzustowicz
29 | */
30 | public class JSONMapper extends Mapper {
31 |
32 | @Override
33 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
34 | JSONParser parser = new JSONParser();
35 | try {
36 | JSONObject obj = (JSONObject) parser.parse(value.toString());
37 |
38 | // get what you need from this object
39 | String userID = obj.get("user_id").toString();
40 | String productID = obj.get("product_id").toString();
41 | int numUnits = Integer.parseInt(obj.get("num_units").toString());
42 |
43 | JSONObject output = new JSONObject();
44 | output.put("productID", productID);
45 | output.put("numUnits", numUnits);
46 |
47 | context.write(new Text(userID), new Text(output.toString()));
48 |
49 |
50 | } catch (ParseException ex) {
51 | //TODO error parsing json
52 | }
53 |
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/SimpleTokenMapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import com.oreilly.dswj.dataops.SimpleTokenizer;
19 | import java.io.IOException;
20 | import org.apache.hadoop.io.LongWritable;
21 | import org.apache.hadoop.io.Text;
22 | import org.apache.hadoop.mapreduce.Mapper;
23 |
24 | /**
25 | * @author Michael Brzustowicz
26 | */
27 | public class SimpleTokenMapper extends Mapper {
28 |
29 | SimpleTokenizer tokenizer;
30 |
31 | @Override
32 | protected void setup(Context context) throws IOException {
33 |
34 | tokenizer = new SimpleTokenizer(3); // mintokensize = 3 !!!
35 | }
36 |
37 | @Override
38 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
39 |
40 | String[] tokens = tokenizer.getTokens(value.toString());
41 |
42 | for (String token : tokens) {
43 | context.write(new Text(token), new LongWritable(1L));
44 | }
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/SparseAlgebraMapReduceExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import org.apache.hadoop.conf.Configured;
19 | import org.apache.hadoop.fs.Path;
20 | import org.apache.hadoop.io.DoubleWritable;
21 | import org.apache.hadoop.io.IntWritable;
22 | import org.apache.hadoop.mapreduce.Job;
23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
24 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
25 | import org.apache.hadoop.util.Tool;
26 | import org.apache.hadoop.util.ToolRunner;
27 |
28 | /**
29 | * a sparse matrix stored as i,j,v is mapped to
30 | * k:v -> i:SparseMatrixWritables. The reducer classes each create a sparse
31 | * or dense vector from a serialized file of that object, and then each reducer
32 | * method creates a sparse vector for that matrix row (i) and then dot product
33 | * of row and vector is output as i:dotproduct
34 | * @author Michael Brzustowicz
35 | */
36 | public class SparseAlgebraMapReduceExample extends Configured implements Tool {
37 |
38 | /**
39 | * @param args the command line arguments
40 | * @throws java.lang.Exception
41 | */
42 | public static void main(String[] args) throws Exception {
43 | int exitCode = ToolRunner.run(new SparseAlgebraMapReduceExample(), args);
44 | System.exit(exitCode);
45 | }
46 |
47 | @Override
48 | public int run(String[] args) throws Exception {
49 | Job job = Job.getInstance(getConf());
50 | job.setJarByClass(SparseAlgebraMapReduceExample.class);
51 | job.setJobName("SparseAlgebraMapReduceExample");
52 |
53 | // third command line arg is the filepath to the serialized vector file
54 | job.getConfiguration().set("vectorFileName", args[2]);
55 |
56 | FileInputFormat.addInputPath(job, new Path(args[0]));
57 | FileOutputFormat.setOutputPath(job, new Path(args[1]));
58 |
59 | job.setMapperClass(SparseMatrixMultiplicationMapper.class);
60 | job.setMapOutputKeyClass(IntWritable.class);
61 | job.setMapOutputValueClass(SparseMatrixWritable.class);
62 | job.setReducerClass(SparseMatrixMultiplicationReducer.class);
63 | job.setOutputKeyClass(IntWritable.class);
64 | job.setOutputValueClass(DoubleWritable.class);
65 | job.setNumReduceTasks(1);
66 |
67 |
68 | return job.waitForCompletion(true) ? 0 : 1;
69 | }
70 |
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixMultiplicationMapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import java.io.IOException;
19 | import org.apache.hadoop.io.IntWritable;
20 | import org.apache.hadoop.io.LongWritable;
21 | import org.apache.hadoop.io.Text;
22 | import org.apache.hadoop.mapreduce.Mapper;
23 |
24 | /**
25 | * outputs key = row number of sparse entry and value is sparsematrix writable
26 | * this is a great mapper for matrix multiplication with a vector
27 | * @author Michael Brzustowicz
28 | */
29 | public class SparseMatrixMultiplicationMapper extends Mapper {
30 |
31 | /**
32 | *
33 | * @param key
34 | * @param value
35 | * @param context
36 | * @throws IOException
37 | * @throws InterruptedException
38 | */
39 | @Override
40 | protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
41 | try {
42 | String[] items = value.toString().split(",");
43 | int rowIndex = Integer.parseInt(items[0]);
44 | int columnIndex = Integer.parseInt(items[1]);
45 | double entry = Double.parseDouble(items[2]);
46 | SparseMatrixWritable smw = new SparseMatrixWritable(rowIndex, columnIndex, entry);
47 | context.write(new IntWritable(rowIndex), smw);
48 | //NOTE can add another context.write() for e.g. a symmetric matrix entry if matrix is sparse upper triag
49 | } catch (NumberFormatException | IOException | InterruptedException e) {
50 | context.getCounter("mapperErrors", e.getMessage()).increment(1L);
51 | }
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixMultiplicationReducer.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import java.io.FileInputStream;
19 | import java.io.IOException;
20 | import java.io.ObjectInputStream;
21 | import org.apache.commons.math3.linear.OpenMapRealVector;
22 | import org.apache.commons.math3.linear.RealVector;
23 | import org.apache.hadoop.io.DoubleWritable;
24 | import org.apache.hadoop.io.IntWritable;
25 |
26 | import org.apache.hadoop.mapreduce.Reducer;
27 |
28 | /**
29 | *
30 | * @author Michael Brzustowicz
31 | */
32 | public class SparseMatrixMultiplicationReducer extends Reducer{
33 |
34 | private RealVector vector;
35 |
36 | @Override
37 | protected void setup(Context context) throws IOException, InterruptedException {
38 | /* unserialize the RealVector object */
39 | // NOTE this is just the filename, please include the resource itself in the dist cache via -files at runtime
40 | //TODO set the filename in Job conf with set("vectorFileName", "actual file name here")
41 | String vectorFileName = context.getConfiguration().get("vectorFileName");
42 | try (ObjectInputStream in = new ObjectInputStream(new FileInputStream(vectorFileName))) {
43 | vector = (RealVector) in.readObject();
44 | } catch(ClassNotFoundException e) {
45 | //TODO
46 | }
47 | }
48 |
49 | @Override
50 | protected void reduce(IntWritable key, Iterable values, Context context) throws IOException, InterruptedException {
51 |
52 | /* rely on the fact that rowVector dim has to be same as input vector dim */
53 | RealVector rowVector = new OpenMapRealVector(vector.getDimension());
54 |
55 | for (SparseMatrixWritable value : values) {
56 | rowVector.setEntry(value.columnIndex, value.entry);
57 | }
58 |
59 | double dotProduct = rowVector.dotProduct(vector);
60 |
61 | /* only write the nonzero outputs, since the Matrix-Vector product is probably sparse */
62 | if(dotProduct != 0.0) {
63 | /* this outputs the vector index and it's value */
64 | context.write(key, new DoubleWritable(dotProduct));
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/SparseMatrixWritable.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import java.io.DataInput;
19 | import java.io.DataOutput;
20 | import java.io.IOException;
21 | import org.apache.hadoop.io.Writable;
22 | /**
23 | * ONE element of a sparse matrix
24 | * @author Michael Brzustowicz
25 | */
26 | public class SparseMatrixWritable implements Writable {
27 | int rowIndex; // i
28 | int columnIndex; // j
29 | double entry; // the value at i,j
30 |
31 | public SparseMatrixWritable() {
32 | }
33 |
34 | public SparseMatrixWritable(int rowIndex, int columnIndex, double entry) {
35 | this.rowIndex = rowIndex;
36 | this.columnIndex = columnIndex;
37 | this.entry = entry;
38 | }
39 |
40 | @Override
41 | public void write(DataOutput d) throws IOException {
42 | d.writeInt(rowIndex);
43 | d.writeInt(rowIndex);
44 | d.writeDouble(entry);
45 | }
46 |
47 | @Override
48 | public void readFields(DataInput di) throws IOException {
49 | rowIndex = di.readInt();
50 | columnIndex = di.readInt();
51 | entry = di.readDouble();
52 | }
53 |
54 | // THIS IS OPTIONAL
55 | public static SparseMatrixWritable read(DataInput di) throws IOException {
56 | SparseMatrixWritable smw = new SparseMatrixWritable();
57 | smw.readFields(di);
58 | return smw;
59 |
60 | }
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/mapreduce/WordCountMapReduceExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.mapreduce;
17 |
18 | import org.apache.hadoop.conf.Configured;
19 | import org.apache.hadoop.fs.Path;
20 | import org.apache.hadoop.io.IntWritable;
21 | import org.apache.hadoop.io.Text;
22 | import org.apache.hadoop.mapreduce.Job;
23 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
24 | import org.apache.hadoop.mapreduce.lib.map.TokenCounterMapper;
25 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
26 | import org.apache.hadoop.mapreduce.lib.reduce.IntSumReducer;
27 | import org.apache.hadoop.util.Tool;
28 | import org.apache.hadoop.util.ToolRunner;
29 |
30 | /**
31 | *
32 | * @author Michael Brzustowicz
33 | */
34 | public class WordCountMapReduceExample extends Configured implements Tool {
35 |
36 | /**
37 | * @param args the command line arguments
38 | * @throws java.lang.Exception
39 | */
40 | public static void main(String[] args) throws Exception {
41 | int exitCode = ToolRunner.run(new WordCountMapReduceExample(), args);
42 | System.exit(exitCode);
43 | }
44 |
45 | @Override
46 | public int run(String[] args) throws Exception {
47 | Job job = Job.getInstance(getConf());
48 | job.setJarByClass(WordCountMapReduceExample.class);
49 | job.setJobName("WordCountMapReduceExample");
50 |
51 | FileInputFormat.addInputPath(job, new Path(args[0]));
52 | FileOutputFormat.setOutputPath(job, new Path(args[1]));
53 |
54 | job.setMapperClass(TokenCounterMapper.class);
55 | job.setMapOutputKeyClass(Text.class);
56 | job.setMapOutputValueClass(IntWritable.class);
57 | job.setReducerClass(IntSumReducer.class);
58 | job.setOutputKeyClass(Text.class);
59 | job.setOutputValueClass(IntWritable.class);
60 | job.setNumReduceTasks(1);
61 |
62 |
63 | return job.waitForCompletion(true) ? 0 : 1;
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/statistics/AggStatsExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.statistics;
17 |
18 | import java.util.ArrayList;
19 | import java.util.List;
20 | import org.apache.commons.math3.stat.descriptive.AggregateSummaryStatistics;
21 | import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues;
22 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
23 |
24 | /**
25 | *
26 | * @author Michael Brzustowicz
27 | */
28 | public class AggStatsExample {
29 |
30 | /**
31 | * @param args the command line arguments
32 | */
33 | public static void main(String[] args) {
34 |
35 | List ls = new ArrayList<>();
36 |
37 | SummaryStatistics ss = new SummaryStatistics();
38 | ss.addValue(1.0);
39 | ss.addValue(11.0);
40 | ss.addValue(5.0);
41 |
42 | SummaryStatistics ss2 = new SummaryStatistics();
43 | ss2.addValue(2.0);
44 | ss2.addValue(12.0);
45 | ss2.addValue(6.0);
46 |
47 | SummaryStatistics ss3 = new SummaryStatistics();
48 | ss3.addValue(0.0);
49 | ss3.addValue(10.0);
50 | ss3.addValue(4.0);
51 |
52 |
53 | ls.add(ss);
54 | ls.add(ss2);
55 | ls.add(ss3);
56 |
57 | StatisticalSummaryValues s = AggregateSummaryStatistics.aggregate(ls);
58 |
59 | System.out.println(s);
60 |
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/statistics/AnscombeStatsExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.statistics;
17 |
18 | import com.oreilly.dswj.datasets.Anscombe;
19 | import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
20 |
21 | /**
22 | *
23 | * @author Michael Brzustowicz
24 | */
25 | public class AnscombeStatsExample {
26 |
27 | /**
28 | * @param args the command line arguments
29 | */
30 | public static void main(String[] args) {
31 | DescriptiveStatistics x1 = new DescriptiveStatistics(Anscombe.x1);
32 | DescriptiveStatistics y1 = new DescriptiveStatistics(Anscombe.y1);
33 | DescriptiveStatistics x2 = new DescriptiveStatistics(Anscombe.x2);
34 | DescriptiveStatistics y2 = new DescriptiveStatistics(Anscombe.y2);
35 | DescriptiveStatistics x3 = new DescriptiveStatistics(Anscombe.x3);
36 | DescriptiveStatistics y3 = new DescriptiveStatistics(Anscombe.y3);
37 | DescriptiveStatistics x4 = new DescriptiveStatistics(Anscombe.x4);
38 | DescriptiveStatistics y4 = new DescriptiveStatistics(Anscombe.y4);
39 |
40 | System.out.println(x1);
41 | System.out.println(y1);
42 | System.out.println(x2);
43 | System.out.println(y2);
44 | System.out.println(x3);
45 | System.out.println(y3);
46 | System.out.println(x4);
47 | System.out.println(y4);
48 |
49 | }
50 |
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/statistics/ContinuousDistributionPlot.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.statistics;
17 |
18 |
19 | import java.io.File;
20 | import java.io.IOException;
21 | import javafx.application.Application;
22 | import javafx.embed.swing.SwingFXUtils;
23 | import javafx.scene.Scene;
24 | import javafx.scene.SnapshotParameters;
25 | import javafx.scene.chart.LineChart;
26 | import javafx.scene.chart.NumberAxis;
27 | import javafx.scene.chart.XYChart;
28 | import javafx.scene.image.WritableImage;
29 | import javafx.stage.Stage;
30 | import javax.imageio.ImageIO;
31 | import org.apache.commons.math3.distribution.LogNormalDistribution;
32 | import org.apache.commons.math3.distribution.NormalDistribution;
33 | import org.apache.commons.math3.distribution.UniformRealDistribution;
34 |
35 | /**
36 | *
37 | * @author Michael Brzustowicz
38 | */
39 | public class ContinuousDistributionPlot extends Application {
40 |
41 | /**
42 | * @param args the command line arguments
43 | */
44 | public static void main(String[] args) {
45 | Application.launch(args);
46 | }
47 |
48 | @Override
49 | public void start(Stage primaryStage) throws Exception {
50 |
51 | /* make the dataset */
52 | NormalDistribution dist = new NormalDistribution();
53 | // LogNormalDistribution dist = new LogNormalDistribution();
54 | // UniformRealDistribution dist = new UniformRealDistribution();
55 |
56 | int n = 1000;
57 | double[] x = new double[n];
58 | double[] y = new double[n];
59 | double min = -5;
60 | double max = 5;
61 | double delta = max - min;
62 | for (int i = 0; i < n; i++) {
63 |
64 | x[i] = min + i * delta / n;
65 | y[i] = dist.density(x[i]); // this is for PDF
66 | // y[i] = dist.cumulativeProbability(x[i]); // this is for CDF
67 | }
68 |
69 | /* CREATE THE PLOT */
70 | XYChart.Series series = new XYChart.Series();
71 | for (int i = 0; i < x.length; i++) {
72 | series.getData().add(new XYChart.Data(x[i], y[i]));
73 | }
74 | NumberAxis xAxis = new NumberAxis("x", -5, 5, 1);
75 | NumberAxis yAxis = new NumberAxis();
76 | xAxis.setLabel("x");
77 | yAxis.setLabel("f(x)"); // this is for PDF
78 | // yAxis.setLabel("F(x)"); // this is for CDF
79 | xAxis.setMinorTickVisible(false);
80 | yAxis.setMinorTickVisible(false);
81 |
82 |
83 | LineChart lineChart = new LineChart<>(xAxis,yAxis);
84 | lineChart.setAnimated(false); // need this to save file
85 | lineChart.getData().addAll(series);
86 | lineChart.setBackground(null);
87 | lineChart.setLegendVisible(false);
88 | lineChart.setHorizontalGridLinesVisible(false);
89 | lineChart.setVerticalGridLinesVisible(false);
90 | lineChart.setVerticalZeroLineVisible(false);
91 | lineChart.setCreateSymbols(false);
92 |
93 | Scene scene = new Scene(lineChart,800,600);
94 | // scene.getStylesheets().add("css/");
95 | primaryStage.setScene(scene);
96 | primaryStage.show();
97 |
98 | /* uncomment below to write to file in addition to screen rendering */
99 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null);
100 |
101 | // TODO: probably use a file chooser here
102 | // File file = new File("plot.png");
103 |
104 | // try {
105 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file);
106 | // } catch (IOException e) {
107 | // TODO: handle exception here
108 | // }
109 |
110 |
111 | }
112 |
113 | }
114 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/statistics/DiscreteDistributionPlot.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.statistics;
17 |
18 |
19 | import java.io.File;
20 | import java.io.IOException;
21 | import javafx.application.Application;
22 | import javafx.embed.swing.SwingFXUtils;
23 | import javafx.scene.Scene;
24 | import javafx.scene.SnapshotParameters;
25 | import javafx.scene.chart.NumberAxis;
26 | import javafx.scene.chart.ScatterChart;
27 | import javafx.scene.chart.XYChart;
28 | import javafx.scene.image.WritableImage;
29 | import javafx.stage.Stage;
30 | import javax.imageio.ImageIO;
31 | import org.apache.commons.math3.distribution.BinomialDistribution;
32 | import org.apache.commons.math3.distribution.PoissonDistribution;
33 |
34 | /**
35 | *
36 | * @author Michael Brzustowicz
37 | */
38 | public class DiscreteDistributionPlot extends Application {
39 |
40 | /**
41 | * @param args the command line arguments
42 | */
43 | public static void main(String[] args) {
44 | Application.launch(args);
45 | }
46 |
47 | @Override
48 | public void start(Stage primaryStage) throws Exception {
49 | /* make the dataset */
50 |
51 | // BinomialDistribution dist = new BinomialDistribution(40, 0.5);
52 | PoissonDistribution dist = new PoissonDistribution(5.0);
53 | int n = 21;
54 | double[] x = new double[n];
55 | double[] y = new double[n];
56 | double min = 0;
57 | double max = 21;
58 | double delta = max - min;
59 | for (int i = 0; i < n; i++) {
60 | x[i] = i;
61 | y[i] = dist.probability(i); // PMF
62 | // y[i] = dist.cumulativeProbability(i); //CDF
63 | }
64 |
65 | /* */
66 | XYChart.Series series = new XYChart.Series();
67 | for (int i = 0; i < x.length; i++) {
68 | series.getData().add(new XYChart.Data(x[i], y[i]));
69 | }
70 |
71 | final NumberAxis xAxis = new NumberAxis("k", 0, 20, 5);
72 | final NumberAxis yAxis = new NumberAxis();
73 |
74 | yAxis.setLabel("f(k)");
75 | // xAxis.setTickLabelFont(new Font(12));
76 | // xAxis.setTickUnit(1.0 / 41.0);
77 | xAxis.setMinorTickVisible(false);
78 | yAxis.setMinorTickVisible(false);
79 |
80 | //discrete
81 | final ScatterChart chart = new ScatterChart<>(xAxis,yAxis);
82 |
83 | chart.setAnimated(false);
84 |
85 | chart.getData().addAll(series);
86 | chart.setBackground(null);
87 | chart.setLegendVisible(false);
88 |
89 | chart.setHorizontalGridLinesVisible(false);
90 | chart.setVerticalGridLinesVisible(false);
91 | chart.setVerticalZeroLineVisible(false);
92 | // lineChart.setCreateSymbols(false);
93 | // lineChart.setStyle(".chart-plot-background {-fx-background-color: #ffffff;}");
94 |
95 | Scene scene = new Scene(chart,800,600);
96 | scene.getStylesheets().add("css/chart_lineplot.css");
97 | primaryStage.setScene(scene);
98 | primaryStage.show();
99 |
100 | /* uncomment below to write to file in addition to screen rendering */
101 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null);
102 |
103 | // TODO: probably use a file chooser here
104 | // File file = new File("plot.png");
105 |
106 | // try {
107 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file);
108 | // } catch (IOException e) {
109 | // TODO: handle exception here
110 | // }
111 |
112 |
113 | }
114 |
115 | }
116 |
--------------------------------------------------------------------------------
/src/main/java/com/oreilly/dswj/statistics/EntropyPlotExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Michael Brzustowicz.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.oreilly.dswj.statistics;
17 |
18 | import java.io.File;
19 | import java.io.IOException;
20 | import javafx.application.Application;
21 | import javafx.embed.swing.SwingFXUtils;
22 | import javafx.scene.Scene;
23 | import javafx.scene.SnapshotParameters;
24 | import javafx.scene.chart.LineChart;
25 | import javafx.scene.chart.NumberAxis;
26 | import javafx.scene.chart.XYChart;
27 | import javafx.scene.image.WritableImage;
28 | import javafx.stage.Stage;
29 | import javax.imageio.ImageIO;
30 |
31 | /**
32 | *
33 | * @author Michael Brzustowicz
34 | */
35 | public class EntropyPlotExample extends Application {
36 |
37 | public static double entropy(double p) {
38 | double out = 0;
39 | if (p>0 && p<1){
40 | out = - ((1-p) * Math.log(1-p) + p * Math.log(p)) / Math.log(2);
41 | }
42 | return out;
43 | }
44 | /**
45 | * @param args the command line arguments
46 | */
47 | public static void main(String[] args) {
48 | Application.launch(args);
49 | }
50 |
51 | @Override
52 | public void start(Stage primaryStage) throws Exception {
53 |
54 | int n = 100;
55 | double[] p = new double[n];
56 | double[] h = new double[n];
57 | double delta = 1.0 / (n-1);
58 | for (int i = 0; i < n; i++) {
59 | p[i] = i * delta;
60 | h[i] = entropy(p[i]);
61 | }
62 |
63 | XYChart.Series series = new XYChart.Series();
64 |
65 | for (int i = 0; i < p.length; i++) {
66 | series.getData().add(new XYChart.Data(p[i], h[i]));
67 | }
68 |
69 | NumberAxis xAxis = new NumberAxis("p", 0, 1, 0.25);
70 | NumberAxis yAxis = new NumberAxis();
71 | yAxis.setLabel("entropy");
72 | xAxis.setMinorTickVisible(false);
73 |
74 | LineChart lineChart = new LineChart<>(xAxis,yAxis);
75 | lineChart.setAnimated(false);
76 | lineChart.getData().addAll(series);
77 | lineChart.setBackground(null);
78 | lineChart.setLegendVisible(false);
79 | lineChart.setHorizontalGridLinesVisible(false);
80 | lineChart.setVerticalGridLinesVisible(false);
81 | lineChart.setVerticalZeroLineVisible(false);
82 | lineChart.setCreateSymbols(false);
83 |
84 | Scene scene = new Scene(lineChart,800,600);
85 | scene.getStylesheets().add("css/chart_lineplot.css");
86 | primaryStage.setScene(scene);
87 | primaryStage.show();
88 |
89 | /* uncomment below to write to file in addition to screen rendering */
90 | // WritableImage image = lineChart.snapshot(new SnapshotParameters(), null);
91 |
92 |
93 | // TODO: probably use a file chooser here
94 | // File file = new File("entropy.png");
95 | //
96 | // try {
97 | // ImageIO.write(SwingFXUtils.fromFXImage(image, null), "png", file);
98 | // } catch (IOException e) {
99 | // // TODO: handle exception here
100 | // }
101 |
102 |
103 |
104 |
105 |
106 | }
107 |
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/resources/css/chart.css:
--------------------------------------------------------------------------------
1 | /*
2 | To change this license header, choose License Headers in Project Properties.
3 | To change this template file, choose Tools | Templates
4 | and open the template in the editor.
5 | */
6 | /*
7 | Created on : Apr 29, 2015, 6:23:44 PM
8 | Author : mbrzusto
9 | */
10 | .chart {
11 | /*-fx-background-color: red;*/
12 | }
13 |
14 | .chart-plot-background {
15 | -fx-background-color: #ffffff;
16 | }
17 |
18 | .default-color0.chart-series-line { -fx-stroke: transparent; }
19 | .default-color1.chart-series-line { -fx-stroke: blue; -fx-stroke-width: 1; }
20 | .default-color2.chart-series-line {
21 | -fx-stroke: blue;
22 | -fx-stroke-width: 1;
23 | -fx-stroke-dash-array: 1 4 1 4;
24 | }
25 | .default-color3.chart-series-line {
26 | -fx-stroke: blue;
27 | -fx-stroke-width: 1;
28 | -fx-stroke-dash-array: 1 4 1 4;
29 | }
30 |
31 | /*.default-color0.chart-line-symbol {
32 | -fx-background-color: white, green;
33 | }*/
34 | .default-color1.chart-line-symbol {
35 | -fx-background-color: transparent, transparent;
36 | }
37 | .default-color2.chart-line-symbol {
38 | -fx-background-color: transparent, transparent;
39 | }
40 | .default-color3.chart-line-symbol {
41 | -fx-background-color: transparent, transparent;
42 | }
43 | /*.default-color0.chart-legend-item-symbol{
44 | -fx-background-color: blue;
45 | }
46 | .default-color1.chart-legend-item-symbol{
47 | -fx-background-color: red;
48 | }
49 | .default-color2.chart-legend-item-symbol{
50 | -fx-background-color: black;
51 | }
52 | .default-color3.chart-legend-item-symbol{
53 | -fx-background-color: black;
54 | }*/
--------------------------------------------------------------------------------
/src/main/resources/css/chart_lineplot.css:
--------------------------------------------------------------------------------
1 | /*
2 | To change this license header, choose License Headers in Project Properties.
3 | To change this template file, choose Tools | Templates
4 | and open the template in the editor.
5 | */
6 | /*
7 | Created on : Apr 29, 2015, 6:23:44 PM
8 | Author : mbrzusto
9 | */
10 | /*.chart {
11 | -fx-background-color: red;
12 | }*/
13 |
14 | .chart-plot-background {
15 | -fx-background-color: #ffffff;
16 | }
17 |
18 | /*.axis {
19 | -fx-font-size: 1.6em;
20 | }*/
21 |
22 | /*.axis-label {
23 | -fx-font-size: 1.6em;
24 | }*/
25 |
26 | .default-color0.chart-series-line {
27 | -fx-stroke: black;
28 | -fx-stroke-width: 1;
29 | }
30 |
31 | .default-color1.chart-series-line {
32 | -fx-stroke: black;
33 | -fx-stroke-width: 1;
34 | -fx-stroke-dash-array: 1 4 1 4;
35 | }
36 |
37 | .default-color2.chart-series-line {
38 | -fx-stroke: black;
39 | -fx-stroke-width: 1;
40 | -fx-stroke-dash-array: 1 6 1 6;
41 | }
42 | .default-color3.chart-series-line {
43 | -fx-stroke: black;
44 | -fx-stroke-width: 1;
45 | -fx-stroke-dash-array: 1 8 1 8;
46 | }
47 |
48 | .default-color4.chart-series-line {
49 | -fx-stroke: black;
50 | -fx-stroke-width: 1;
51 | -fx-stroke-dash-array: 4 4 4 4;
52 | }
53 |
54 | .default-color0.chart-line-symbol {
55 | /*-fx-shape: "circle";*/
56 | -fx-background-color: green, white;
57 | }
58 |
59 | .default-color1.chart-line-symbol {
60 | -fx-background-color: transparent, transparent;
61 | }
62 | .default-color2.chart-line-symbol {
63 | -fx-background-color: transparent, transparent;
64 | }
65 | .default-color3.chart-line-symbol {
66 | -fx-background-color: transparent, transparent;
67 | }
68 | /*.default-color0.chart-legend-item-symbol{
69 | -fx-background-color: blue;
70 | }
71 | .default-color1.chart-legend-item-symbol{
72 | -fx-background-color: red;
73 | }
74 | .default-color2.chart-legend-item-symbol{
75 | -fx-background-color: black;
76 | }
77 | .default-color3.chart-legend-item-symbol{
78 | -fx-background-color: black;
79 | }*/
--------------------------------------------------------------------------------
/src/main/resources/css/data_with_fit_line.css:
--------------------------------------------------------------------------------
1 | /*
2 | To change this license header, choose License Headers in Project Properties.
3 | To change this template file, choose Tools | Templates
4 | and open the template in the editor.
5 | */
6 | /*
7 | Created on : Apr 29, 2015, 6:23:44 PM
8 | Author : mbrzusto
9 | */
10 | .chart {
11 | /*-fx-background-color: red;*/
12 | }
13 |
14 | .chart-plot-background {
15 | -fx-background-color: #ffffff;
16 | }
17 |
18 | .default-color0.chart-series-line { -fx-stroke: blue; -fx-stroke-width: 1; }
19 | .default-color1.chart-series-line { -fx-stroke: transparent; }
20 |
21 |
22 |
23 |
24 | .default-color0.chart-line-symbol {
25 | -fx-background-color: transparent, transparent;
26 | }
27 | /*.default-color1.chart-line-symbol {
28 | -fx-background-color: white, green;
29 | }*/
30 |
31 | /*.default-color0.chart-legend-item-symbol{
32 | -fx-background-color: blue;
33 | }
34 | .default-color1.chart-legend-item-symbol{
35 | -fx-background-color: red;
36 | }
37 | .default-color2.chart-legend-item-symbol{
38 | -fx-background-color: black;
39 | }
40 | .default-color3.chart-legend-item-symbol{
41 | -fx-background-color: black;
42 | }*/
--------------------------------------------------------------------------------
/src/main/resources/css/overlay-chart.css:
--------------------------------------------------------------------------------
1 | /** file: overlay-chart.css (place in same directory as LayeredXyChartsSample */
2 | .chart-plot-background {
3 | -fx-background-color: transparent;
4 | }
5 | .default-color0.chart-series-line {
6 | -fx-stroke: forestgreen;
7 | }
8 |
9 |
--------------------------------------------------------------------------------
/src/main/resources/datasets/iris/iris_data.csv:
--------------------------------------------------------------------------------
1 | 5.1,3.5,1.4,0.2,Iris-setosa
2 | 4.9,3.0,1.4,0.2,Iris-setosa
3 | 4.7,3.2,1.3,0.2,Iris-setosa
4 | 4.6,3.1,1.5,0.2,Iris-setosa
5 | 5.0,3.6,1.4,0.2,Iris-setosa
6 | 5.4,3.9,1.7,0.4,Iris-setosa
7 | 4.6,3.4,1.4,0.3,Iris-setosa
8 | 5.0,3.4,1.5,0.2,Iris-setosa
9 | 4.4,2.9,1.4,0.2,Iris-setosa
10 | 4.9,3.1,1.5,0.1,Iris-setosa
11 | 5.4,3.7,1.5,0.2,Iris-setosa
12 | 4.8,3.4,1.6,0.2,Iris-setosa
13 | 4.8,3.0,1.4,0.1,Iris-setosa
14 | 4.3,3.0,1.1,0.1,Iris-setosa
15 | 5.8,4.0,1.2,0.2,Iris-setosa
16 | 5.7,4.4,1.5,0.4,Iris-setosa
17 | 5.4,3.9,1.3,0.4,Iris-setosa
18 | 5.1,3.5,1.4,0.3,Iris-setosa
19 | 5.7,3.8,1.7,0.3,Iris-setosa
20 | 5.1,3.8,1.5,0.3,Iris-setosa
21 | 5.4,3.4,1.7,0.2,Iris-setosa
22 | 5.1,3.7,1.5,0.4,Iris-setosa
23 | 4.6,3.6,1.0,0.2,Iris-setosa
24 | 5.1,3.3,1.7,0.5,Iris-setosa
25 | 4.8,3.4,1.9,0.2,Iris-setosa
26 | 5.0,3.0,1.6,0.2,Iris-setosa
27 | 5.0,3.4,1.6,0.4,Iris-setosa
28 | 5.2,3.5,1.5,0.2,Iris-setosa
29 | 5.2,3.4,1.4,0.2,Iris-setosa
30 | 4.7,3.2,1.6,0.2,Iris-setosa
31 | 4.8,3.1,1.6,0.2,Iris-setosa
32 | 5.4,3.4,1.5,0.4,Iris-setosa
33 | 5.2,4.1,1.5,0.1,Iris-setosa
34 | 5.5,4.2,1.4,0.2,Iris-setosa
35 | 4.9,3.1,1.5,0.1,Iris-setosa
36 | 5.0,3.2,1.2,0.2,Iris-setosa
37 | 5.5,3.5,1.3,0.2,Iris-setosa
38 | 4.9,3.1,1.5,0.1,Iris-setosa
39 | 4.4,3.0,1.3,0.2,Iris-setosa
40 | 5.1,3.4,1.5,0.2,Iris-setosa
41 | 5.0,3.5,1.3,0.3,Iris-setosa
42 | 4.5,2.3,1.3,0.3,Iris-setosa
43 | 4.4,3.2,1.3,0.2,Iris-setosa
44 | 5.0,3.5,1.6,0.6,Iris-setosa
45 | 5.1,3.8,1.9,0.4,Iris-setosa
46 | 4.8,3.0,1.4,0.3,Iris-setosa
47 | 5.1,3.8,1.6,0.2,Iris-setosa
48 | 4.6,3.2,1.4,0.2,Iris-setosa
49 | 5.3,3.7,1.5,0.2,Iris-setosa
50 | 5.0,3.3,1.4,0.2,Iris-setosa
51 | 7.0,3.2,4.7,1.4,Iris-versicolor
52 | 6.4,3.2,4.5,1.5,Iris-versicolor
53 | 6.9,3.1,4.9,1.5,Iris-versicolor
54 | 5.5,2.3,4.0,1.3,Iris-versicolor
55 | 6.5,2.8,4.6,1.5,Iris-versicolor
56 | 5.7,2.8,4.5,1.3,Iris-versicolor
57 | 6.3,3.3,4.7,1.6,Iris-versicolor
58 | 4.9,2.4,3.3,1.0,Iris-versicolor
59 | 6.6,2.9,4.6,1.3,Iris-versicolor
60 | 5.2,2.7,3.9,1.4,Iris-versicolor
61 | 5.0,2.0,3.5,1.0,Iris-versicolor
62 | 5.9,3.0,4.2,1.5,Iris-versicolor
63 | 6.0,2.2,4.0,1.0,Iris-versicolor
64 | 6.1,2.9,4.7,1.4,Iris-versicolor
65 | 5.6,2.9,3.6,1.3,Iris-versicolor
66 | 6.7,3.1,4.4,1.4,Iris-versicolor
67 | 5.6,3.0,4.5,1.5,Iris-versicolor
68 | 5.8,2.7,4.1,1.0,Iris-versicolor
69 | 6.2,2.2,4.5,1.5,Iris-versicolor
70 | 5.6,2.5,3.9,1.1,Iris-versicolor
71 | 5.9,3.2,4.8,1.8,Iris-versicolor
72 | 6.1,2.8,4.0,1.3,Iris-versicolor
73 | 6.3,2.5,4.9,1.5,Iris-versicolor
74 | 6.1,2.8,4.7,1.2,Iris-versicolor
75 | 6.4,2.9,4.3,1.3,Iris-versicolor
76 | 6.6,3.0,4.4,1.4,Iris-versicolor
77 | 6.8,2.8,4.8,1.4,Iris-versicolor
78 | 6.7,3.0,5.0,1.7,Iris-versicolor
79 | 6.0,2.9,4.5,1.5,Iris-versicolor
80 | 5.7,2.6,3.5,1.0,Iris-versicolor
81 | 5.5,2.4,3.8,1.1,Iris-versicolor
82 | 5.5,2.4,3.7,1.0,Iris-versicolor
83 | 5.8,2.7,3.9,1.2,Iris-versicolor
84 | 6.0,2.7,5.1,1.6,Iris-versicolor
85 | 5.4,3.0,4.5,1.5,Iris-versicolor
86 | 6.0,3.4,4.5,1.6,Iris-versicolor
87 | 6.7,3.1,4.7,1.5,Iris-versicolor
88 | 6.3,2.3,4.4,1.3,Iris-versicolor
89 | 5.6,3.0,4.1,1.3,Iris-versicolor
90 | 5.5,2.5,4.0,1.3,Iris-versicolor
91 | 5.5,2.6,4.4,1.2,Iris-versicolor
92 | 6.1,3.0,4.6,1.4,Iris-versicolor
93 | 5.8,2.6,4.0,1.2,Iris-versicolor
94 | 5.0,2.3,3.3,1.0,Iris-versicolor
95 | 5.6,2.7,4.2,1.3,Iris-versicolor
96 | 5.7,3.0,4.2,1.2,Iris-versicolor
97 | 5.7,2.9,4.2,1.3,Iris-versicolor
98 | 6.2,2.9,4.3,1.3,Iris-versicolor
99 | 5.1,2.5,3.0,1.1,Iris-versicolor
100 | 5.7,2.8,4.1,1.3,Iris-versicolor
101 | 6.3,3.3,6.0,2.5,Iris-virginica
102 | 5.8,2.7,5.1,1.9,Iris-virginica
103 | 7.1,3.0,5.9,2.1,Iris-virginica
104 | 6.3,2.9,5.6,1.8,Iris-virginica
105 | 6.5,3.0,5.8,2.2,Iris-virginica
106 | 7.6,3.0,6.6,2.1,Iris-virginica
107 | 4.9,2.5,4.5,1.7,Iris-virginica
108 | 7.3,2.9,6.3,1.8,Iris-virginica
109 | 6.7,2.5,5.8,1.8,Iris-virginica
110 | 7.2,3.6,6.1,2.5,Iris-virginica
111 | 6.5,3.2,5.1,2.0,Iris-virginica
112 | 6.4,2.7,5.3,1.9,Iris-virginica
113 | 6.8,3.0,5.5,2.1,Iris-virginica
114 | 5.7,2.5,5.0,2.0,Iris-virginica
115 | 5.8,2.8,5.1,2.4,Iris-virginica
116 | 6.4,3.2,5.3,2.3,Iris-virginica
117 | 6.5,3.0,5.5,1.8,Iris-virginica
118 | 7.7,3.8,6.7,2.2,Iris-virginica
119 | 7.7,2.6,6.9,2.3,Iris-virginica
120 | 6.0,2.2,5.0,1.5,Iris-virginica
121 | 6.9,3.2,5.7,2.3,Iris-virginica
122 | 5.6,2.8,4.9,2.0,Iris-virginica
123 | 7.7,2.8,6.7,2.0,Iris-virginica
124 | 6.3,2.7,4.9,1.8,Iris-virginica
125 | 6.7,3.3,5.7,2.1,Iris-virginica
126 | 7.2,3.2,6.0,1.8,Iris-virginica
127 | 6.2,2.8,4.8,1.8,Iris-virginica
128 | 6.1,3.0,4.9,1.8,Iris-virginica
129 | 6.4,2.8,5.6,2.1,Iris-virginica
130 | 7.2,3.0,5.8,1.6,Iris-virginica
131 | 7.4,2.8,6.1,1.9,Iris-virginica
132 | 7.9,3.8,6.4,2.0,Iris-virginica
133 | 6.4,2.8,5.6,2.2,Iris-virginica
134 | 6.3,2.8,5.1,1.5,Iris-virginica
135 | 6.1,2.6,5.6,1.4,Iris-virginica
136 | 7.7,3.0,6.1,2.3,Iris-virginica
137 | 6.3,3.4,5.6,2.4,Iris-virginica
138 | 6.4,3.1,5.5,1.8,Iris-virginica
139 | 6.0,3.0,4.8,1.8,Iris-virginica
140 | 6.9,3.1,5.4,2.1,Iris-virginica
141 | 6.7,3.1,5.6,2.4,Iris-virginica
142 | 6.9,3.1,5.1,2.3,Iris-virginica
143 | 5.8,2.7,5.1,1.9,Iris-virginica
144 | 6.8,3.2,5.9,2.3,Iris-virginica
145 | 6.7,3.3,5.7,2.5,Iris-virginica
146 | 6.7,3.0,5.2,2.3,Iris-virginica
147 | 6.3,2.5,5.0,1.9,Iris-virginica
148 | 6.5,3.0,5.2,2.0,Iris-virginica
149 | 6.2,3.4,5.4,2.3,Iris-virginica
150 | 5.9,3.0,5.1,1.8,Iris-virginica
151 |
--------------------------------------------------------------------------------
/src/main/resources/datasets/mnist/README:
--------------------------------------------------------------------------------
1 | download the following files
2 |
3 | train-images-idx3-ubyte
4 | train-labels-idx1-ubyte
5 | t10k-images-idx3-ubyte
6 | t10k-labels-idx1-ubyte
7 |
8 | from
9 |
10 | http://yann.lecun.com/exdb/mnist/
11 |
12 | and put them in this directory
--------------------------------------------------------------------------------