accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0));
96 |
97 | // Run the graph
98 | try (Session session = new Session(graph)) {
99 |
100 | // Train the model
101 | for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
102 | try (TFloat32 batchImages = preprocessImages(trainingBatch.images());
103 | TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) {
104 | session.runner()
105 | .addTarget(minimize)
106 | .feed(images.asOutput(), batchImages)
107 | .feed(labels.asOutput(), batchLabels)
108 | .run();
109 | }
110 | }
111 |
112 | // Test the model
113 | ImageBatch testBatch = dataset.testBatch();
114 | try (TFloat32 testImages = preprocessImages(testBatch.images());
115 | TFloat32 testLabels = preprocessLabels(testBatch.labels());
116 | var result = session.runner()
117 | .fetch(accuracy)
118 | .feed(images.asOutput(), testImages)
119 | .feed(labels.asOutput(), testLabels)
120 | .run()) {
121 | TFloat32 accuracyValue = (TFloat32) result.get(0);
122 | System.out.println("Accuracy: " + accuracyValue.getFloat());
123 | }
124 | }
125 | }
126 |
127 | private static final int VALIDATION_SIZE = 0;
128 | private static final int TRAINING_BATCH_SIZE = 100;
129 | private static final float LEARNING_RATE = 0.2f;
130 |
131 | private static TFloat32 preprocessImages(ByteNdArray rawImages) {
132 | Ops tf = Ops.create();
133 |
134 | // Flatten images in a single dimension and normalize their pixels as floats.
135 | long imageSize = rawImages.get(0).shape().size();
136 | return tf.math.div(
137 | tf.reshape(
138 | tf.dtypes.cast(tf.constant(rawImages), TFloat32.class),
139 | tf.array(-1L, imageSize)
140 | ),
141 | tf.constant(255.0f)
142 | ).asTensor();
143 | }
144 |
145 | private static TFloat32 preprocessLabels(ByteNdArray rawLabels) {
146 | Ops tf = Ops.create();
147 |
148 | // Map labels to one hot vectors where only the expected predictions as a value of 1.0
149 | return tf.oneHot(
150 | tf.constant(rawLabels),
151 | tf.constant(MnistDataset.NUM_CLASSES),
152 | tf.constant(1.0f),
153 | tf.constant(0.0f)
154 | ).asTensor();
155 | }
156 |
157 | private final Graph graph;
158 | private final MnistDataset dataset;
159 |
160 | private SimpleMnist(Graph graph, MnistDataset dataset) {
161 | this.graph = graph;
162 | this.dataset = dataset;
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =======================================================================
16 | */
17 | package org.tensorflow.model.examples.regression.linear;
18 |
19 | import java.util.List;
20 | import java.util.Random;
21 | import org.tensorflow.Graph;
22 | import org.tensorflow.Session;
23 | import org.tensorflow.framework.optimizers.GradientDescent;
24 | import org.tensorflow.framework.optimizers.Optimizer;
25 | import org.tensorflow.ndarray.Shape;
26 | import org.tensorflow.op.Op;
27 | import org.tensorflow.op.Ops;
28 | import org.tensorflow.op.core.Placeholder;
29 | import org.tensorflow.op.core.Variable;
30 | import org.tensorflow.op.math.Add;
31 | import org.tensorflow.op.math.Div;
32 | import org.tensorflow.op.math.Mul;
33 | import org.tensorflow.op.math.Pow;
34 | import org.tensorflow.types.TFloat32;
35 |
36 | /**
37 | * In this example TensorFlow finds the weight and bias of the linear regression during 1 epoch,
38 | * training on observations one by one.
39 | *
40 | * Also, the weight and bias are extracted and printed.
41 | */
42 | public class LinearRegressionExample {
43 | /**
44 | * Amount of data points.
45 | */
46 | private static final int N = 10;
47 |
48 | /**
49 | * This value is used to fill the Y placeholder in prediction.
50 | */
51 | public static final float LEARNING_RATE = 0.1f;
52 | public static final String WEIGHT_VARIABLE_NAME = "weight";
53 | public static final String BIAS_VARIABLE_NAME = "bias";
54 |
55 | public static void main(String[] args) {
56 | // Prepare the data
57 | float[] xValues = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f};
58 | float[] yValues = new float[N];
59 |
60 | Random rnd = new Random(42);
61 |
62 | for (int i = 0; i < yValues.length; i++) {
63 | yValues[i] = (float) (10 * xValues[i] + 2 + 0.1 * (rnd.nextDouble() - 0.5));
64 | }
65 |
66 | try (Graph graph = new Graph()) {
67 | Ops tf = Ops.create(graph);
68 |
69 | // Define placeholders
70 | Placeholder xData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar()));
71 | Placeholder yData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar()));
72 |
73 | // Define variables
74 | Variable weight = tf.withName(WEIGHT_VARIABLE_NAME).variable(tf.constant(1f));
75 | Variable bias = tf.withName(BIAS_VARIABLE_NAME).variable(tf.constant(1f));
76 |
77 | // Define the model function weight*x + bias
78 | Mul mul = tf.math.mul(xData, weight);
79 | Add yPredicted = tf.math.add(mul, bias);
80 |
81 | // Define loss function MSE
82 | Pow sum = tf.math.pow(tf.math.sub(yPredicted, yData), tf.constant(2f));
83 | Div mse = tf.math.div(sum, tf.constant(2f * N));
84 |
85 | // Back-propagate gradients to variables for training
86 | Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE);
87 | Op minimize = optimizer.minimize(mse);
88 |
89 | try (Session session = new Session(graph)) {
90 |
91 | // Train the model on data
92 | for (int i = 0; i < xValues.length; i++) {
93 | float y = yValues[i];
94 | float x = xValues[i];
95 |
96 | try (TFloat32 xTensor = TFloat32.scalarOf(x);
97 | TFloat32 yTensor = TFloat32.scalarOf(y)) {
98 |
99 | session.runner()
100 | .addTarget(minimize)
101 | .feed(xData.asOutput(), xTensor)
102 | .feed(yData.asOutput(), yTensor)
103 | .run();
104 |
105 | System.out.println("Training phase");
106 | System.out.println("x is " + x + " y is " + y);
107 | }
108 | }
109 |
110 | // Extract linear regression model weight and bias values
111 | try (var result = session.runner()
112 | .fetch(WEIGHT_VARIABLE_NAME)
113 | .fetch(BIAS_VARIABLE_NAME)
114 | .run()) {
115 | System.out.println("Weight is " + result.get(WEIGHT_VARIABLE_NAME));
116 | System.out.println("Bias is " + result.get(BIAS_VARIABLE_NAME));
117 | }
118 |
119 | // Let's predict y for x = 10f
120 | float x = 10f;
121 | float predictedY = 0f;
122 |
123 | try (TFloat32 xTensor = TFloat32.scalarOf(x);
124 | TFloat32 yTensor = TFloat32.scalarOf(predictedY);
125 | TFloat32 yPredictedTensor = (TFloat32)session.runner()
126 | .feed(xData.asOutput(), xTensor)
127 | .feed(yData.asOutput(), yTensor)
128 | .fetch(yPredicted)
129 | .run().get(0)) {
130 |
131 | predictedY = yPredictedTensor.getFloat();
132 |
133 | System.out.println("Predicted value: " + predictedY);
134 | }
135 | }
136 | }
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | * =======================================================================
16 | */
17 | package org.tensorflow.model.examples.tensors;
18 |
19 | import org.tensorflow.ndarray.Shape;
20 | import org.tensorflow.ndarray.IntNdArray;
21 | import org.tensorflow.ndarray.NdArrays;
22 | import org.tensorflow.types.TInt32;
23 |
24 | import java.util.Arrays;
25 |
26 | /**
27 | * Creates a few tensors of ranks: 0, 1, 2, 3.
28 | */
29 | public class TensorCreation {
30 |
31 | public static void main(String[] args) {
32 | // Rank 0 Tensor
33 | TInt32 rank0Tensor = TInt32.scalarOf(42);
34 |
35 | System.out.println("---- Scalar tensor ---------");
36 |
37 | System.out.println("DataType: " + rank0Tensor.dataType().name());
38 |
39 | System.out.println("Rank: " + rank0Tensor.shape().size());
40 |
41 | System.out.println("Shape: " + Arrays.toString(rank0Tensor.shape().asArray()));
42 |
43 | rank0Tensor.scalars().forEach(value -> System.out.println("Value: " + value.getObject()));
44 |
45 | // Rank 1 Tensor
46 | TInt32 rank1Tensor = TInt32.vectorOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
47 |
48 | System.out.println("---- Vector tensor ---------");
49 |
50 | System.out.println("DataType: " + rank1Tensor.dataType().name());
51 |
52 | System.out.println("Rank: " + rank1Tensor.shape().size());
53 |
54 | System.out.println("Shape: " + Arrays.toString(rank1Tensor.shape().asArray()));
55 |
56 | System.out.println("6th element: " + rank1Tensor.getInt(5));
57 |
58 | // Rank 2 Tensor
59 | // 3x2 matrix of ints.
60 | IntNdArray matrix2d = NdArrays.ofInts(Shape.of(3, 2));
61 |
62 | matrix2d.set(NdArrays.vectorOf(1, 2), 0)
63 | .set(NdArrays.vectorOf(3, 4), 1)
64 | .set(NdArrays.vectorOf(5, 6), 2);
65 |
66 | TInt32 rank2Tensor = TInt32.tensorOf(matrix2d);
67 |
68 | System.out.println("---- Matrix tensor ---------");
69 |
70 | System.out.println("DataType: " + rank2Tensor.dataType().name());
71 |
72 | System.out.println("Rank: " + rank2Tensor.shape().size());
73 |
74 | System.out.println("Shape: " + Arrays.toString(rank2Tensor.shape().asArray()));
75 |
76 | System.out.println("6th element: " + rank2Tensor.getInt(2, 1));
77 |
78 | // Rank 3 Tensor
79 | // 3*2*4 matrix of ints.
80 | IntNdArray matrix3d = NdArrays.ofInts(Shape.of(3, 2, 4));
81 |
82 | matrix3d.elements(0).forEach(matrix -> {
83 | matrix
84 | .set(NdArrays.vectorOf(1, 2, 3, 4), 0)
85 | .set(NdArrays.vectorOf(5, 6, 7, 8), 1);
86 | });
87 |
88 | TInt32 rank3Tensor = TInt32.tensorOf(matrix3d);
89 |
90 | System.out.println("---- Matrix tensor ---------");
91 |
92 | System.out.println("DataType: " + rank3Tensor.dataType().name());
93 |
94 | System.out.println("Rank: " + rank3Tensor.shape().size());
95 |
96 | System.out.println("Shape: " + Arrays.toString(rank3Tensor.shape().asArray()));
97 |
98 | System.out.println("n-th element: " + rank3Tensor.getInt(2, 1, 3));
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/MANIFEST.MF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/META-INF/MANIFEST.MF
--------------------------------------------------------------------------------
/src/main/resources/fashionmnist/Readme.md:
--------------------------------------------------------------------------------
1 | This dataset is distributed under MIT License and presented in next paper.
2 | Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747
3 |
4 | The data was downloaded from the FashionMnist Repository
5 | https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion
6 |
7 |
--------------------------------------------------------------------------------
/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/fasterrcnninception/image2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fasterrcnninception/image2.jpg
--------------------------------------------------------------------------------
/src/main/resources/fasterrcnninception/image2rcnn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fasterrcnninception/image2rcnn.jpg
--------------------------------------------------------------------------------
/src/main/resources/mnist/t10k-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/t10k-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/mnist/train-images-idx3-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/src/main/resources/mnist/train-labels-idx1-ubyte.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/train-labels-idx1-ubyte.gz
--------------------------------------------------------------------------------