├── .gitignore ├── .vscode ├── launch.json ├── settings.json └── tasks.json ├── README.md ├── build.sbt ├── project └── build.properties └── src └── main └── scala ├── Main.scala └── org └── tensorflow └── keras └── scala ├── Activation.scala ├── Callback.scala ├── Initializer.scala ├── Layer.scala ├── Layers.scala ├── Loss.scala ├── Metric.scala ├── Model.scala ├── Optimizer.scala ├── Pair.scala └── examples ├── FashionMNISTKeras.scala └── MNISTKeras.scala /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | .Spotlight-V100 6 | .Trashes 7 | *~ 8 | 9 | # IntelliJ 10 | *.iml 11 | *.ipr 12 | *.iws 13 | out/ 14 | .idea/ 15 | .idea_modules/ 16 | 17 | # Java 18 | *.class 19 | *.log 20 | .mtj.tmp/ 21 | *.jar 22 | *.war 23 | *.ear 24 | 25 | # sbt 26 | cache/ 27 | .history/ 28 | .lib/ 29 | dist/ 30 | target/ 31 | lib_managed/ 32 | src_managed/ 33 | project/boot/ 34 | project/plugins/project/ 35 | 36 | .metals/ 37 | .bloop/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvrajan/tensorflow-keras-scala/69303dc7bcc8f0fdd50b4c78d42e27b18abe7e98/.vscode/launch.json -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "**/.classpath": true, 4 | "**/.project": true, 5 | "**/.settings": true, 6 | "**/.factorypath": true 7 | } 8 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvrajan/tensorflow-keras-scala/69303dc7bcc8f0fdd50b4c78d42e27b18abe7e98/.vscode/tasks.json -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Tensorflow Keras (Scala) 2 | ----- 3 | 4 | This repository contains a Scala wrapper for [Tensorflow Keras Java](https://github.com/dhruvrajan/tensorflow-keras-java); an implementation of 5 | the high-level Keras API for training deep learning models. Using the scala wrapper allows 6 | for a more concise representation of Keras model training, with syntax very close to that of 7 | the Python API. 8 | 9 | *Python* 10 | ```python 11 | import tensorflow as tf 12 | 13 | model = tf.keras.models.Sequential([ 14 | tf.keras.layers.Flatten(input_shape=(28, 28)), 15 | tf.keras.layers.Dense(128, activation='relu', kernel_initializer="random_normal", bias_initializer="zeros"), 16 | tf.keras.layers.Dense(10, activation='softmax', kernel_initializer="random_normal", bias_initializer="zeros") 17 | ]) 18 | 19 | model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) 20 | 21 | (X_train, y_train), (X_val, y_val) = tf.keras.datasets.load_mnist() 22 | model.fit(X_train, y_train, val_data=(X_val, y_val), epochs=10, batch_size=100) 23 | ``` 24 | 25 | *Scala*: 26 | ```scala 27 | package org.tensorflow.keras.scala.examples 28 | 29 | import java.lang.{Float => JFloat} 30 | 31 | import org.tensorflow.Graph 32 | import org.tensorflow.data.GraphLoader 33 | import org.tensorflow.keras.activations.Activations.{relu, softmax} 34 | import org.tensorflow.keras.datasets.FashionMNIST 35 | import org.tensorflow.keras.initializers.Initializers.{randomNormal, zeros} 36 | import org.tensorflow.keras.losses.Losses.sparseCategoricalCrossentropy 37 | import org.tensorflow.keras.metrics.Metrics.accuracy 38 | import org.tensorflow.keras.models.Sequential 39 | import org.tensorflow.keras.optimizers.Optimizers.sgd 40 | import org.tensorflow.keras.scala.Layers.{dense, flatten, input} 41 | import org.tensorflow.keras.scala.Model 42 | import org.tensorflow.op.Ops 43 | import org.tensorflow.utils.Pair 44 | 45 | import scala.util.Using 46 | 47 | object FashionMNISTKeras { 48 | 49 | val model: Model[JFloat] = Sequential.of[JFloat]( 50 | classOf[JFloat], 51 | input(28, 28), 52 | flatten(), 53 | dense(128, activation = relu, kernelInitializer = randomNormal, biasInitializer = zeros), 54 | dense(10, activation = softmax, kernelInitializer = randomNormal, biasInitializer = zeros) 55 | ) 56 | 57 | def train(model: Model[JFloat]): Model[JFloat] = { 58 | Using.resource(new Graph()) { graph => { 59 | implicit val tf: Ops = Ops.create(graph) 60 | model.compile(optimizer = sgd, loss = sparseCategoricalCrossentropy, metrics = List(accuracy)) 61 | 62 | val (trainLoader, testLoader): (GraphLoader[JFloat], GraphLoader[JFloat]) = FashionMNIST.graphLoaders2D() 63 | // GraphLoader objects contain AutoCloseable `Tensors`. 64 | Using.resources(trainLoader, testLoader) { (train, test) => { 65 | model.fit(train, test, epochs = 10, batchSize = 100) 66 | }} 67 | }} 68 | 69 | model 70 | } 71 | 72 | def main(args: Array[String]): Unit = { 73 | train(model.self) 74 | } 75 | } 76 | ``` -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "tensorflow-keras-scala" 2 | organization := "org.tensorflow" 3 | version := "0.0.1" 4 | scalaVersion := "2.13.1" 5 | scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature") 6 | 7 | lazy val root = project.in(file(".")).settings( 8 | resolvers += Resolver.mavenLocal, 9 | libraryDependencies ++= Seq( 10 | "org.scalatest" %% "scalatest" % "3.2.0-M1" % Test, 11 | "org.tensorflow" % "libtensorflow_jni" % "1.13.1", 12 | "org.tensorflow" % "tensorflow-keras" % "0.0.1" 13 | ) 14 | ) -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.3.2 -------------------------------------------------------------------------------- /src/main/scala/Main.scala: -------------------------------------------------------------------------------- 1 | //import org.tensorflow.Graph 2 | //import org.tensorflow.data.GraphLoader 3 | //import org.tensorflow.keras.activations.Activations.{relu, softmax} 4 | //import org.tensorflow.keras.datasets.FashionMNIST 5 | //import org.tensorflow.keras.losses.Losses.sparseCategoricalCrossentropy 6 | //import org.tensorflow.keras.metrics.Metrics.accuracy 7 | //import org.tensorflow.keras.models.Sequential 8 | //import org.tensorflow.keras.optimizers.Optimizers.sgd 9 | //import org.tensorflow.keras.scala.{Layers, Model} 10 | //import org.tensorflow.op.Ops 11 | //import org.tensorflow.utils.Pair 12 | // 13 | //import scala.util.Using 14 | // 15 | //object Main { 16 | // val model: Model[java.lang.Float] = Sequential.of( 17 | // Layers.input(28, 28), 18 | // Layers.flatten(), 19 | // Layers.dense(128, activation = relu), 20 | // Layers.dense(10, activation = softmax) 21 | // ) 22 | // 23 | // def train(model: Model[java.lang.Float]): Model[java.lang.Float] = { 24 | // Using.resource(new Graph()) { graph => { 25 | // val tf: Ops = Ops.create(graph) 26 | // model.compile(tf, optimizer = sgd, loss = sparseCategoricalCrossentropy, metrics = List(accuracy)) 27 | // 28 | // val data: Pair[GraphLoader[java.lang.Float], GraphLoader[java.lang.Float]] = FashionMNIST.graphLoaders2D() 29 | // // GraphLoader objects contain AutoCloseable `Tensors`. 30 | // Using.resources(data.first(), data.second()) { (train, test) => { 31 | // model.fit(tf, train, test, epochs = 10, batchSize = 100) 32 | // } 33 | // } 34 | // } 35 | // } 36 | // 37 | // model 38 | // } 39 | // 40 | // def main(args: Array[String]): Unit = { 41 | // train(model) 42 | // } 43 | //} 44 | -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Activation.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import org.tensorflow.keras.activations.{Activations, Activation => JActivation} 4 | 5 | import scala.language.implicitConversions 6 | 7 | 8 | case class Activation[T <: Number](self: JActivation[T]) 9 | 10 | object Activation { 11 | implicit def convert[T <: Number](a: Activations): Activation[T] = Activation(Activations.select(a)) 12 | 13 | implicit def convert[T <: Number](a: JActivation[T]): Activation[T] = Activation[T](a) 14 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Callback.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import scala.language.implicitConversions 4 | import org.tensorflow.keras.callbacks.{Callback => JCallback, Callbacks} 5 | 6 | case class Callback(self: JCallback) 7 | 8 | object Callback { 9 | implicit def convert(callback: Callbacks): Callback = Callback(Callbacks.select(callback)) 10 | implicit def convert(callback: JCallback): Callback = Callback(callback) 11 | } 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Initializer.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import scala.language.implicitConversions 4 | import org.tensorflow.keras.initializers.{Initializers, Initializer => JInitializer} 5 | 6 | case class Initializer(self: JInitializer) 7 | 8 | object Initializer { 9 | implicit def convert(a: Initializers): Initializer = Initializer(Initializers.select(a)) 10 | implicit def convert(a: JInitializer): Initializer = Initializer(a) 11 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Layer.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import org.tensorflow.{Operand, Shape => JShape} 4 | import org.tensorflow.keras.layers.{Layer => JLayer} 5 | import org.tensorflow.op.Ops 6 | 7 | case class Layer[T <: Number](self: JLayer[T]) { 8 | def build(inputShape: JShape, dtype: Class[T])(implicit tf: Ops) : Unit = { 9 | self.build(tf, inputShape, dtype) 10 | } 11 | 12 | def apply(inputs: Operand[T]*)(implicit tf: Ops): Operand[T] = { 13 | self.apply(tf, inputs: _*) 14 | } 15 | } 16 | 17 | object Layer { 18 | 19 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Layers.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | import org.tensorflow.keras.layers.{Dense, Flatten, Input, Layers => JLayers} 3 | 4 | 5 | object Layers { 6 | object defaults { 7 | lazy val dense: Dense.Options = Dense.Options.defaults() 8 | } 9 | 10 | def input[T <: Number](firstDim: Long, units: Long*): Input[T] = JLayers.input[T](firstDim, units: _*) 11 | 12 | def flatten[T <: Number](): Flatten[T] = JLayers.flatten[T] 13 | 14 | def dense[T <: Number](units: Int, 15 | activation: Activation[T] = defaults.dense.getActivation[T], 16 | kernelInitializer: Initializer = defaults.dense.getBiasInitializer, 17 | biasInitializer: Initializer = defaults.dense.getKernelInitializer): Dense[T] = { 18 | 19 | 20 | JLayers.dense[T](units, activation.self, kernelInitializer.self, biasInitializer.self) 21 | } 22 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Loss.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | 4 | 5 | import scala.language.implicitConversions 6 | import org.tensorflow.keras.losses.{Losses, Loss => JLoss} 7 | 8 | case class Loss(self: JLoss) 9 | 10 | object Loss { 11 | implicit def convert(a: Losses): Loss = Loss(Losses.select(a)) 12 | 13 | implicit def convertJLoss(a: JLoss): Loss = Loss(a) 14 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Metric.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import scala.language.implicitConversions 4 | import org.tensorflow.keras.metrics.{Metrics, Metric => JMetrc} 5 | 6 | case class Metric(self: JMetrc) 7 | 8 | object Metric { 9 | implicit def convert(a: Metrics): Metric = Metric(Metrics.select(a)) 10 | 11 | implicit def convert(a: JMetrc): Metric = Metric(a) 12 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Model.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import org.tensorflow.data.GraphLoader 4 | import org.tensorflow.keras.models.Model.{CompileOptions, FitOptions} 5 | import org.tensorflow.keras.models.{Model => JModel} 6 | import org.tensorflow.op.Ops 7 | 8 | import scala.jdk.javaapi.CollectionConverters 9 | import scala.language.implicitConversions 10 | 11 | case class Model[T <: java.lang.Number](self: JModel[T]) { 12 | 13 | object defaults { 14 | lazy val compile: CompileOptions = CompileOptions.defaults() 15 | lazy val fit: FitOptions = FitOptions.defaults() 16 | } 17 | 18 | def compile(optimizer: Optimizer[T] = defaults.compile.getOptimizer[T], 19 | loss: Loss = defaults.compile.getLoss, 20 | metrics: Seq[Metric] 21 | = CollectionConverters.asScala(defaults.compile.getMetrics).toSeq.map(Metric.convert)) 22 | (implicit tf: Ops) : Unit = { 23 | val compileOptionsBuilder: JModel.CompileOptions.Builder = JModel.CompileOptions.builder() 24 | .setOptimizer(optimizer.self) 25 | .setLoss(loss.self) 26 | 27 | metrics.foreach((metric: Metric) => compileOptionsBuilder.addMetric(metric.self)) 28 | self.compile(tf, compileOptionsBuilder.build()); 29 | } 30 | 31 | def fit(train: GraphLoader[T], test: GraphLoader[T], 32 | epochs: Int = defaults.fit.getEpochs, 33 | batchSize: Int = defaults.fit.getBatchSize, 34 | callbacks: Seq[Callback] = 35 | CollectionConverters.asScala(defaults.fit.getCallbacks).toSeq.map(Callback.convert)) 36 | (implicit tf: Ops): Unit = { 37 | 38 | 39 | val fitOptionsBuilder: JModel.FitOptions.Builder = JModel.FitOptions.builder() 40 | .setEpochs(epochs) 41 | .setBatchSize(batchSize) 42 | 43 | callbacks.foreach((callback: Callback) => fitOptionsBuilder.addCallback(callback.self)) 44 | 45 | self.fit(tf,train, test, fitOptionsBuilder.build()) 46 | } 47 | 48 | } 49 | 50 | object Model { 51 | implicit def convert[T <: java.lang.Number](a: JModel[T]): Model[T] = Model[T](a) 52 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Optimizer.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import org.tensorflow.keras.optimizers.{Optimizers, Optimizer => JOptimizer} 4 | 5 | import scala.language.implicitConversions 6 | 7 | case class Optimizer[T <: Number](self: JOptimizer[T]) 8 | 9 | object Optimizer { 10 | implicit def convert[T <: Number](a: Optimizers): Optimizer[T] = Optimizer[T](Optimizers.select(a)) 11 | 12 | implicit def convert[T <: Number](a: JOptimizer[T]): Optimizer[T] = Optimizer[T](a) 13 | } -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/Pair.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala 2 | 3 | import org.tensorflow.utils.{Pair => JPair} 4 | 5 | import scala.language.implicitConversions 6 | 7 | object Pair { 8 | implicit def convert[T, S](pair: JPair[T, S]): (T, S) = (pair.first(), pair.second()) 9 | } 10 | -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/examples/FashionMNISTKeras.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.keras.scala.examples 2 | 3 | import java.lang.{Float => JFloat} 4 | 5 | import org.tensorflow.Graph 6 | import org.tensorflow.data.GraphLoader 7 | import org.tensorflow.keras.activations.Activations.{relu, softmax, sigmoid} 8 | import org.tensorflow.keras.callbacks.Callbacks 9 | import org.tensorflow.keras.datasets.FashionMNIST 10 | import org.tensorflow.keras.initializers.Initializers.{randomNormal, zeros} 11 | import org.tensorflow.keras.losses.Losses.sparseCategoricalCrossentropy 12 | import org.tensorflow.keras.metrics.Metrics.accuracy 13 | import org.tensorflow.keras.models.Sequential 14 | import org.tensorflow.keras.optimizers.Optimizers.sgd 15 | import org.tensorflow.keras.scala.Layers.{dense, flatten, input} 16 | import org.tensorflow.keras.scala.Model 17 | import org.tensorflow.keras.scala.Pair._ 18 | import org.tensorflow.op.Ops 19 | 20 | import scala.util.Using 21 | 22 | object FashionMNISTKeras { 23 | 24 | val model: Model[JFloat] = Sequential.of[JFloat](classOf[JFloat], 25 | input(28, 28), 26 | flatten(), 27 | dense(128, activation = relu, kernelInitializer = randomNormal, biasInitializer = zeros), 28 | dense(10, activation = softmax, kernelInitializer = randomNormal, biasInitializer = zeros), 29 | ) 30 | 31 | def train(model: Model[JFloat]): Model[JFloat] = { 32 | Using.resource(new Graph()) { graph => { 33 | implicit val tf: Ops = Ops.create(graph) 34 | model.compile(optimizer = sgd, loss = sparseCategoricalCrossentropy, metrics = List(accuracy)) 35 | 36 | val (trainLoader, testLoader): (GraphLoader[JFloat], GraphLoader[JFloat]) = FashionMNIST.graphLoaders2D() 37 | // GraphLoader objects contain AutoCloseable `Tensors`. 38 | Using.resources(trainLoader, testLoader) { (train, test) => { 39 | model.fit(train, test, epochs = 10, batchSize = 100) 40 | }} 41 | }} 42 | 43 | model 44 | } 45 | 46 | def main(args: Array[String]): Unit = { 47 | train(model.self) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/org/tensorflow/keras/scala/examples/MNISTKeras.scala: -------------------------------------------------------------------------------- 1 | //package org.tensorflow.keras.scala.examples 2 | //import org.tensorflow.Graph 3 | //import org.tensorflow.data.GraphLoader 4 | //import org.tensorflow.keras.activations.Activations.{relu, softmax} 5 | //import org.tensorflow.keras.datasets.MNIST 6 | //import org.tensorflow.keras.losses.Losses.sparseCategoricalCrossentropy 7 | //import org.tensorflow.keras.metrics.Metrics.accuracy 8 | //import org.tensorflow.keras.models.Sequential 9 | //import org.tensorflow.keras.optimizers.Optimizers.sgd 10 | //import org.tensorflow.keras.scala.{Layers, Model} 11 | //import org.tensorflow.op.Ops 12 | //import org.tensorflow.utils.Pair 13 | // 14 | //import scala.util.Using 15 | // 16 | //object MNISTKeras { 17 | // val model: Model[java.lang.Float] = Sequential.of( 18 | // Layers.input(28, 28), 19 | // Layers.flatten(), 20 | // Layers.dense(128, activation = relu), 21 | // Layers.dense(10, activation = softmax) 22 | // ) 23 | // 24 | // def train(model: Model[java.lang.Float]): Model[java.lang.Float] = { 25 | // Using.resource(new Graph()) { graph => { 26 | // val tf: Ops = Ops.create(graph) 27 | // // Compile Model 28 | // model.compile(tf, optimizer = sgd, loss = sparseCategoricalCrossentropy, metrics = List(accuracy)) 29 | // 30 | // val data: Pair[GraphLoader[java.lang.Float], GraphLoader[java.lang.Float]] = MNIST.graphLoaders2D() 31 | // // GraphLoader objects contain AutoCloseable `Tensors`. 32 | // Using.resources(data.first(), data.second()) { (train, test) => { 33 | // // Fit Model 34 | // model.fit(tf, train, test, epochs = 10, batchSize = 100) 35 | // } 36 | // } 37 | // } 38 | // } 39 | // // Output Model 40 | // model 41 | // } 42 | // 43 | // def main(args: Array[String]): Unit = { 44 | // train(model) 45 | // } 46 | //} 47 | --------------------------------------------------------------------------------