├── .gitignore ├── README.md ├── pom.xml └── src └── main ├── java └── org │ └── tensorflow │ └── model │ └── examples │ ├── cnn │ ├── fastrcnn │ │ ├── FasterRcnnInception.java │ │ └── Readme.md │ ├── lenet │ │ └── CnnMnist.java │ └── vgg │ │ ├── VGG11OnFashionMnist.java │ │ └── VGGModel.java │ ├── datasets │ ├── ImageBatch.java │ ├── ImageBatchIterator.java │ └── mnist │ │ └── MnistDataset.java │ ├── dense │ └── SimpleMnist.java │ ├── regression │ └── linear │ │ └── LinearRegressionExample.java │ └── tensors │ └── TensorCreation.java └── resources ├── META-INF └── MANIFEST.MF ├── fashionmnist ├── Readme.md ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── fasterrcnninception ├── image2.jpg └── image2rcnn.jpg └── mnist ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.jar 3 | target/ 4 | bin/ 5 | build/ 6 | .idea/ 7 | .idea_modules/ 8 | *.iml 9 | .settings/ 10 | bin/ 11 | tmp/ 12 | .metadata 13 | .classpath 14 | .project 15 | dist/ 16 | .DS_Store 17 | models/ 18 | outputs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Java Examples 2 | 3 | This repository contains examples for [TensorFlow-Java](https://github.com/tensorflow/java). 4 | 5 | ## Example Models 6 | 7 | There are five example models: a LeNet CNN, a VGG CNN, inference using Faster-RCNN, a linear regression and a logistic regression. 8 | 9 | ### Faster-RCNN 10 | 11 | The Faster-RCNN inference example is in `org.tensorflow.model.examples.cnn.fastrcnn`. 12 | 13 | Download the model from https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1 14 | 15 | Unzip then untar the model to a local folder - I've used models/faster_rcnn_inception_resnet_v2_1024x1024. 16 | 17 | Create a testimages folder then add some test images into a testimages folder 18 | 19 | To run the example add the input image and output image as parameters: 20 | 21 | ```shell 22 | java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-jar-with-dependencies.jar org.tensorflow.model.examples.cnn.fastrcnn.FasterRcnnInception testimages/image2.jpg image2rcnn.jpg 23 | ``` 24 | 25 | ### LeNet CNN 26 | 27 | The LeNet example runs on MNIST which is stored in the project's resource directory. It is found in 28 | `org.tensorflow.model.examples.cnn.lenet`, and can be run with: 29 | 30 | ```shell 31 | java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.cnn.lenet.CnnMnist 32 | ``` 33 | 34 | ### VGG 35 | 36 | The VGG11 example runs on FashionMNIST, stored in the project's resource directory. It is found in 37 | `org.tensorflow.model.examples.cnn.vgg`, and can be run with: 38 | 39 | ```shell 40 | java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.cnn.vgg.VGG11OnFashionMnist 41 | ``` 42 | 43 | ### Linear Regression 44 | 45 | The linear regression example runs on hard coded data. It is found in `org.tensorflow.model.examples.regression.linear` 46 | and can be run with: 47 | 48 | ```shell 49 | java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.regression.linear.LinearRegressionExample 50 | ``` 51 | 52 | ### Logistic Regression 53 | 54 | The logistic regression example runs on MNIST, stored in the project's resource directory. It is found in 55 | `org.tensorflow.model.examples.dense.SimpleMnist`, and can be run with: 56 | 57 | ```shell 58 | java -cp target/tensorflow-examples-1.0.0-tfj-1.0.0-rc.2-with-dependencies.jar org.tensorflow.model.examples.dense.SimpleMnist 59 | ``` 60 | 61 | ## Contributions 62 | 63 | Contributions of other example models are welcome, for instructions please see the 64 | [Contributor guidelines](https://github.com/tensorflow/java/blob/master/CONTRIBUTING.md) in TensorFlow-Java. 65 | 66 | ## Development 67 | 68 | This repository tracks TensorFlow-Java and the head will be updated with new releases of TensorFlow-Java. 69 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 5 | 6 | 4.0.0 7 | 8 | org.tensorflow 9 | tensorflow-examples 10 | 1.0.0-tfj-1.0.0-rc.2 11 | 12 | TensorFlow Examples 13 | A suite of executable examples using TensorFlow Java 14 | 15 | 16 | 17 | 18 | 17 19 | 17 20 | 17 21 | 1.0.0-rc.2 22 | 23 | 24 | 25 | 26 | org.tensorflow 27 | tensorflow-core-platform 28 | ${tensorflow.version} 29 | 30 | 31 | org.tensorflow 32 | tensorflow-framework 33 | ${tensorflow.version} 34 | 35 | 36 | 37 | 38 | 39 | 40 | org.apache.maven.plugins 41 | maven-assembly-plugin 42 | 43 | 44 | package 45 | 46 | single 47 | 48 | 49 | 50 | 51 | 52 | org.tensorflow.model.examples.dense.SimpleMnist 53 | 54 | 55 | 56 | 57 | jar-with-dependencies 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/FasterRcnnInception.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021, 2024 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | 18 | package org.tensorflow.model.examples.cnn.fastrcnn; 19 | /* 20 | 21 | From the web page this is the output dictionary 22 | 23 | num_detections: a tf.int tensor with only one value, the number of detections [N]. 24 | detection_boxes: a tf.float32 tensor of shape [N, 4] containing bounding box coordinates in the following order: [ymin, xmin, ymax, xmax]. 25 | detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. 26 | detection_scores: a tf.float32 tensor of shape [N] containing detection scores. 27 | raw_detection_boxes: a tf.float32 tensor of shape [1, M, 4] containing decoded detection boxes without Non-Max suppression. M is the number of raw detections. 28 | raw_detection_scores: a tf.float32 tensor of shape [1, M, 90] and contains class score logits for raw detection boxes. M is the number of raw detections. 29 | detection_anchor_indices: a tf.float32 tensor of shape [N] and contains the anchor indices of the detections after NMS. 30 | detection_multiclass_scores: a tf.float32 tensor of shape [1, N, 90] and contains class score distribution (including background) for detection boxes in the image including background class. 31 | 32 | However using 33 | venv\Scripts\python.exe venv\Lib\site-packages\tensorflow\python\tools\saved_model_cli.py show --dir models\faster_rcnn_inception_resnet_v2_1024x1024 --all 34 | 2021-03-19 12:25:37.000143: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudart64_110.dll 35 | 36 | MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: 37 | 38 | signature_def['__saved_model_init_op']: 39 | The given SavedModel SignatureDef contains the following input(s): 40 | The given SavedModel SignatureDef contains the following output(s): 41 | outputs['__saved_model_init_op'] tensor_info: 42 | dtype: DT_INVALID 43 | shape: unknown_rank 44 | name: NoOp 45 | Method name is: 46 | 47 | signature_def['serving_default']: 48 | The given SavedModel SignatureDef contains the following input(s): 49 | inputs['input_tensor'] tensor_info: 50 | dtype: DT_UINT8 51 | shape: (1, -1, -1, 3) 52 | name: serving_default_input_tensor:0 53 | The given SavedModel SignatureDef contains the following output(s): 54 | outputs['detection_anchor_indices'] tensor_info: 55 | dtype: DT_FLOAT 56 | shape: (1, 300) 57 | name: StatefulPartitionedCall:0 58 | outputs['detection_boxes'] tensor_info: 59 | dtype: DT_FLOAT 60 | shape: (1, 300, 4) 61 | name: StatefulPartitionedCall:1 62 | outputs['detection_classes'] tensor_info: 63 | dtype: DT_FLOAT 64 | shape: (1, 300) 65 | name: StatefulPartitionedCall:2 66 | outputs['detection_multiclass_scores'] tensor_info: 67 | dtype: DT_FLOAT 68 | shape: (1, 300, 91) 69 | name: StatefulPartitionedCall:3 70 | outputs['detection_scores'] tensor_info: 71 | dtype: DT_FLOAT 72 | shape: (1, 300) 73 | name: StatefulPartitionedCall:4 74 | outputs['num_detections'] tensor_info: 75 | dtype: DT_FLOAT 76 | shape: (1) 77 | name: StatefulPartitionedCall:5 78 | outputs['raw_detection_boxes'] tensor_info: 79 | dtype: DT_FLOAT 80 | shape: (1, 300, 4) 81 | name: StatefulPartitionedCall:6 82 | outputs['raw_detection_scores'] tensor_info: 83 | dtype: DT_FLOAT 84 | shape: (1, 300, 91) 85 | name: StatefulPartitionedCall:7 86 | Method name is: tensorflow/serving/predict 87 | 88 | Defined Functions: 89 | Function Name: '__call__' 90 | Option #1 91 | Callable with: 92 | Argument #1 93 | input_tensor: TensorSpec(shape=(1, None, None, 3), dtype=tf.uint8, name='input_tensor') 94 | 95 | So it appears there's a discrepancy between the web page and running saved_model_cli as 96 | num_detections: a tf.int tensor with only one value, the number of detections [N]. 97 | but the actual tensor is DT_FLOAT according to saved_model_cli 98 | also the web page states 99 | detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file. 100 | but again the actual tensor is DT_FLOAT according to saved_model_cli. 101 | */ 102 | 103 | 104 | import java.util.ArrayList; 105 | import java.util.HashMap; 106 | import java.util.Map; 107 | import java.util.TreeMap; 108 | import org.tensorflow.Graph; 109 | import org.tensorflow.Operand; 110 | import org.tensorflow.Result; 111 | import org.tensorflow.SavedModelBundle; 112 | import org.tensorflow.Session; 113 | import org.tensorflow.Tensor; 114 | import org.tensorflow.ndarray.FloatNdArray; 115 | import org.tensorflow.ndarray.Shape; 116 | import org.tensorflow.op.Ops; 117 | import org.tensorflow.op.core.Constant; 118 | import org.tensorflow.op.core.Placeholder; 119 | import org.tensorflow.op.core.Reshape; 120 | import org.tensorflow.op.image.DecodeJpeg; 121 | import org.tensorflow.op.image.EncodeJpeg; 122 | import org.tensorflow.op.io.ReadFile; 123 | import org.tensorflow.op.io.WriteFile; 124 | import org.tensorflow.types.TFloat32; 125 | import org.tensorflow.types.TString; 126 | import org.tensorflow.types.TUint8; 127 | 128 | 129 | /** 130 | * Loads an image using ReadFile and DecodeJpeg and then uses the saved model 131 | * faster_rcnn/inception_resnet_v2_1024x1024/1 to detect objects with a detection score greater than 0.3 132 | * Uses the DrawBounding boxes 133 | */ 134 | public class FasterRcnnInception { 135 | 136 | private final static String[] cocoLabels = new String[]{ 137 | "person", 138 | "bicycle", 139 | "car", 140 | "motorcycle", 141 | "airplane", 142 | "bus", 143 | "train", 144 | "truck", 145 | "boat", 146 | "traffic light", 147 | "fire hydrant", 148 | "street sign", 149 | "stop sign", 150 | "parking meter", 151 | "bench", 152 | "bird", 153 | "cat", 154 | "dog", 155 | "horse", 156 | "sheep", 157 | "cow", 158 | "elephant", 159 | "bear", 160 | "zebra", 161 | "giraffe", 162 | "hat", 163 | "backpack", 164 | "umbrella", 165 | "shoe", 166 | "eye glasses", 167 | "handbag", 168 | "tie", 169 | "suitcase", 170 | "frisbee", 171 | "skis", 172 | "snowboard", 173 | "sports ball", 174 | "kite", 175 | "baseball bat", 176 | "baseball glove", 177 | "skateboard", 178 | "surfboard", 179 | "tennis racket", 180 | "bottle", 181 | "plate", 182 | "wine glass", 183 | "cup", 184 | "fork", 185 | "knife", 186 | "spoon", 187 | "bowl", 188 | "banana", 189 | "apple", 190 | "sandwich", 191 | "orange", 192 | "broccoli", 193 | "carrot", 194 | "hot dog", 195 | "pizza", 196 | "donut", 197 | "cake", 198 | "chair", 199 | "couch", 200 | "potted plant", 201 | "bed", 202 | "mirror", 203 | "dining table", 204 | "window", 205 | "desk", 206 | "toilet", 207 | "door", 208 | "tv", 209 | "laptop", 210 | "mouse", 211 | "remote", 212 | "keyboard", 213 | "cell phone", 214 | "microwave", 215 | "oven", 216 | "toaster", 217 | "sink", 218 | "refrigerator", 219 | "blender", 220 | "book", 221 | "clock", 222 | "vase", 223 | "scissors", 224 | "teddy bear", 225 | "hair drier", 226 | "toothbrush", 227 | "hair brush" 228 | }; 229 | 230 | public static void main(String[] params) { 231 | String outputImagePath; 232 | String imagePath; 233 | 234 | if (params.length == 0) { 235 | imagePath = "src/main/resources/fasterrcnninception/image2.jpg"; 236 | outputImagePath = "outputs/image2rcnn.jpg"; 237 | 238 | } else if (params.length == 2) { 239 | imagePath = params[0]; 240 | outputImagePath = params[1]; 241 | 242 | } else { 243 | throw new IllegalArgumentException("Exactly 0 or 2 parameters required: java FasterRcnnInception [ ]"); 244 | } 245 | // get path to model folder 246 | String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024"; 247 | // load saved model 248 | SavedModelBundle model = SavedModelBundle.load(modelPath, "serve"); 249 | //create a map of the COCO 2017 labels 250 | TreeMap cocoTreeMap = new TreeMap<>(); 251 | float cocoCount = 0; 252 | for (String cocoLabel : cocoLabels) { 253 | cocoTreeMap.put(cocoCount, cocoLabel); 254 | cocoCount++; 255 | } 256 | try (Graph g = new Graph(); Session s = new Session(g)) { 257 | Ops tf = Ops.create(g); 258 | Constant fileName = tf.constant(imagePath); 259 | ReadFile readFile = tf.io.readFile(fileName); 260 | Session.Runner runner = s.runner(); 261 | DecodeJpeg.Options options = DecodeJpeg.channels(3L); 262 | DecodeJpeg decodeImage = tf.image.decodeJpeg(readFile.contents(), options); 263 | //fetch image from file 264 | Shape imageShape; 265 | try (var shapeResult = runner.fetch(decodeImage).run()) { 266 | imageShape = shapeResult.get(0).shape(); 267 | } 268 | //reshape the tensor to 4D for input to model 269 | Reshape reshape = tf.reshape(decodeImage, 270 | tf.array(1, 271 | imageShape.asArray()[0], 272 | imageShape.asArray()[1], 273 | imageShape.asArray()[2] 274 | ) 275 | ); 276 | try (var reshapeResult = s.runner().fetch(reshape).run()) { 277 | TUint8 reshapeTensor = (TUint8) reshapeResult.get(0); 278 | Map feedDict = new HashMap<>(); 279 | //The given SavedModel SignatureDef input 280 | feedDict.put("input_tensor", reshapeTensor); 281 | //The given SavedModel MetaGraphDef key 282 | //detection_classes, detectionBoxes etc. are model output names 283 | try (Result outputTensorMap = model.function("serving_default").call(feedDict)) { 284 | TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections").get(); 285 | int numDetects = (int) numDetections.getFloat(0); 286 | if (numDetects > 0) { 287 | TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes").get(); 288 | TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores").get(); 289 | ArrayList boxArray = new ArrayList<>(); 290 | //TODO tf.image.combinedNonMaxSuppression 291 | for (int n = 0; n < numDetects; n++) { 292 | //put probability and position in outputMap 293 | float detectionScore = detectionScores.getFloat(0, n); 294 | //only include those classes with detection score greater than 0.3f 295 | if (detectionScore > 0.3f) { 296 | boxArray.add(detectionBoxes.get(0, n)); 297 | } 298 | } 299 | /* These values are also returned by the FasterRCNN, but we don't use them in this example. 300 | * TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes").get(); 301 | * TFloat32 rawDetectionBoxes = (TFloat32) outputTensorMap.get("raw_detection_boxes").get(); 302 | * TFloat32 rawDetectionScores = (TFloat32) outputTensorMap.get("raw_detection_scores").get(); 303 | * TFloat32 detectionAnchorIndices = (TFloat32) outputTensorMap.get("detection_anchor_indices").get(); 304 | * TFloat32 detectionMulticlassScores = (TFloat32) outputTensorMap.get("detection_multiclass_scores").get(); 305 | */ 306 | //2-D. A list of RGBA colors to cycle through for the boxes. 307 | Operand colors = tf.constant(new float[][]{ 308 | {0.9f, 0.3f, 0.3f, 0.0f}, 309 | {0.3f, 0.3f, 0.9f, 0.0f}, 310 | {0.3f, 0.9f, 0.3f, 0.0f} 311 | }); 312 | Shape boxesShape = Shape.of(1, boxArray.size(), 4); 313 | int boxCount = 0; 314 | //3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding boxes 315 | try (TFloat32 boxes = TFloat32.tensorOf(boxesShape)) { 316 | //batch size of 1 317 | boxes.setFloat(1, 0, 0, 0); 318 | for (FloatNdArray floatNdArray : boxArray) { 319 | boxes.set(floatNdArray, 0, boxCount); 320 | boxCount++; 321 | } 322 | //Placeholders for boxes and path to outputimage 323 | Placeholder boxesPlaceHolder = tf.placeholder(TFloat32.class, Placeholder.shape(boxesShape)); 324 | Placeholder outImagePathPlaceholder = tf.placeholder(TString.class); 325 | //Create JPEG from the Tensor with quality of 100% 326 | EncodeJpeg.Options jpgOptions = EncodeJpeg.quality(100L); 327 | //convert the 4D input image to normalised 0.0f - 1.0f 328 | //Draw bounding boxes using boxes tensor and list of colors 329 | //multiply by 255 then reshape and recast to TUint8 3D tensor 330 | WriteFile writeFile = tf.io.writeFile(outImagePathPlaceholder, 331 | tf.image.encodeJpeg( 332 | tf.dtypes.cast(tf.reshape( 333 | tf.math.mul( 334 | tf.image.drawBoundingBoxes(tf.math.div( 335 | tf.dtypes.cast(tf.constant(reshapeTensor), 336 | TFloat32.class), 337 | tf.constant(255.0f) 338 | ), 339 | boxesPlaceHolder, colors), 340 | tf.constant(255.0f) 341 | ), 342 | tf.array( 343 | imageShape.asArray()[0], 344 | imageShape.asArray()[1], 345 | imageShape.asArray()[2] 346 | ) 347 | ), TUint8.class), 348 | jpgOptions)); 349 | //output the JPEG to file 350 | s.runner().feed(outImagePathPlaceholder, TString.scalarOf(outputImagePath)) 351 | .feed(boxesPlaceHolder, boxes) 352 | .addTarget(writeFile).run(); 353 | } 354 | } 355 | } 356 | } 357 | } 358 | } 359 | } 360 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/cnn/fastrcnn/Readme.md: -------------------------------------------------------------------------------- 1 | # FasterRcnnInception 2 | 3 | Download the model from https://www.kaggle.com/models/tensorflow/faster-rcnn-inception-resnet-v2/tensorFlow2/1024x1024/1 4 | 5 | Unzip then untar the model to a local folder - I've used models/faster_rcnn_inception_resnet_v2_1024x1024. 6 | 7 | Create a testimages folder then add some test images into a testimages folder 8 | 9 | To run the example add the input image and output image as parameters: 10 | 11 | FasterRcnnInception testimages/image2.jpg image2rcnn.jpg 12 | 13 | ### Example output 14 | Using the image2.jpg image from https://github.com/tensorflow/models/tree/master/research/object_detection/test_images 15 | ![image2rcnn.jpg.](image2rcnn.jpg "Beach") 16 | 17 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/cnn/lenet/CnnMnist.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.cnn.lenet; 18 | 19 | import java.util.Arrays; 20 | import java.util.logging.Level; 21 | import java.util.logging.Logger; 22 | import org.tensorflow.Graph; 23 | import org.tensorflow.Operand; 24 | import org.tensorflow.Session; 25 | import org.tensorflow.framework.optimizers.AdaDelta; 26 | import org.tensorflow.framework.optimizers.AdaGrad; 27 | import org.tensorflow.framework.optimizers.AdaGradDA; 28 | import org.tensorflow.framework.optimizers.Adam; 29 | import org.tensorflow.framework.optimizers.GradientDescent; 30 | import org.tensorflow.framework.optimizers.Momentum; 31 | import org.tensorflow.framework.optimizers.Optimizer; 32 | import org.tensorflow.framework.optimizers.RMSProp; 33 | import org.tensorflow.model.examples.datasets.ImageBatch; 34 | import org.tensorflow.model.examples.datasets.mnist.MnistDataset; 35 | import org.tensorflow.ndarray.ByteNdArray; 36 | import org.tensorflow.ndarray.FloatNdArray; 37 | import org.tensorflow.ndarray.Shape; 38 | import org.tensorflow.ndarray.index.Indices; 39 | import org.tensorflow.op.Op; 40 | import org.tensorflow.op.Ops; 41 | import org.tensorflow.op.core.Constant; 42 | import org.tensorflow.op.core.OneHot; 43 | import org.tensorflow.op.core.Placeholder; 44 | import org.tensorflow.op.core.Reshape; 45 | import org.tensorflow.op.core.Variable; 46 | import org.tensorflow.op.math.Add; 47 | import org.tensorflow.op.math.Mean; 48 | import org.tensorflow.op.nn.Conv2d; 49 | import org.tensorflow.op.nn.MaxPool; 50 | import org.tensorflow.op.nn.Relu; 51 | import org.tensorflow.op.nn.Softmax; 52 | import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; 53 | import org.tensorflow.op.random.TruncatedNormal; 54 | import org.tensorflow.types.TFloat32; 55 | import org.tensorflow.types.TUint8; 56 | 57 | /** 58 | * Builds a LeNet-5 style CNN for MNIST. 59 | */ 60 | public class CnnMnist { 61 | 62 | private static final Logger logger = Logger.getLogger(CnnMnist.class.getName()); 63 | 64 | private static final int PIXEL_DEPTH = 255; 65 | private static final int NUM_CHANNELS = 1; 66 | private static final int IMAGE_SIZE = 28; 67 | private static final int NUM_LABELS = MnistDataset.NUM_CLASSES; 68 | private static final long SEED = 123456789L; 69 | 70 | private static final String PADDING_TYPE = "SAME"; 71 | 72 | public static final String INPUT_NAME = "input"; 73 | public static final String OUTPUT_NAME = "output"; 74 | public static final String TARGET = "target"; 75 | public static final String TRAIN = "train"; 76 | public static final String TRAINING_LOSS = "training_loss"; 77 | 78 | private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; 79 | private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; 80 | private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; 81 | private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; 82 | 83 | public static Graph build(String optimizerName) { 84 | Graph graph = new Graph(); 85 | 86 | Ops tf = Ops.create(graph); 87 | 88 | // Inputs 89 | Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.class, 90 | Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); 91 | Reshape input_reshaped = tf 92 | .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); 93 | Placeholder labels = tf.withName(TARGET).placeholder(TUint8.class); 94 | 95 | // Scaling the features 96 | Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); 97 | Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); 98 | Operand scaledInput = tf.math 99 | .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor), 100 | scalingFactor); 101 | 102 | // First conv layer 103 | Variable conv1Weights = tf.variable(tf.math.mul(tf.random 104 | .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.class, 105 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 106 | Conv2d conv1 = tf.nn 107 | .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); 108 | Variable conv1Biases = tf 109 | .variable(tf.fill(tf.array(new int[]{32}), tf.constant(0.0f))); 110 | Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); 111 | 112 | // First pooling layer 113 | MaxPool pool1 = tf.nn 114 | .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), 115 | PADDING_TYPE); 116 | 117 | // Second conv layer 118 | Variable conv2Weights = tf.variable(tf.math.mul(tf.random 119 | .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.class, 120 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 121 | Conv2d conv2 = tf.nn 122 | .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); 123 | Variable conv2Biases = tf 124 | .variable(tf.fill(tf.array(new int[]{64}), tf.constant(0.1f))); 125 | Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); 126 | 127 | // Second pooling layer 128 | MaxPool pool2 = tf.nn 129 | .maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), 130 | PADDING_TYPE); 131 | 132 | // Flatten inputs 133 | Reshape flatten = tf.reshape(pool2, tf.concat(Arrays 134 | .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), 135 | tf.array(new int[]{-1})), tf.constant(0))); 136 | 137 | // Fully connected layer 138 | Variable fc1Weights = tf.variable(tf.math.mul(tf.random 139 | .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.class, 140 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 141 | Variable fc1Biases = tf 142 | .variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f))); 143 | Relu relu3 = tf.nn 144 | .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); 145 | 146 | // Softmax layer 147 | Variable fc2Weights = tf.variable(tf.math.mul(tf.random 148 | .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.class, 149 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 150 | Variable fc2Biases = tf 151 | .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); 152 | 153 | Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); 154 | 155 | // Predicted outputs 156 | Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); 157 | 158 | // Loss function & regularization 159 | OneHot oneHot = tf 160 | .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); 161 | SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); 162 | Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); 163 | Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math 164 | .add(tf.nn.l2Loss(fc1Biases), 165 | tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); 166 | Add loss = tf.withName(TRAINING_LOSS).math 167 | .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); 168 | 169 | String lcOptimizerName = optimizerName.toLowerCase(); 170 | // Optimizer 171 | Optimizer optimizer = switch (lcOptimizerName) { 172 | case "adadelta" -> new AdaDelta(graph, 1f, 0.95f, 1e-8f); 173 | case "adagradda" -> new AdaGradDA(graph, 0.01f); 174 | case "adagrad" -> new AdaGrad(graph, 0.01f); 175 | case "adam" -> new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); 176 | case "sgd" -> new GradientDescent(graph, 0.01f); 177 | case "momentum" -> new Momentum(graph, 0.01f, 0.9f, false); 178 | case "rmsprop" -> new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); 179 | default -> throw new IllegalArgumentException("Unknown optimizer " + optimizerName); 180 | }; 181 | logger.info("Optimizer = " + optimizer); 182 | Op minimize = optimizer.minimize(loss, TRAIN); 183 | 184 | return graph; 185 | } 186 | 187 | public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) { 188 | int interval = 0; 189 | // Train the model 190 | for (int i = 0; i < epochs; i++) { 191 | for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { 192 | try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images()); 193 | TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels()); 194 | var result = session.runner() 195 | .feed(TARGET, batchLabels) 196 | .feed(INPUT_NAME, batchImages) 197 | .addTarget(TRAIN) 198 | .fetch(TRAINING_LOSS) 199 | .run()) { 200 | TFloat32 loss = (TFloat32) result.get(0); 201 | if (interval % 100 == 0) { 202 | logger.log(Level.INFO, 203 | "Iteration = " + interval + ", training loss = " + loss.getFloat()); 204 | } 205 | } 206 | interval++; 207 | } 208 | } 209 | } 210 | 211 | public static void test(Session session, int minibatchSize, MnistDataset dataset) { 212 | int correctCount = 0; 213 | int[][] confusionMatrix = new int[10][10]; 214 | 215 | for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { 216 | try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images()); 217 | var result = session.runner().feed(INPUT_NAME, transformedInput).fetch(OUTPUT_NAME).run()) { 218 | TFloat32 outputTensor = (TFloat32) result.get(0); 219 | ByteNdArray labelBatch = trainingBatch.labels(); 220 | for (int k = 0; k < labelBatch.shape().get(0); k++) { 221 | byte trueLabel = labelBatch.getByte(k); 222 | int predLabel; 223 | 224 | predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all())); 225 | if (predLabel == trueLabel) { 226 | correctCount++; 227 | } 228 | 229 | confusionMatrix[trueLabel][predLabel]++; 230 | } 231 | } 232 | } 233 | 234 | logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); 235 | 236 | StringBuilder sb = new StringBuilder(); 237 | sb.append("Label"); 238 | for (int i = 0; i < confusionMatrix.length; i++) { 239 | sb.append(String.format("%1$5s", "" + i)); 240 | } 241 | sb.append("\n"); 242 | 243 | for (int i = 0; i < confusionMatrix.length; i++) { 244 | sb.append(String.format("%1$5s", "" + i)); 245 | for (int j = 0; j < confusionMatrix[i].length; j++) { 246 | sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); 247 | } 248 | sb.append("\n"); 249 | } 250 | 251 | System.out.println(sb); 252 | } 253 | 254 | /** 255 | * Find the maximum probability and return it's index. 256 | * 257 | * @param probabilities The probabilites. 258 | * @return The index of the max. 259 | */ 260 | public static int argmax(FloatNdArray probabilities) { 261 | float maxVal = Float.NEGATIVE_INFINITY; 262 | int idx = 0; 263 | for (int i = 0; i < probabilities.shape().get(0); i++) { 264 | float curVal = probabilities.getFloat(i); 265 | if (curVal > maxVal) { 266 | maxVal = curVal; 267 | idx = i; 268 | } 269 | } 270 | return idx; 271 | } 272 | 273 | public static void main(String[] args) { 274 | int epochs; 275 | int minibatchSize; 276 | String optimizerName; 277 | 278 | if (args.length == 0) { 279 | epochs = 1; 280 | minibatchSize = 64; 281 | optimizerName = "adam"; 282 | 283 | } else if (args.length == 3) { 284 | epochs = Integer.parseInt(args[0]); 285 | minibatchSize = Integer.parseInt(args[1]); 286 | optimizerName = args[2]; 287 | 288 | } else { 289 | throw new IllegalArgumentException("Usage: MNISTTest "); 290 | } 291 | 292 | MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, 293 | TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); 294 | 295 | logger.info("Loaded data."); 296 | 297 | try (Graph graph = build(optimizerName); 298 | Session session = new Session(graph)) { 299 | train(session, epochs, minibatchSize, dataset); 300 | 301 | logger.info("Trained model"); 302 | 303 | test(session, minibatchSize, dataset); 304 | } 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/cnn/vgg/VGG11OnFashionMnist.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.cnn.vgg; 18 | 19 | import java.util.logging.Logger; 20 | import org.tensorflow.model.examples.datasets.mnist.MnistDataset; 21 | 22 | /** 23 | * Trains and evaluates VGG'11 model on FashionMNIST dataset. 24 | */ 25 | public class VGG11OnFashionMnist { 26 | // Hyper-parameters 27 | public static final int EPOCHS = 1; 28 | public static final int BATCH_SIZE = 500; 29 | 30 | // Fashion MNIST dataset paths 31 | public static final String TRAINING_IMAGES_ARCHIVE = "fashionmnist/train-images-idx3-ubyte.gz"; 32 | public static final String TRAINING_LABELS_ARCHIVE = "fashionmnist/train-labels-idx1-ubyte.gz"; 33 | public static final String TEST_IMAGES_ARCHIVE = "fashionmnist/t10k-images-idx3-ubyte.gz"; 34 | public static final String TEST_LABELS_ARCHIVE = "fashionmnist/t10k-labels-idx1-ubyte.gz"; 35 | 36 | private static final Logger logger = Logger.getLogger(VGG11OnFashionMnist.class.getName()); 37 | 38 | public static void main(String[] args) { 39 | logger.info("Data loading."); 40 | MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); 41 | 42 | try (VGGModel vggModel = new VGGModel()) { 43 | logger.info("Model training."); 44 | vggModel.train(dataset, EPOCHS, BATCH_SIZE); 45 | 46 | logger.info("Model evaluation."); 47 | vggModel.test(dataset, BATCH_SIZE); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/cnn/vgg/VGGModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.cnn.vgg; 18 | 19 | import java.util.Arrays; 20 | import java.util.logging.Level; 21 | import java.util.logging.Logger; 22 | import org.tensorflow.Graph; 23 | import org.tensorflow.Operand; 24 | import org.tensorflow.Session; 25 | import org.tensorflow.framework.optimizers.Adam; 26 | import org.tensorflow.framework.optimizers.Optimizer; 27 | import org.tensorflow.model.examples.datasets.ImageBatch; 28 | import org.tensorflow.model.examples.datasets.mnist.MnistDataset; 29 | import org.tensorflow.ndarray.ByteNdArray; 30 | import org.tensorflow.ndarray.FloatNdArray; 31 | import org.tensorflow.ndarray.Shape; 32 | import org.tensorflow.ndarray.index.Indices; 33 | import org.tensorflow.op.Ops; 34 | import org.tensorflow.op.core.Constant; 35 | import org.tensorflow.op.core.OneHot; 36 | import org.tensorflow.op.core.Placeholder; 37 | import org.tensorflow.op.core.Reshape; 38 | import org.tensorflow.op.core.Variable; 39 | import org.tensorflow.op.math.Add; 40 | import org.tensorflow.op.math.Mean; 41 | import org.tensorflow.op.nn.Conv2d; 42 | import org.tensorflow.op.nn.MaxPool; 43 | import org.tensorflow.op.nn.Relu; 44 | import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; 45 | import org.tensorflow.op.random.TruncatedNormal; 46 | import org.tensorflow.types.TFloat32; 47 | import org.tensorflow.types.TUint8; 48 | 49 | /** 50 | * Describes the VGGModel. 51 | */ 52 | public class VGGModel implements AutoCloseable { 53 | private static final int PIXEL_DEPTH = 255; 54 | private static final int NUM_CHANNELS = 1; 55 | private static final int IMAGE_SIZE = 28; 56 | private static final int NUM_LABELS = MnistDataset.NUM_CLASSES; 57 | private static final long SEED = 123456789L; 58 | 59 | private static final String PADDING_TYPE = "SAME"; 60 | public static final String INPUT_NAME = "input"; 61 | public static final String OUTPUT_NAME = "output"; 62 | public static final String TARGET = "target"; 63 | public static final String TRAIN = "train"; 64 | public static final String TRAINING_LOSS = "training_loss"; 65 | 66 | private static final Logger logger = Logger.getLogger(VGGModel.class.getName()); 67 | 68 | private final Graph graph; 69 | 70 | private final Session session; 71 | 72 | public VGGModel() { 73 | graph = compile(); 74 | session = new Session(graph); 75 | } 76 | 77 | public static Graph compile() { 78 | Graph graph = new Graph(); 79 | 80 | Ops tf = Ops.create(graph); 81 | 82 | // Inputs 83 | Placeholder input = tf.withName(INPUT_NAME).placeholder(TUint8.class, 84 | Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); 85 | Reshape input_reshaped = tf 86 | .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); 87 | Placeholder labels = tf.withName(TARGET).placeholder(TUint8.class); 88 | 89 | // Scaling the features 90 | Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); 91 | Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); 92 | Operand scaledInput = tf.math 93 | .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.class), centeringFactor), 94 | scalingFactor); 95 | 96 | Relu relu1 = vggConv2DLayer("1", tf, scaledInput, new int[]{3, 3, NUM_CHANNELS, 32}, 32); 97 | 98 | MaxPool pool1 = vggMaxPool(tf, relu1); 99 | 100 | Relu relu2 = vggConv2DLayer("2", tf, pool1, new int[]{3, 3, 32, 64}, 64); 101 | 102 | MaxPool pool2 = vggMaxPool(tf, relu2); 103 | 104 | Relu relu3 = vggConv2DLayer("3", tf, pool2, new int[]{3, 3, 64, 128}, 128); 105 | Relu relu4 = vggConv2DLayer("4", tf, relu3, new int[]{3, 3, 128, 128}, 128); 106 | 107 | MaxPool pool3 = vggMaxPool(tf, relu4); 108 | 109 | Relu relu5 = vggConv2DLayer("5", tf, pool3, new int[]{3, 3, 128, 256}, 256); 110 | Relu relu6 = vggConv2DLayer("6", tf, relu5, new int[]{3, 3, 256, 256}, 256); 111 | 112 | MaxPool pool4 = vggMaxPool(tf, relu6); 113 | 114 | Relu relu7 = vggConv2DLayer("7", tf, pool4, new int[]{3, 3, 256, 256}, 256); 115 | Relu relu8 = vggConv2DLayer("8", tf, relu7, new int[]{3, 3, 256, 256}, 256); 116 | 117 | MaxPool pool5 = vggMaxPool(tf, relu8); 118 | 119 | Reshape flatten = vggFlatten(tf, pool5); 120 | 121 | Add loss = buildFCLayersAndRegularization(tf, labels, flatten); 122 | 123 | Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); 124 | 125 | optimizer.minimize(loss, TRAIN); 126 | 127 | return graph; 128 | } 129 | 130 | public static Add buildFCLayersAndRegularization(Ops tf, Placeholder labels, Reshape flatten) { 131 | int fcBiasShape = 100; 132 | int[] fcWeightShape = {256, fcBiasShape}; 133 | 134 | Variable fc1Weights = tf.variable(tf.math.mul(tf.random 135 | .truncatedNormal(tf.array(fcWeightShape), TFloat32.class, 136 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 137 | Variable fc1Biases = tf 138 | .variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f))); 139 | Relu fcRelu = tf.nn 140 | .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); 141 | 142 | // Softmax layer 143 | Variable fc2Weights = tf.variable(tf.math.mul(tf.random 144 | .truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.class, 145 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 146 | Variable fc2Biases = tf 147 | .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); 148 | 149 | Add logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases); 150 | 151 | // Predicted outputs 152 | tf.withName(OUTPUT_NAME).nn.softmax(logits); 153 | 154 | // Loss function & regularization 155 | OneHot oneHot = tf 156 | .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); 157 | SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); 158 | Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); 159 | Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math 160 | .add(tf.nn.l2Loss(fc1Biases), 161 | tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); 162 | return tf.withName(TRAINING_LOSS).math 163 | .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); 164 | } 165 | 166 | public static Reshape vggFlatten(Ops tf, MaxPool pool2) { 167 | return tf.reshape(pool2, tf.concat(Arrays 168 | .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), 169 | tf.array(new int[]{-1})), tf.constant(0))); 170 | } 171 | 172 | public static MaxPool vggMaxPool(Ops tf, Relu relu1) { 173 | return tf.nn 174 | .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), 175 | PADDING_TYPE); 176 | } 177 | 178 | public static Relu vggConv2DLayer(String layerName, Ops tf, Operand scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) { 179 | Variable conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random 180 | .truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.class, 181 | TruncatedNormal.seed(SEED)), tf.constant(0.1f))); 182 | Conv2d conv = tf.nn 183 | .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); 184 | Variable convBias = tf 185 | .withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f))); 186 | return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias)); 187 | } 188 | 189 | public void train(MnistDataset dataset, int epochs, int minibatchSize) { 190 | int interval = 0; 191 | // Train the model 192 | for (int i = 0; i < epochs; i++) { 193 | for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { 194 | try (TUint8 batchImages = TUint8.tensorOf(trainingBatch.images()); 195 | TUint8 batchLabels = TUint8.tensorOf(trainingBatch.labels()); 196 | var result = session.runner() 197 | .feed(TARGET, batchLabels) 198 | .feed(INPUT_NAME, batchImages) 199 | .addTarget(TRAIN) 200 | .fetch(TRAINING_LOSS) 201 | .run()) { 202 | TFloat32 loss = (TFloat32) result.get(0); 203 | 204 | logger.log(Level.INFO, 205 | "Iteration = " + interval + ", training loss = " + loss.getFloat()); 206 | 207 | } 208 | interval++; 209 | } 210 | } 211 | } 212 | 213 | public void test(MnistDataset dataset, int minibatchSize) { 214 | int correctCount = 0; 215 | int[][] confusionMatrix = new int[10][10]; 216 | 217 | for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { 218 | try (TUint8 transformedInput = TUint8.tensorOf(trainingBatch.images()); 219 | var result = session.runner() 220 | .feed(INPUT_NAME, transformedInput) 221 | .fetch(OUTPUT_NAME).run()) { 222 | TFloat32 outputTensor = (TFloat32) result.get(0); 223 | 224 | ByteNdArray labelBatch = trainingBatch.labels(); 225 | for (int k = 0; k < labelBatch.shape().get(0); k++) { 226 | byte trueLabel = labelBatch.getByte(k); 227 | int predLabel; 228 | 229 | predLabel = argmax(outputTensor.slice(Indices.at(k), Indices.all())); 230 | if (predLabel == trueLabel) { 231 | correctCount++; 232 | } 233 | 234 | confusionMatrix[trueLabel][predLabel]++; 235 | } 236 | } 237 | } 238 | 239 | logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); 240 | 241 | StringBuilder sb = new StringBuilder(); 242 | sb.append("Label"); 243 | for (int i = 0; i < confusionMatrix.length; i++) { 244 | sb.append(String.format("%1$5s", "" + i)); 245 | } 246 | sb.append("\n"); 247 | 248 | for (int i = 0; i < confusionMatrix.length; i++) { 249 | sb.append(String.format("%1$5s", "" + i)); 250 | for (int j = 0; j < confusionMatrix[i].length; j++) { 251 | sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); 252 | } 253 | sb.append("\n"); 254 | } 255 | 256 | System.out.println(sb); 257 | } 258 | 259 | /** 260 | * Find the maximum probability and return it's index. 261 | * 262 | * @param probabilities The probabilites. 263 | * @return The index of the max. 264 | */ 265 | public static int argmax(FloatNdArray probabilities) { 266 | float maxVal = Float.NEGATIVE_INFINITY; 267 | int idx = 0; 268 | for (int i = 0; i < probabilities.shape().get(0); i++) { 269 | float curVal = probabilities.getFloat(i); 270 | if (curVal > maxVal) { 271 | maxVal = curVal; 272 | idx = i; 273 | } 274 | } 275 | return idx; 276 | } 277 | 278 | @Override 279 | public void close() { 280 | session.close(); 281 | graph.close(); 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/datasets/ImageBatch.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.datasets; 18 | 19 | import org.tensorflow.ndarray.ByteNdArray; 20 | 21 | /** 22 | * Batch of images for batch training. 23 | */ 24 | public record ImageBatch(ByteNdArray images, ByteNdArray labels) { } 25 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/datasets/ImageBatchIterator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.datasets; 18 | 19 | import static org.tensorflow.ndarray.index.Indices.range; 20 | 21 | import java.util.Iterator; 22 | 23 | import org.tensorflow.ndarray.index.Index; 24 | import org.tensorflow.ndarray.ByteNdArray; 25 | import org.tensorflow.ndarray.index.Index; 26 | 27 | /** Basic batch iterator across images presented in datset. */ 28 | public class ImageBatchIterator implements Iterator { 29 | 30 | @Override 31 | public boolean hasNext() { 32 | return batchStart < numImages; 33 | } 34 | 35 | @Override 36 | public ImageBatch next() { 37 | long nextBatchSize = Math.min(batchSize, numImages - batchStart); 38 | Index range = range(batchStart, batchStart + nextBatchSize); 39 | batchStart += nextBatchSize; 40 | return new ImageBatch(images.slice(range), labels.slice(range)); 41 | } 42 | 43 | public ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) { 44 | this.batchSize = batchSize; 45 | this.images = images; 46 | this.labels = labels; 47 | this.numImages = images != null ? images.shape().size(0) : 0; 48 | this.batchStart = 0; 49 | } 50 | 51 | private final int batchSize; 52 | private final ByteNdArray images; 53 | private final ByteNdArray labels; 54 | private final long numImages; 55 | private int batchStart; 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/datasets/mnist/MnistDataset.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.datasets.mnist; 18 | 19 | import org.tensorflow.model.examples.datasets.ImageBatch; 20 | import org.tensorflow.model.examples.datasets.ImageBatchIterator; 21 | import org.tensorflow.ndarray.Shape; 22 | import org.tensorflow.ndarray.buffer.DataBuffers; 23 | import org.tensorflow.ndarray.ByteNdArray; 24 | import org.tensorflow.ndarray.NdArrays; 25 | 26 | import java.io.DataInputStream; 27 | import java.io.IOException; 28 | import java.util.zip.GZIPInputStream; 29 | 30 | import static org.tensorflow.ndarray.index.Indices.sliceFrom; 31 | import static org.tensorflow.ndarray.index.Indices.sliceTo; 32 | 33 | /** Common loader and data preprocessor for MNIST and FashionMNIST datasets. */ 34 | public class MnistDataset { 35 | public static final int NUM_CLASSES = 10; 36 | 37 | public static MnistDataset create(int validationSize, String trainingImagesArchive, String trainingLabelsArchive, 38 | String testImagesArchive, String testLabelsArchive) { 39 | try { 40 | ByteNdArray trainingImages = readArchive(trainingImagesArchive); 41 | ByteNdArray trainingLabels = readArchive(trainingLabelsArchive); 42 | ByteNdArray testImages = readArchive(testImagesArchive); 43 | ByteNdArray testLabels = readArchive(testLabelsArchive); 44 | 45 | if (validationSize > 0) { 46 | return new MnistDataset( 47 | trainingImages.slice(sliceFrom(validationSize)), 48 | trainingLabels.slice(sliceFrom(validationSize)), 49 | trainingImages.slice(sliceTo(validationSize)), 50 | trainingLabels.slice(sliceTo(validationSize)), 51 | testImages, 52 | testLabels 53 | ); 54 | } 55 | return new MnistDataset(trainingImages, trainingLabels, null, null, testImages, testLabels); 56 | 57 | } catch (IOException e) { 58 | throw new AssertionError(e); 59 | } 60 | } 61 | 62 | public Iterable trainingBatches(int batchSize) { 63 | return () -> new ImageBatchIterator(batchSize, trainingImages, trainingLabels); 64 | } 65 | 66 | public Iterable validationBatches(int batchSize) { 67 | return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels); 68 | } 69 | 70 | public Iterable testBatches(int batchSize) { 71 | return () -> new ImageBatchIterator(batchSize, testImages, testLabels); 72 | } 73 | 74 | public ImageBatch testBatch() { 75 | return new ImageBatch(testImages, testLabels); 76 | } 77 | 78 | public long imageSize() { 79 | return imageSize; 80 | } 81 | 82 | public long numTrainingExamples() { 83 | return trainingLabels.shape().get(0); 84 | } 85 | 86 | public long numTestingExamples() { 87 | return testLabels.shape().get(0); 88 | } 89 | 90 | public long numValidationExamples() { 91 | return validationLabels.shape().get(0); 92 | } 93 | 94 | private static final int TYPE_UBYTE = 0x08; 95 | 96 | private final ByteNdArray trainingImages; 97 | private final ByteNdArray trainingLabels; 98 | private final ByteNdArray validationImages; 99 | private final ByteNdArray validationLabels; 100 | private final ByteNdArray testImages; 101 | private final ByteNdArray testLabels; 102 | private final long imageSize; 103 | 104 | private MnistDataset( 105 | ByteNdArray trainingImages, 106 | ByteNdArray trainingLabels, 107 | ByteNdArray validationImages, 108 | ByteNdArray validationLabels, 109 | ByteNdArray testImages, 110 | ByteNdArray testLabels 111 | ) { 112 | this.trainingImages = trainingImages; 113 | this.trainingLabels = trainingLabels; 114 | this.validationImages = validationImages; 115 | this.validationLabels = validationLabels; 116 | this.testImages = testImages; 117 | this.testLabels = testLabels; 118 | this.imageSize = trainingImages.get(0).shape().size(); 119 | } 120 | 121 | private static ByteNdArray readArchive(String archiveName) throws IOException { 122 | try (DataInputStream archiveStream = new DataInputStream( 123 | new GZIPInputStream(MnistDataset.class.getClassLoader().getResourceAsStream(archiveName)) 124 | )) { 125 | archiveStream.readShort(); // first two bytes are always 0 126 | byte magic = archiveStream.readByte(); 127 | if (magic != TYPE_UBYTE) { 128 | throw new IllegalArgumentException("\"" + archiveName + "\" is not a valid archive"); 129 | } 130 | int numDims = archiveStream.readByte(); 131 | long[] dimSizes = new long[numDims]; 132 | int size = 1; // for simplicity, we assume that total size does not exceeds Integer.MAX_VALUE 133 | for (int i = 0; i < dimSizes.length; ++i) { 134 | dimSizes[i] = archiveStream.readInt(); 135 | size *= dimSizes[i]; 136 | } 137 | byte[] bytes = new byte[size]; 138 | archiveStream.readFully(bytes); 139 | return NdArrays.wrap(Shape.of(dimSizes), DataBuffers.of(bytes, true, false)); 140 | } 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/dense/SimpleMnist.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020, 2024 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.dense; 18 | 19 | import org.tensorflow.Graph; 20 | import org.tensorflow.Operand; 21 | import org.tensorflow.Session; 22 | import org.tensorflow.framework.optimizers.GradientDescent; 23 | import org.tensorflow.framework.optimizers.Optimizer; 24 | import org.tensorflow.model.examples.datasets.ImageBatch; 25 | import org.tensorflow.model.examples.datasets.mnist.MnistDataset; 26 | import org.tensorflow.ndarray.ByteNdArray; 27 | import org.tensorflow.ndarray.Shape; 28 | import org.tensorflow.op.Op; 29 | import org.tensorflow.op.Ops; 30 | import org.tensorflow.op.core.Placeholder; 31 | import org.tensorflow.op.core.Variable; 32 | import org.tensorflow.op.math.Mean; 33 | import org.tensorflow.op.nn.Softmax; 34 | import org.tensorflow.types.TFloat32; 35 | import org.tensorflow.types.TInt64; 36 | 37 | public class SimpleMnist implements Runnable { 38 | private static final String TRAINING_IMAGES_ARCHIVE = "mnist/train-images-idx3-ubyte.gz"; 39 | private static final String TRAINING_LABELS_ARCHIVE = "mnist/train-labels-idx1-ubyte.gz"; 40 | private static final String TEST_IMAGES_ARCHIVE = "mnist/t10k-images-idx3-ubyte.gz"; 41 | private static final String TEST_LABELS_ARCHIVE = "mnist/t10k-labels-idx1-ubyte.gz"; 42 | 43 | public static void main(String[] args) { 44 | MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, 45 | TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); 46 | 47 | try (Graph graph = new Graph()) { 48 | SimpleMnist mnist = new SimpleMnist(graph, dataset); 49 | mnist.run(); 50 | } 51 | } 52 | 53 | @Override 54 | public void run() { 55 | Ops tf = Ops.create(graph); 56 | 57 | // Create placeholders and variables, which should fit batches of an unknown number of images 58 | Placeholder images = tf.placeholder(TFloat32.class); 59 | Placeholder labels = tf.placeholder(TFloat32.class); 60 | 61 | // Create weights with an initial value of 0 62 | Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES); 63 | Variable weights = tf.variable(tf.zeros(tf.constant(weightShape), TFloat32.class)); 64 | 65 | // Create biases with an initial value of 0 66 | Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES); 67 | Variable biases = tf.variable(tf.zeros(tf.constant(biasShape), TFloat32.class)); 68 | 69 | // Predict the class of each image in the batch and compute the loss 70 | Softmax softmax = 71 | tf.nn.softmax( 72 | tf.math.add( 73 | tf.linalg.matMul(images, weights), 74 | biases 75 | ) 76 | ); 77 | Mean crossEntropy = 78 | tf.math.mean( 79 | tf.math.neg( 80 | tf.reduceSum( 81 | tf.math.mul(labels, tf.math.log(softmax)), 82 | tf.array(1) 83 | ) 84 | ), 85 | tf.array(0) 86 | ); 87 | 88 | // Back-propagate gradients to variables for training 89 | Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE); 90 | Op minimize = optimizer.minimize(crossEntropy); 91 | 92 | // Compute the accuracy of the model 93 | Operand predicted = tf.math.argMax(softmax, tf.constant(1)); 94 | Operand expected = tf.math.argMax(labels, tf.constant(1)); 95 | Operand accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.class), tf.array(0)); 96 | 97 | // Run the graph 98 | try (Session session = new Session(graph)) { 99 | 100 | // Train the model 101 | for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) { 102 | try (TFloat32 batchImages = preprocessImages(trainingBatch.images()); 103 | TFloat32 batchLabels = preprocessLabels(trainingBatch.labels())) { 104 | session.runner() 105 | .addTarget(minimize) 106 | .feed(images.asOutput(), batchImages) 107 | .feed(labels.asOutput(), batchLabels) 108 | .run(); 109 | } 110 | } 111 | 112 | // Test the model 113 | ImageBatch testBatch = dataset.testBatch(); 114 | try (TFloat32 testImages = preprocessImages(testBatch.images()); 115 | TFloat32 testLabels = preprocessLabels(testBatch.labels()); 116 | var result = session.runner() 117 | .fetch(accuracy) 118 | .feed(images.asOutput(), testImages) 119 | .feed(labels.asOutput(), testLabels) 120 | .run()) { 121 | TFloat32 accuracyValue = (TFloat32) result.get(0); 122 | System.out.println("Accuracy: " + accuracyValue.getFloat()); 123 | } 124 | } 125 | } 126 | 127 | private static final int VALIDATION_SIZE = 0; 128 | private static final int TRAINING_BATCH_SIZE = 100; 129 | private static final float LEARNING_RATE = 0.2f; 130 | 131 | private static TFloat32 preprocessImages(ByteNdArray rawImages) { 132 | Ops tf = Ops.create(); 133 | 134 | // Flatten images in a single dimension and normalize their pixels as floats. 135 | long imageSize = rawImages.get(0).shape().size(); 136 | return tf.math.div( 137 | tf.reshape( 138 | tf.dtypes.cast(tf.constant(rawImages), TFloat32.class), 139 | tf.array(-1L, imageSize) 140 | ), 141 | tf.constant(255.0f) 142 | ).asTensor(); 143 | } 144 | 145 | private static TFloat32 preprocessLabels(ByteNdArray rawLabels) { 146 | Ops tf = Ops.create(); 147 | 148 | // Map labels to one hot vectors where only the expected predictions as a value of 1.0 149 | return tf.oneHot( 150 | tf.constant(rawLabels), 151 | tf.constant(MnistDataset.NUM_CLASSES), 152 | tf.constant(1.0f), 153 | tf.constant(0.0f) 154 | ).asTensor(); 155 | } 156 | 157 | private final Graph graph; 158 | private final MnistDataset dataset; 159 | 160 | private SimpleMnist(Graph graph, MnistDataset dataset) { 161 | this.graph = graph; 162 | this.dataset = dataset; 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/regression/linear/LinearRegressionExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.regression.linear; 18 | 19 | import java.util.List; 20 | import java.util.Random; 21 | import org.tensorflow.Graph; 22 | import org.tensorflow.Session; 23 | import org.tensorflow.framework.optimizers.GradientDescent; 24 | import org.tensorflow.framework.optimizers.Optimizer; 25 | import org.tensorflow.ndarray.Shape; 26 | import org.tensorflow.op.Op; 27 | import org.tensorflow.op.Ops; 28 | import org.tensorflow.op.core.Placeholder; 29 | import org.tensorflow.op.core.Variable; 30 | import org.tensorflow.op.math.Add; 31 | import org.tensorflow.op.math.Div; 32 | import org.tensorflow.op.math.Mul; 33 | import org.tensorflow.op.math.Pow; 34 | import org.tensorflow.types.TFloat32; 35 | 36 | /** 37 | * In this example TensorFlow finds the weight and bias of the linear regression during 1 epoch, 38 | * training on observations one by one. 39 | *

40 | * Also, the weight and bias are extracted and printed. 41 | */ 42 | public class LinearRegressionExample { 43 | /** 44 | * Amount of data points. 45 | */ 46 | private static final int N = 10; 47 | 48 | /** 49 | * This value is used to fill the Y placeholder in prediction. 50 | */ 51 | public static final float LEARNING_RATE = 0.1f; 52 | public static final String WEIGHT_VARIABLE_NAME = "weight"; 53 | public static final String BIAS_VARIABLE_NAME = "bias"; 54 | 55 | public static void main(String[] args) { 56 | // Prepare the data 57 | float[] xValues = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f}; 58 | float[] yValues = new float[N]; 59 | 60 | Random rnd = new Random(42); 61 | 62 | for (int i = 0; i < yValues.length; i++) { 63 | yValues[i] = (float) (10 * xValues[i] + 2 + 0.1 * (rnd.nextDouble() - 0.5)); 64 | } 65 | 66 | try (Graph graph = new Graph()) { 67 | Ops tf = Ops.create(graph); 68 | 69 | // Define placeholders 70 | Placeholder xData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); 71 | Placeholder yData = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.scalar())); 72 | 73 | // Define variables 74 | Variable weight = tf.withName(WEIGHT_VARIABLE_NAME).variable(tf.constant(1f)); 75 | Variable bias = tf.withName(BIAS_VARIABLE_NAME).variable(tf.constant(1f)); 76 | 77 | // Define the model function weight*x + bias 78 | Mul mul = tf.math.mul(xData, weight); 79 | Add yPredicted = tf.math.add(mul, bias); 80 | 81 | // Define loss function MSE 82 | Pow sum = tf.math.pow(tf.math.sub(yPredicted, yData), tf.constant(2f)); 83 | Div mse = tf.math.div(sum, tf.constant(2f * N)); 84 | 85 | // Back-propagate gradients to variables for training 86 | Optimizer optimizer = new GradientDescent(graph, LEARNING_RATE); 87 | Op minimize = optimizer.minimize(mse); 88 | 89 | try (Session session = new Session(graph)) { 90 | 91 | // Train the model on data 92 | for (int i = 0; i < xValues.length; i++) { 93 | float y = yValues[i]; 94 | float x = xValues[i]; 95 | 96 | try (TFloat32 xTensor = TFloat32.scalarOf(x); 97 | TFloat32 yTensor = TFloat32.scalarOf(y)) { 98 | 99 | session.runner() 100 | .addTarget(minimize) 101 | .feed(xData.asOutput(), xTensor) 102 | .feed(yData.asOutput(), yTensor) 103 | .run(); 104 | 105 | System.out.println("Training phase"); 106 | System.out.println("x is " + x + " y is " + y); 107 | } 108 | } 109 | 110 | // Extract linear regression model weight and bias values 111 | try (var result = session.runner() 112 | .fetch(WEIGHT_VARIABLE_NAME) 113 | .fetch(BIAS_VARIABLE_NAME) 114 | .run()) { 115 | System.out.println("Weight is " + result.get(WEIGHT_VARIABLE_NAME)); 116 | System.out.println("Bias is " + result.get(BIAS_VARIABLE_NAME)); 117 | } 118 | 119 | // Let's predict y for x = 10f 120 | float x = 10f; 121 | float predictedY = 0f; 122 | 123 | try (TFloat32 xTensor = TFloat32.scalarOf(x); 124 | TFloat32 yTensor = TFloat32.scalarOf(predictedY); 125 | TFloat32 yPredictedTensor = (TFloat32)session.runner() 126 | .feed(xData.asOutput(), xTensor) 127 | .feed(yData.asOutput(), yTensor) 128 | .fetch(yPredicted) 129 | .run().get(0)) { 130 | 131 | predictedY = yPredictedTensor.getFloat(); 132 | 133 | System.out.println("Predicted value: " + predictedY); 134 | } 135 | } 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * ======================================================================= 16 | */ 17 | package org.tensorflow.model.examples.tensors; 18 | 19 | import org.tensorflow.ndarray.Shape; 20 | import org.tensorflow.ndarray.IntNdArray; 21 | import org.tensorflow.ndarray.NdArrays; 22 | import org.tensorflow.types.TInt32; 23 | 24 | import java.util.Arrays; 25 | 26 | /** 27 | * Creates a few tensors of ranks: 0, 1, 2, 3. 28 | */ 29 | public class TensorCreation { 30 | 31 | public static void main(String[] args) { 32 | // Rank 0 Tensor 33 | TInt32 rank0Tensor = TInt32.scalarOf(42); 34 | 35 | System.out.println("---- Scalar tensor ---------"); 36 | 37 | System.out.println("DataType: " + rank0Tensor.dataType().name()); 38 | 39 | System.out.println("Rank: " + rank0Tensor.shape().size()); 40 | 41 | System.out.println("Shape: " + Arrays.toString(rank0Tensor.shape().asArray())); 42 | 43 | rank0Tensor.scalars().forEach(value -> System.out.println("Value: " + value.getObject())); 44 | 45 | // Rank 1 Tensor 46 | TInt32 rank1Tensor = TInt32.vectorOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); 47 | 48 | System.out.println("---- Vector tensor ---------"); 49 | 50 | System.out.println("DataType: " + rank1Tensor.dataType().name()); 51 | 52 | System.out.println("Rank: " + rank1Tensor.shape().size()); 53 | 54 | System.out.println("Shape: " + Arrays.toString(rank1Tensor.shape().asArray())); 55 | 56 | System.out.println("6th element: " + rank1Tensor.getInt(5)); 57 | 58 | // Rank 2 Tensor 59 | // 3x2 matrix of ints. 60 | IntNdArray matrix2d = NdArrays.ofInts(Shape.of(3, 2)); 61 | 62 | matrix2d.set(NdArrays.vectorOf(1, 2), 0) 63 | .set(NdArrays.vectorOf(3, 4), 1) 64 | .set(NdArrays.vectorOf(5, 6), 2); 65 | 66 | TInt32 rank2Tensor = TInt32.tensorOf(matrix2d); 67 | 68 | System.out.println("---- Matrix tensor ---------"); 69 | 70 | System.out.println("DataType: " + rank2Tensor.dataType().name()); 71 | 72 | System.out.println("Rank: " + rank2Tensor.shape().size()); 73 | 74 | System.out.println("Shape: " + Arrays.toString(rank2Tensor.shape().asArray())); 75 | 76 | System.out.println("6th element: " + rank2Tensor.getInt(2, 1)); 77 | 78 | // Rank 3 Tensor 79 | // 3*2*4 matrix of ints. 80 | IntNdArray matrix3d = NdArrays.ofInts(Shape.of(3, 2, 4)); 81 | 82 | matrix3d.elements(0).forEach(matrix -> { 83 | matrix 84 | .set(NdArrays.vectorOf(1, 2, 3, 4), 0) 85 | .set(NdArrays.vectorOf(5, 6, 7, 8), 1); 86 | }); 87 | 88 | TInt32 rank3Tensor = TInt32.tensorOf(matrix3d); 89 | 90 | System.out.println("---- Matrix tensor ---------"); 91 | 92 | System.out.println("DataType: " + rank3Tensor.dataType().name()); 93 | 94 | System.out.println("Rank: " + rank3Tensor.shape().size()); 95 | 96 | System.out.println("Shape: " + Arrays.toString(rank3Tensor.shape().asArray())); 97 | 98 | System.out.println("n-th element: " + rank3Tensor.getInt(2, 1, 3)); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/META-INF/MANIFEST.MF -------------------------------------------------------------------------------- /src/main/resources/fashionmnist/Readme.md: -------------------------------------------------------------------------------- 1 | This dataset is distributed under MIT License and presented in next paper. 2 | Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. Han Xiao, Kashif Rasul, Roland Vollgraf. arXiv:1708.07747 3 | 4 | The data was downloaded from the FashionMnist Repository 5 | https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion 6 | 7 | -------------------------------------------------------------------------------- /src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/fashionmnist/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fashionmnist/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/fasterrcnninception/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fasterrcnninception/image2.jpg -------------------------------------------------------------------------------- /src/main/resources/fasterrcnninception/image2rcnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/fasterrcnninception/image2rcnn.jpg -------------------------------------------------------------------------------- /src/main/resources/mnist/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/mnist/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/mnist/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /src/main/resources/mnist/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/java-models/c7c0e907e11fb6a55ea4f66b85a3452b35a17c3f/src/main/resources/mnist/train-labels-idx1-ubyte.gz --------------------------------------------------------------------------------