├── .circleci
├── config.yml
└── images
│ ├── Dockerfile.linux-cpu-x86_64
│ ├── Dockerfile.linux-gpu-x86_64
│ ├── docker_build_linux-cpu-x86_64.sh
│ ├── docker_build_linux-gpu-x86_64.sh
│ └── install.sh
├── .gitattributes
├── .gitignore
├── .jvmopts
├── .scalafmt.conf
├── LICENSE
├── README.md
├── RELEASE.md
├── build.sbt
├── docs
├── images
│ ├── logo.afdesign
│ ├── logo.pdf
│ ├── logo.png
│ ├── logo.svg
│ ├── logo_square.afdesign
│ └── logo_square.png
└── src
│ └── main
│ ├── paradox
│ ├── assets
│ │ ├── custom.css
│ │ └── images
│ │ │ ├── afosr_logo.gif
│ │ │ ├── cmu_logo.svg
│ │ │ ├── favicon.ico
│ │ │ ├── logo.png
│ │ │ ├── nsf_logo.svg
│ │ │ ├── tensorboard_mnist_example_plot.png
│ │ │ ├── tensorflow_logo.png
│ │ │ ├── tensorflow_logo_square.svg
│ │ │ └── tensorflow_logo_wide.svg
│ ├── contributing.md
│ ├── guides.md
│ ├── guides
│ │ ├── adding_ops.md
│ │ ├── estimators.md
│ │ ├── graph_construction.md
│ │ └── tensors.md
│ ├── index.md
│ ├── installation.md
│ ├── release_notes.md
│ └── release_notes
│ │ ├── 0.1.0.md
│ │ ├── 0.1.1.md
│ │ ├── 0.2.0.md
│ │ ├── 0.2.1.md
│ │ ├── 0.2.2.md
│ │ ├── 0.2.3.md
│ │ ├── 0.2.4.md
│ │ ├── 0.3.0.md
│ │ ├── 0.4.0.md
│ │ ├── 0.4.1.md
│ │ ├── 0.5.0.md
│ │ └── 0.5.1.md
│ └── scala
│ ├── AddingOps.scala
│ ├── Estimators.scala
│ ├── Index.scala
│ ├── Tensors.scala
│ └── installation.sh
├── modules
├── api
│ └── src
│ │ ├── main
│ │ ├── resources
│ │ │ └── logback.xml
│ │ └── scala
│ │ │ └── org
│ │ │ └── platanios
│ │ │ └── tensorflow
│ │ │ └── api
│ │ │ ├── Documentation.scala
│ │ │ ├── config
│ │ │ ├── CheckpointConfig.scala
│ │ │ ├── ClusterConfig.scala
│ │ │ ├── SummaryConfig.scala
│ │ │ └── TensorBoardConfig.scala
│ │ │ ├── core
│ │ │ ├── DeviceSpecification.scala
│ │ │ ├── Devices.scala
│ │ │ ├── Graph.scala
│ │ │ ├── Implicits.scala
│ │ │ ├── Indexer.scala
│ │ │ ├── Logging.scala
│ │ │ ├── Shape.scala
│ │ │ ├── client
│ │ │ │ ├── FeedMap.scala
│ │ │ │ ├── Implicits.scala
│ │ │ │ ├── Session.scala
│ │ │ │ ├── SessionConfig.scala
│ │ │ │ └── Timeline.scala
│ │ │ ├── distributed
│ │ │ │ ├── Protocol.scala
│ │ │ │ ├── ReplicaDevicePlacer.scala
│ │ │ │ └── Server.scala
│ │ │ ├── package.scala
│ │ │ └── types
│ │ │ │ ├── DataType.scala
│ │ │ │ ├── Implicits.scala
│ │ │ │ └── package.scala
│ │ │ ├── implicits
│ │ │ ├── Implicits.scala
│ │ │ ├── helpers
│ │ │ │ ├── DataTypeStructure.scala
│ │ │ │ ├── DataTypeToOutput.scala
│ │ │ │ ├── DataTypeToShape.scala
│ │ │ │ ├── OpStructure.scala
│ │ │ │ ├── OutputStructure.scala
│ │ │ │ ├── OutputToDataType.scala
│ │ │ │ ├── OutputToShape.scala
│ │ │ │ ├── OutputToTensor.scala
│ │ │ │ ├── PlaceholderSupport.scala
│ │ │ │ ├── ShapeStructure.scala
│ │ │ │ ├── TensorStructure.scala
│ │ │ │ ├── TensorToDataType.scala
│ │ │ │ ├── TensorToOutput.scala
│ │ │ │ ├── TensorToShape.scala
│ │ │ │ ├── Zero.scala
│ │ │ │ └── package.scala
│ │ │ └── ops
│ │ │ │ ├── BasicImplicits.scala
│ │ │ │ ├── ClipImplicits.scala
│ │ │ │ ├── ControlFlowImplicits.scala
│ │ │ │ ├── EmbeddingImplicits.scala
│ │ │ │ ├── Implicits.scala
│ │ │ │ ├── MathImplicits.scala
│ │ │ │ ├── NNImplicits.scala
│ │ │ │ ├── SparseImplicits.scala
│ │ │ │ ├── StatisticsImplicits.scala
│ │ │ │ └── TextImplicits.scala
│ │ │ ├── io
│ │ │ ├── CheckpointReader.scala
│ │ │ ├── CompressionType.scala
│ │ │ ├── DirectoryLoader.scala
│ │ │ ├── FileIO.scala
│ │ │ ├── Loader.scala
│ │ │ ├── NPY.scala
│ │ │ ├── TFRecordReader.scala
│ │ │ ├── TFRecordWriter.scala
│ │ │ └── events
│ │ │ │ ├── EventAccumulator.scala
│ │ │ │ ├── EventFileReader.scala
│ │ │ │ ├── EventFileWriter.scala
│ │ │ │ ├── EventMultiplexer.scala
│ │ │ │ ├── EventPluginUtilities.scala
│ │ │ │ ├── EventRecord.scala
│ │ │ │ ├── EventType.scala
│ │ │ │ ├── SummaryFileWriter.scala
│ │ │ │ └── SummaryFileWriterCache.scala
│ │ │ ├── learn
│ │ │ ├── ClipGradients.scala
│ │ │ ├── Configuration.scala
│ │ │ ├── Counter.scala
│ │ │ ├── Implicits.scala
│ │ │ ├── Mode.scala
│ │ │ ├── Model.scala
│ │ │ ├── ModelInstance.scala
│ │ │ ├── SessionCreator.scala
│ │ │ ├── SessionManager.scala
│ │ │ ├── SessionScaffold.scala
│ │ │ ├── SessionWrapper.scala
│ │ │ ├── StopCriteria.scala
│ │ │ ├── estimators
│ │ │ │ ├── Estimator.scala
│ │ │ │ ├── FileBasedEstimator.scala
│ │ │ │ ├── InMemoryEstimator.scala
│ │ │ │ └── package.scala
│ │ │ ├── hooks
│ │ │ │ ├── CheckpointSaver.scala
│ │ │ │ ├── Evaluator.scala
│ │ │ │ ├── Hook.scala
│ │ │ │ ├── HookTrigger.scala
│ │ │ │ ├── LossLogger.scala
│ │ │ │ ├── ModelDependentHook.scala
│ │ │ │ ├── NaNChecker.scala
│ │ │ │ ├── StepRateLogger.scala
│ │ │ │ ├── Stopper.scala
│ │ │ │ ├── SummarySaver.scala
│ │ │ │ ├── SummaryWriterHookAddOn.scala
│ │ │ │ ├── TensorBoardHook.scala
│ │ │ │ ├── TensorLogger.scala
│ │ │ │ ├── TimelineHook.scala
│ │ │ │ ├── TriggeredHook.scala
│ │ │ │ └── package.scala
│ │ │ ├── layers
│ │ │ │ ├── Activation.scala
│ │ │ │ ├── Basic.scala
│ │ │ │ ├── Embedding.scala
│ │ │ │ ├── Input.scala
│ │ │ │ ├── Layer.scala
│ │ │ │ ├── Loss.scala
│ │ │ │ ├── Math.scala
│ │ │ │ ├── NN.scala
│ │ │ │ ├── Summary.scala
│ │ │ │ ├── core
│ │ │ │ │ └── package.scala
│ │ │ │ ├── package.scala
│ │ │ │ └── rnn
│ │ │ │ │ ├── BidirectionalRNN.scala
│ │ │ │ │ ├── RNN.scala
│ │ │ │ │ ├── cell
│ │ │ │ │ ├── BasicLSTMCell.scala
│ │ │ │ │ ├── BasicRNNCell.scala
│ │ │ │ │ ├── DeviceWrapper.scala
│ │ │ │ │ ├── DropoutWrapper.scala
│ │ │ │ │ ├── GRUCell.scala
│ │ │ │ │ ├── LSTMCell.scala
│ │ │ │ │ ├── RNNCell.scala
│ │ │ │ │ ├── ResidualWrapper.scala
│ │ │ │ │ ├── StackedCell.scala
│ │ │ │ │ └── package.scala
│ │ │ │ │ └── package.scala
│ │ │ ├── models
│ │ │ │ └── RBM.scala
│ │ │ └── package.scala
│ │ │ ├── ops
│ │ │ ├── Callback.scala
│ │ │ ├── Cast.scala
│ │ │ ├── Checks.scala
│ │ │ ├── Clip.scala
│ │ │ ├── DataFlow.scala
│ │ │ ├── Documentation.scala
│ │ │ ├── Embedding.scala
│ │ │ ├── Files.scala
│ │ │ ├── Function.scala
│ │ │ ├── Gradients.scala
│ │ │ ├── Image.scala
│ │ │ ├── Input.scala
│ │ │ ├── Logging.scala
│ │ │ ├── NN.scala
│ │ │ ├── Op.scala
│ │ │ ├── OpSpecification.scala
│ │ │ ├── Output.scala
│ │ │ ├── OutputOps.scala
│ │ │ ├── Parsing.scala
│ │ │ ├── Queue.scala
│ │ │ ├── Random.scala
│ │ │ ├── Resources.scala
│ │ │ ├── Sets.scala
│ │ │ ├── Slot.scala
│ │ │ ├── Sparse.scala
│ │ │ ├── Statistics.scala
│ │ │ ├── Summary.scala
│ │ │ ├── TensorArray.scala
│ │ │ ├── Text.scala
│ │ │ ├── basic
│ │ │ │ ├── Basic.scala
│ │ │ │ ├── Constructors.scala
│ │ │ │ ├── Inplace.scala
│ │ │ │ ├── Manipulation.scala
│ │ │ │ └── Masking.scala
│ │ │ ├── control_flow
│ │ │ │ ├── CondContext.scala
│ │ │ │ ├── Context.scala
│ │ │ │ ├── ControlFlow.scala
│ │ │ │ ├── GradientLoopState.scala
│ │ │ │ ├── GradientState.scala
│ │ │ │ ├── WhileLoopContext.scala
│ │ │ │ └── package.scala
│ │ │ ├── data
│ │ │ │ ├── Data.scala
│ │ │ │ ├── Dataset.scala
│ │ │ │ ├── DatasetIterator.scala
│ │ │ │ ├── Experimental.scala
│ │ │ │ └── package.scala
│ │ │ ├── lookup
│ │ │ │ ├── IDLookupTableWithHashBuckets.scala
│ │ │ │ ├── Lookup.scala
│ │ │ │ ├── LookupTable.scala
│ │ │ │ ├── LookupTableInitializer.scala
│ │ │ │ ├── LookupTableTensorInitializer.scala
│ │ │ │ ├── LookupTableTextFileInitializer.scala
│ │ │ │ └── package.scala
│ │ │ ├── math
│ │ │ │ ├── Bitwise.scala
│ │ │ │ └── Math.scala
│ │ │ ├── metrics
│ │ │ │ ├── Accuracy.scala
│ │ │ │ ├── ConfusionMatrix.scala
│ │ │ │ ├── GroupedPrecision.scala
│ │ │ │ ├── MapMetric.scala
│ │ │ │ ├── Mean.scala
│ │ │ │ ├── Metric.scala
│ │ │ │ ├── PrecisionAtK.scala
│ │ │ │ └── package.scala
│ │ │ ├── package.scala
│ │ │ ├── rnn
│ │ │ │ ├── RNN.scala
│ │ │ │ ├── attention
│ │ │ │ │ ├── Attention.scala
│ │ │ │ │ ├── AttentionWrapperCell.scala
│ │ │ │ │ ├── BahdanauAttention.scala
│ │ │ │ │ ├── LuongAttention.scala
│ │ │ │ │ └── package.scala
│ │ │ │ ├── cell
│ │ │ │ │ ├── BasicLSTMCell.scala
│ │ │ │ │ ├── BasicRNNCell.scala
│ │ │ │ │ ├── DeviceWrapper.scala
│ │ │ │ │ ├── DropoutWrapper.scala
│ │ │ │ │ ├── GRUCell.scala
│ │ │ │ │ ├── LSTMCell.scala
│ │ │ │ │ ├── RNNCell.scala
│ │ │ │ │ ├── ResidualWrapper.scala
│ │ │ │ │ ├── StackedCell.scala
│ │ │ │ │ └── package.scala
│ │ │ │ └── package.scala
│ │ │ ├── training
│ │ │ │ ├── ExponentialMovingAverage.scala
│ │ │ │ ├── optimizers
│ │ │ │ │ ├── AMSGrad.scala
│ │ │ │ │ ├── AdaDelta.scala
│ │ │ │ │ ├── AdaGrad.scala
│ │ │ │ │ ├── Adafactor.scala
│ │ │ │ │ ├── Adam.scala
│ │ │ │ │ ├── GradientDescent.scala
│ │ │ │ │ ├── LazyAMSGrad.scala
│ │ │ │ │ ├── LazyAdam.scala
│ │ │ │ │ ├── Optimizer.scala
│ │ │ │ │ ├── RMSProp.scala
│ │ │ │ │ ├── YellowFin.scala
│ │ │ │ │ ├── Yogi.scala
│ │ │ │ │ ├── package.scala
│ │ │ │ │ └── schedules
│ │ │ │ │ │ ├── ComposedSchedule.scala
│ │ │ │ │ │ ├── CosineDecay.scala
│ │ │ │ │ │ ├── CycleLinear10xDecay.scala
│ │ │ │ │ │ ├── ExponentialDecay.scala
│ │ │ │ │ │ ├── FixedSchedule.scala
│ │ │ │ │ │ ├── RSqrtDecay.scala
│ │ │ │ │ │ ├── Schedule.scala
│ │ │ │ │ │ ├── WarmUpExponentialSchedule.scala
│ │ │ │ │ │ ├── WarmUpLinearSchedule.scala
│ │ │ │ │ │ └── package.scala
│ │ │ │ └── package.scala
│ │ │ └── variables
│ │ │ │ ├── Initializer.scala
│ │ │ │ ├── Regularizer.scala
│ │ │ │ ├── Reuse.scala
│ │ │ │ ├── Saver.scala
│ │ │ │ ├── Variable.scala
│ │ │ │ ├── VariableLike.scala
│ │ │ │ ├── VariableScope.scala
│ │ │ │ ├── VariableScopeStore.scala
│ │ │ │ ├── VariableStore.scala
│ │ │ │ └── package.scala
│ │ │ ├── package.scala
│ │ │ ├── tensors
│ │ │ ├── Context.scala
│ │ │ ├── Implicits.scala
│ │ │ ├── Tensor.scala
│ │ │ ├── TensorOps.scala
│ │ │ ├── ops
│ │ │ │ ├── Basic.scala
│ │ │ │ ├── Cast.scala
│ │ │ │ ├── Math.scala
│ │ │ │ ├── NN.scala
│ │ │ │ ├── Random.scala
│ │ │ │ └── package.scala
│ │ │ └── package.scala
│ │ │ └── utilities
│ │ │ ├── ByteCodable.scala
│ │ │ ├── CRC32C.scala
│ │ │ ├── Coding.scala
│ │ │ ├── Collections.scala
│ │ │ ├── DefaultsTo.scala
│ │ │ ├── Disposer.scala
│ │ │ ├── NativeHandleWrapper.scala
│ │ │ ├── Proto.scala
│ │ │ ├── Reservoir.scala
│ │ │ └── package.scala
│ │ └── test
│ │ └── scala
│ │ └── org
│ │ └── platanios
│ │ └── tensorflow
│ │ └── api
│ │ ├── core
│ │ ├── DataTypeSpec.scala
│ │ ├── DeviceSpecificationSpec.scala
│ │ ├── GraphSpec.scala
│ │ ├── IndexerSpec.scala
│ │ ├── SessionSpec.scala
│ │ ├── ShapeSpec.scala
│ │ └── client
│ │ │ ├── FeedableSuite.scala
│ │ │ └── FetchableSuite.scala
│ │ ├── implicits
│ │ └── helpers
│ │ │ └── OpStructureSuite.scala
│ │ ├── io
│ │ ├── DirectoryLoaderSuite.scala
│ │ └── events
│ │ │ └── EventFileReaderSuite.scala
│ │ ├── ops
│ │ ├── BasicSpec.scala
│ │ ├── CallbackSuite.scala
│ │ ├── FunctionSuite.scala
│ │ ├── GradientsSuite.scala
│ │ ├── NNSpec.scala
│ │ ├── OpSpec.scala
│ │ ├── TextSuite.scala
│ │ ├── VariableSpec.scala
│ │ ├── control_flow
│ │ │ └── ControlFlowSuite.scala
│ │ ├── data
│ │ │ ├── DatasetSuite.scala
│ │ │ └── FilterDatasetSuite.scala
│ │ └── training
│ │ │ └── optimizers
│ │ │ └── GradientDescentSpec.scala
│ │ ├── tensors
│ │ └── TensorSuite.scala
│ │ └── utilities
│ │ ├── CRC32CSuite.scala
│ │ └── ReservoirSuite.scala
├── data
│ └── src
│ │ ├── main
│ │ ├── resources
│ │ │ └── logback.xml
│ │ └── scala
│ │ │ └── org
│ │ │ └── platanios
│ │ │ └── tensorflow
│ │ │ └── data
│ │ │ ├── Loader.scala
│ │ │ ├── XCLoader.scala
│ │ │ ├── image
│ │ │ ├── CIFARLoader.scala
│ │ │ ├── MNISTLoader.scala
│ │ │ └── STL10Loader.scala
│ │ │ ├── models
│ │ │ └── ObjectDetectionModelLoader.scala
│ │ │ ├── text
│ │ │ └── PTBLoader.scala
│ │ │ └── utilities
│ │ │ ├── CompressedFiles.scala
│ │ │ └── Split.scala
│ │ └── test
│ │ └── scala
│ │ └── org
│ │ └── platanios
│ │ └── tensorflow
│ │ └── data
│ │ └── image
│ │ └── MNISTLoaderSpec.scala
├── examples
│ └── src
│ │ └── main
│ │ ├── resources
│ │ ├── logback.xml
│ │ └── python2scala
│ │ │ ├── MetaGraphDef.txt
│ │ │ ├── checkpoint
│ │ │ ├── linear-regression.data-00000-of-00001
│ │ │ ├── linear-regression.index
│ │ │ ├── linear-regression.meta
│ │ │ ├── linearRegression.py
│ │ │ ├── virgin-linear-regression.data-00000-of-00001
│ │ │ ├── virgin-linear-regression.index
│ │ │ └── virgin-linear-regression.meta
│ │ └── scala
│ │ └── org
│ │ └── platanios
│ │ └── tensorflow
│ │ └── examples
│ │ ├── CIFAR.scala
│ │ ├── LinearRegression.scala
│ │ ├── MNIST.scala
│ │ ├── RNNTutorialUsingPTB.scala
│ │ ├── STL10.scala
│ │ ├── package.scala
│ │ └── python2scala
│ │ └── LinearRegressionFromRestoredPythonModel.scala
├── jni
│ └── src
│ │ ├── main
│ │ ├── native
│ │ │ ├── CMakeLists.txt
│ │ │ ├── c_api_internal.cc
│ │ │ ├── checkpoint_reader.cc
│ │ │ ├── checkpoint_reader.h
│ │ │ ├── checkpoint_reader_internal.cc
│ │ │ ├── checkpoint_reader_internal.h
│ │ │ ├── exception.h
│ │ │ ├── function.cc
│ │ │ ├── function.h
│ │ │ ├── generated
│ │ │ │ ├── tensor_basic_ops.cc
│ │ │ │ ├── tensor_basic_ops.h
│ │ │ │ ├── tensor_math_ops.cc
│ │ │ │ ├── tensor_math_ops.h
│ │ │ │ ├── tensor_nn_ops.cc
│ │ │ │ ├── tensor_nn_ops.h
│ │ │ │ ├── tensor_random_ops.cc
│ │ │ │ ├── tensor_random_ops.h
│ │ │ │ ├── tensor_sparse_ops.cc
│ │ │ │ ├── tensor_sparse_ops.h
│ │ │ │ ├── tensor_text_ops.cc
│ │ │ │ └── tensor_text_ops.h
│ │ │ ├── graph.cc
│ │ │ ├── graph.h
│ │ │ ├── include
│ │ │ │ └── tensorflow
│ │ │ │ │ ├── c
│ │ │ │ │ ├── c_api.h
│ │ │ │ │ ├── c_api_experimental.h
│ │ │ │ │ ├── c_api_macros.h
│ │ │ │ │ ├── eager
│ │ │ │ │ │ ├── c_api.h
│ │ │ │ │ │ ├── c_api_experimental.h
│ │ │ │ │ │ └── dlpack.h
│ │ │ │ │ ├── tensor_interface.h
│ │ │ │ │ ├── tf_attrtype.h
│ │ │ │ │ ├── tf_datatype.h
│ │ │ │ │ ├── tf_file_statistics.h
│ │ │ │ │ ├── tf_status.h
│ │ │ │ │ ├── tf_tensor.h
│ │ │ │ │ └── tf_tstring.h
│ │ │ │ │ └── core
│ │ │ │ │ └── platform
│ │ │ │ │ ├── ctstring.h
│ │ │ │ │ └── ctstring_internal.h
│ │ │ ├── op.cc
│ │ │ ├── op.h
│ │ │ ├── ops
│ │ │ │ ├── beam_search_ops.cc
│ │ │ │ ├── beam_search_ops.h
│ │ │ │ ├── beam_search_ops_gpu.cu.cc
│ │ │ │ ├── jvm_callback_op.cc
│ │ │ │ └── jvm_callback_op.h
│ │ │ ├── python_api.cc
│ │ │ ├── python_api.h
│ │ │ ├── server.cc
│ │ │ ├── server.h
│ │ │ ├── session.cc
│ │ │ ├── session.h
│ │ │ ├── tensor.cc
│ │ │ ├── tensor.h
│ │ │ ├── tensorflow.cc
│ │ │ ├── tensorflow.h
│ │ │ └── utilities.h
│ │ ├── resources
│ │ │ ├── logback.xml
│ │ │ └── ops.pbtxt
│ │ └── scala
│ │ │ └── org
│ │ │ └── platanios
│ │ │ └── tensorflow
│ │ │ └── jni
│ │ │ ├── CheckpointReader.scala
│ │ │ ├── Function.scala
│ │ │ ├── Graph.scala
│ │ │ ├── Op.scala
│ │ │ ├── ScalaCallbacksRegistry.scala
│ │ │ ├── Server.scala
│ │ │ ├── Session.scala
│ │ │ ├── Tensor.scala
│ │ │ ├── TensorFlow.scala
│ │ │ ├── TensorFlowException.scala
│ │ │ └── generated
│ │ │ └── tensors
│ │ │ ├── Basic.scala
│ │ │ ├── Math.scala
│ │ │ ├── NN.scala
│ │ │ ├── Random.scala
│ │ │ ├── Sparse.scala
│ │ │ └── Text.scala
│ │ └── test
│ │ └── scala
│ │ └── org
│ │ └── platanios
│ │ └── tensorflow
│ │ └── jni
│ │ └── TensorFlowSpec.scala
└── proto
│ └── src
│ └── main
│ └── proto
│ ├── allocation_description.proto
│ ├── any.proto
│ ├── api_def.proto
│ ├── attr_value.proto
│ ├── autotuning.proto
│ ├── bfc_memory_map.proto
│ ├── checkpoint_state.proto
│ ├── cluster.proto
│ ├── config.proto
│ ├── control_flow.proto
│ ├── cost_graph.proto
│ ├── critical_section.proto
│ ├── debug.proto
│ ├── debug_event.proto
│ ├── device_attributes.proto
│ ├── device_filters.proto
│ ├── device_properties.proto
│ ├── duration.proto
│ ├── eager_service.proto
│ ├── error_codes.proto
│ ├── event.proto
│ ├── example.proto
│ ├── feature.proto
│ ├── function.proto
│ ├── graph.proto
│ ├── graph_debug_info.proto
│ ├── graph_transfer_info.proto
│ ├── kernel_def.proto
│ ├── log_memory.proto
│ ├── master.proto
│ ├── master_service.proto
│ ├── meta_graph.proto
│ ├── named_tensor.proto
│ ├── node_def.proto
│ ├── op_def.proto
│ ├── queue_runner.proto
│ ├── reader_base.proto
│ ├── remote_tensor_handle.proto
│ ├── replay_log.proto
│ ├── resource_handle.proto
│ ├── rewriter_config.proto
│ ├── saved_model.proto
│ ├── saved_object_graph.proto
│ ├── saver.proto
│ ├── source_context.proto
│ ├── step_stats.proto
│ ├── struct.proto
│ ├── summary.proto
│ ├── tensor.proto
│ ├── tensor_bundle.proto
│ ├── tensor_description.proto
│ ├── tensor_shape.proto
│ ├── tensor_slice.proto
│ ├── tensorflow_server.proto
│ ├── trackable_object_graph.proto
│ ├── transport_options.proto
│ ├── type.proto
│ ├── types.proto
│ ├── variable.proto
│ ├── verifier_config.proto
│ ├── versions.proto
│ ├── worker.proto
│ └── worker_service.proto
├── project
├── BuildTool.scala
├── JniCrossPackage.scala
├── JniJavah.scala
├── JniNative.scala
├── JniPackage.scala
├── OpGenerator.scala
├── TensorFlowGenerateTensorOps.scala
├── TensorFlowNativePackage.scala
├── build.properties
└── plugins.sbt
└── version.sbt
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | jobs:
4 | build-linux-x86_64:
5 | docker:
6 | - image: eaplatanios/tensorflow_scala:linux-cpu-x86_64-0.6.0
7 | working_directory: ~/repository
8 | environment:
9 | LD_LIBRARY_PATH: /usr/lib:$LD_LIBRARY_PATH
10 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64
11 | JVM_OPTS: -Xmx3200m
12 | TERM: dumb
13 | steps:
14 | - checkout
15 | - run: mkdir downloads && cd downloads
16 | - run: wget https://oss.sonatype.org/service/local/repositories/snapshots/content/org/platanios/tensorflow_2.13/0.6.0-SNAPSHOT/tensorflow_2.13-0.6.0-SNAPSHOT-linux.jar
17 | - run: jar xf tensorflow_2.13-0.6.0-SNAPSHOT-linux.jar
18 | - run: mv libtensorflow.so /usr/lib/libtensorflow.so
19 | - run: mv libtensorflow_framework.so /usr/lib/libtensorflow_framework.so
20 | - run: ln -s /usr/lib/libtensorflow.so /usr/lib/libtensorflow.so.2
21 | - run: ln -s /usr/lib/libtensorflow.so /usr/lib/libtensorflow.so.2.4.0
22 | - run: ln -s /usr/lib/libtensorflow_framework.so /usr/lib/libtensorflow_framework.so.2
23 | - run: ln -s /usr/lib/libtensorflow_framework.so /usr/lib/libtensorflow_framework.so.2.4.0
24 | - run: cd ..
25 | - restore_cache:
26 | keys:
27 | - v0.5.1-dependencies-{{ checksum "build.sbt" }}
28 | - run: cat /dev/null | sbt +test:compile
29 | - save_cache:
30 | paths:
31 | - ~/.ivy2
32 | key: v0.5.1-dependencies--{{ checksum "build.sbt" }}
33 | - run: cat /dev/null | sbt +test:test
34 |
35 | workflows:
36 | version: 2
37 | test:
38 | jobs:
39 | - build-linux-x86_64
40 |
--------------------------------------------------------------------------------
/.circleci/images/Dockerfile.linux-cpu-x86_64:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 |
3 | # Copy and run the install script.
4 | COPY .circleci/images/install.sh /install.sh
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | RUN /install.sh
7 |
8 | # Set up MPI
9 | ENV TF_NEED_MPI 1
10 |
11 | ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
12 |
13 | # COPY . /tensorflow_scala
14 |
--------------------------------------------------------------------------------
/.circleci/images/Dockerfile.linux-gpu-x86_64:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04
2 |
3 | # In the Ubuntu 18.04 images, cudnn is placed in system paths. Move them to
4 | # /usr/local/cuda
5 | RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include
6 | RUN cp -P /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64
7 |
8 | # Copy and run the install script.
9 | COPY .circleci/images/install.sh /install.sh
10 | ARG DEBIAN_FRONTEND=noninteractive
11 | RUN /install.sh
12 |
13 | # Set up MPI
14 | ENV TF_NEED_MPI 1
15 |
16 | # Set up the master bazelrc configuration file.
17 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
18 |
19 | # Configure the build for our CUDA configuration.
20 | ENV TF_NEED_CUDA 1
21 | ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
22 |
23 | ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64
24 |
25 | # COPY . /tensorflow_scala
26 |
--------------------------------------------------------------------------------
/.circleci/images/docker_build_linux-cpu-x86_64.sh:
--------------------------------------------------------------------------------
1 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
2 | DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/../../)"
3 |
4 | DOCKER_IMAGE="eaplatanios/tensorflow_scala:linux-cpu-x86_64-0.6.0"
5 | DOCKER_FILE=".circleci/images/Dockerfile.linux-cpu-x86_64"
6 |
7 | docker build \
8 | -t "${DOCKER_IMAGE}" \
9 | -f "${DOCKER_CONTEXT_PATH}/${DOCKER_FILE}" \
10 | "${DOCKER_CONTEXT_PATH}"
11 |
--------------------------------------------------------------------------------
/.circleci/images/docker_build_linux-gpu-x86_64.sh:
--------------------------------------------------------------------------------
1 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
2 | DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/../../)"
3 |
4 | DOCKER_IMAGE="eaplatanios/tensorflow_scala:linux-gpu-x86_64-0.6.0"
5 | DOCKER_FILE=".circleci/images/Dockerfile.linux-gpu-x86_64"
6 |
7 | docker build \
8 | -t "${DOCKER_IMAGE}" \
9 | -f "${DOCKER_CONTEXT_PATH}/${DOCKER_FILE}" \
10 | "${DOCKER_CONTEXT_PATH}"
11 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | modules/jni/src/main/native/include/**/* linguist-vendored=true
2 | modules/jni/src/main/native/generated/**/* linguist-generated=true
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 |
3 | # Project Specific
4 | temp/
5 | datasets/
6 |
7 | # Scala Specific
8 | *.class
9 | *.log
10 |
11 | # SBT Specific
12 | .cache
13 | .history
14 | .lib/
15 | dist/*
16 | target/
17 | lib_managed/
18 | src_managed/
19 | project/boot/
20 | project/plugins/project/
21 |
22 | # Bloop Specific
23 | .bloop
24 | .bsp
25 |
26 | # Metals Specific
27 | .metals
28 |
29 | # Hydra Specific
30 | .hydra
31 |
32 | **/src/main/native/lib/**/*
33 | **/src/main/generated/**/*
34 |
35 | # IntelliJ Specific
36 | **/.idea/**/*
37 |
38 | # CLion Specific
39 | **/cmake-build-debug/**/*
40 |
41 | # VS Code Specific
42 | **/.vscode/**/*
43 |
--------------------------------------------------------------------------------
/.jvmopts:
--------------------------------------------------------------------------------
1 | -Dfile.encoding=UTF8
2 | -Xms1G
3 | -Xmx4G
4 | -Xss8M
5 | -XX:ReservedCodeCacheSize=512M
6 | -XX:MaxDirectMemorySize=16G
7 | -XX:MaxMetaspaceSize=1G
8 |
--------------------------------------------------------------------------------
/docs/images/logo.afdesign:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.afdesign
--------------------------------------------------------------------------------
/docs/images/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.pdf
--------------------------------------------------------------------------------
/docs/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.png
--------------------------------------------------------------------------------
/docs/images/logo_square.afdesign:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo_square.afdesign
--------------------------------------------------------------------------------
/docs/images/logo_square.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo_square.png
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/afosr_logo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/afosr_logo.gif
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/favicon.ico
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/logo.png
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/tensorboard_mnist_example_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/tensorboard_mnist_example_plot.png
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/tensorflow_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/tensorflow_logo.png
--------------------------------------------------------------------------------
/docs/src/main/paradox/assets/images/tensorflow_logo_square.svg:
--------------------------------------------------------------------------------
1 |
2 |
50 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/contributing.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | It would be awesome if people could contribute to this library. Given
4 | its scope and its early state, before I settle on the API for some of
5 | the features, I would really appreciate contributions on the following:
6 |
7 | - **Unit Tests:** Currently unit tests are missing for a big part of
8 | the library and it would be extremely useful if we had those.
9 | - **Examples:** Examples of code using the library would be great and
10 | would also make issues come up early so they can be fixed.
11 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/guides/adding_ops.md:
--------------------------------------------------------------------------------
1 | # Adding Support for New Ops
2 |
3 | TensorFlow graphs are constructed by buildings ops that
4 | receive tensors as input and produce tensors as output.
5 | Internally, multiple *kernels* are registered for each op,
6 | which are implementations of the op for different
7 | architectures (e.g., CPU kernels, CUDA GPU kernels, etc.).
8 | TensorFlow Scala offers the
9 | @scaladoc[Op.Builder](org.platanios.tensorflow.api.ops.Op.Builder)
10 | interface to allow users to create arbitrary ops that the
11 | TensorFlow runtime supports.
12 |
13 | For example, the implentation of `tf.add(x, y)` in
14 | TensorFlow Scala looks like this:
15 |
16 | @@snip [AddingOps.scala](/docs/src/main/scala/AddingOps.scala) { #add_op_example }
17 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/guides/graph_construction.md:
--------------------------------------------------------------------------------
1 | # Graph Construction
2 |
3 | The low level API can be used to define computations that will be executed at a later point, and potentially execute
4 | them. It can also be used to create custom layers for the [Learn API](#neural-networks-2). The main type of object
5 | underlying the low level API is the [`Output`][output], which represents the value of a [`Tensor`][tensor] that has not
6 | yet been computed. Its name comes from the fact that it represents the *output* of some computation. An
7 | [`Output`][output] object thus represents a partially defined computation that will eventually produce a value. Core
8 | TensorFlow programs work by first building a graph of [`Output`][output] objects, detailing how each output is computed
9 | based on the other available outputs, and then by running parts of this graph to achieve the desired results.
10 |
11 | Similar to a [`Tensor`][tensor], each element in an [`Output`][output] has the same data type, and the data type is
12 | always known. However, the shape of an [`Output`][output] might be only partially known. Most operations produce tensors
13 | of fully-known shapes if the shapes of their inputs are also fully known, but in some cases it's only possible to find
14 | the shape of a tensor at graph execution time.
15 |
16 | It is important to understand the main concepts underlying the core API:
17 |
18 | - **Tensor:**
19 | - **Output:**
20 | - **Sparse Output:**
21 | - **Placeholder:**
22 | - **Variable:**
23 | - **Graph:**
24 | - **Session:**
25 |
26 | With the exception of [`Variable`][variable]s, the value of outputs is immutable, which means that in the context of a
27 | single execution, outputs only have a single value. However, evaluating the same output twice can result in different
28 | values. For example, that tensor may be the result of reading data from disk, or generating a random number.
29 |
30 | ## Graph
31 |
32 |
33 | ## Working with Outputs
34 |
35 |
36 | ### Evaluating Outputs
37 |
38 |
39 | ### Printing Outputs
40 |
41 |
42 | ### Logging
43 |
44 | Logging in the native TensorFlow library can be controlled by setting the `TF_CPP_MIN_LOG_LEVEL` environment variable:
45 |
46 | - `0`: Debug level (default).
47 | - `1`: Warning level.
48 | - `2`: Error level.
49 | - `3`: Fatal level.
50 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes.md:
--------------------------------------------------------------------------------
1 | # Release Notes
2 |
3 | @@toc { depth=1 }
4 |
5 | @@@ index
6 |
7 | - [0.5.1](release_notes/0.5.1.md)
8 | - [0.5.0](release_notes/0.5.0.md)
9 | - [0.4.1](release_notes/0.4.1.md)
10 | - [0.4.0](release_notes/0.4.0.md)
11 | - [0.3.0](release_notes/0.3.0.md)
12 | - [0.2.4](release_notes/0.2.4.md)
13 | - [0.2.3](release_notes/0.2.3.md)
14 | - [0.2.2](release_notes/0.2.2.md)
15 | - [0.2.1](release_notes/0.2.1.md)
16 | - [0.2.0](release_notes/0.2.0.md)
17 | - [0.1.1](release_notes/0.1.1.md)
18 | - [0.1.0](release_notes/0.1.0.md)
19 |
20 | @@@
21 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.1.0.md:
--------------------------------------------------------------------------------
1 | # Release 0.1.0
2 |
3 | This is the first official release of TensorFlow for Scala. The library
4 | website will soon be updated with information about the functionality
5 | supported by this API. Most of the main TensorFlow Python API
6 | functionality is already supported.
7 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.1.1.md:
--------------------------------------------------------------------------------
1 | # Release 0.1.1
2 |
3 | This release fixes the following bugs:
4 |
5 | - Issue with the packaged pre-compiled TensorFlow binaries that
6 | affected Linux platforms.
7 | - Learn API bug where the shared name of input iterators was being
8 | set incorrectly.
9 |
10 | I also switched to using CircleCI for continuous integration, instead
11 | of TravisCI.
12 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.2.0.md:
--------------------------------------------------------------------------------
1 | # Release 0.2.0
2 |
3 | In this release we have:
4 |
5 | - Added support for incremental compilation.
6 | - Added support for [Horovod](https://github.com/uber/horovod).
7 | - Added support for timelines to allow for easy profiling of
8 | TensorFlow graphs.
9 | - Fixed a major memory leak (issue #87).
10 | - Updated the JNI bindings to be compatible with the TensorFlow
11 | 1.9.0 release.
12 | - Added support for obtaining the list of available devices from
13 | within Scala.
14 | - Fixed bugs for some control flow ops.
15 | - Added support for `tf.cases`.
16 | - Added support for the RMSProp optimizer, the lazy Adam optimizer,
17 | the [AMSGrad](https://openreview.net/pdf?id=ryQu7f-RZ) optimizer,
18 | the lazy AMSGrad optimizer, and the
19 | [YellowFin](https://arxiv.org/pdf/1706.03471.pdf) optimizer.
20 | - Added more learning rate decay schemes:
21 | - Cosine decay.
22 | - Cycle-linear 10x decay.
23 | - Square-root decay.
24 | - More warm-up decay schedules.
25 | - Added support for dataset interleave ops.
26 | - Fixed some bugs related to variable scopes and variable sharing.
27 | - Fixed some bugs related to functional ops.
28 | - Added support for some new image-related ops, under the namespace
29 | `tf.image`.
30 | - Improved consistency for the creation of initializer ops.
31 | - Added support for the `tf.initializer` op creation context.
32 | - Exposed part of the `TensorArray` API.
33 | - Exposed `tf.Op.Builder` in the public API.
34 | - Improvements to the learn API:
35 | - Refactored `mode` into an implicit argument.
36 | - Improved the evaluator hook.
37 | - Removed the layer creation context mechanism, to be refactored
38 | later. It was causing some issues due to bad design and unclear
39 | semantics. The plan is to implement this, in the near future, as
40 | wrapper creation context layers.
41 | - Improved the `Model` class.
42 | - Fixed a bug that was causing some issues related to inference
43 | hooks in the in-memory estimator.
44 | - Improved logging.
45 | - Added support for reading and writing numpy (i.e., `.npy`) files.
46 | - Added a logo. :)
47 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.2.1.md:
--------------------------------------------------------------------------------
1 | # Release 0.2.1
2 |
3 | In this release we have fixed an issue related to the packaging and
4 | distributing of the pre-compiled TensorFlow shared libraries.
5 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.2.2.md:
--------------------------------------------------------------------------------
1 | # Release 0.2.2
2 |
3 | In this release we have updated the precompiled TensorFlow binaries
4 | distributed with this library.
5 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.2.3.md:
--------------------------------------------------------------------------------
1 | # Release 0.2.3
2 |
3 | Added compatibility with TensorFlow 1.9-rc1.
4 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.2.4.md:
--------------------------------------------------------------------------------
1 | # Release 0.2.4
2 |
3 | Fixed an issue with the packaged pre-compiled TensorFlow binaries that
4 | affected Linux platforms.
5 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.3.0.md:
--------------------------------------------------------------------------------
1 | # Release 0.3.0
2 |
3 | With this release we have finally added support for static data type
4 | information for tensors (not for symbolic tensors yet though -- for now
5 | we effectively have support for a statically-typed version of `numpy`
6 | for Scala). This is an important milestone and contributes significantly
7 | to type safety, which can help catch errors at compile time, rather than
8 | runtime. For example:
9 |
10 | ```scala
11 | val t1 = Tensor(0.5, 1) // The inferred type is Tensor[FLOAT64].
12 | val t2 = Tensor(1, 2) // The inferred type is Tensor[INT32].
13 | val t3 = t1 + t2 // The inferred type is Tensor[FLOAT64].
14 | val t4 = t3.isNaN // The inferred type is Tensor[BOOLEAN].
15 | val t5 = t3.any() // Fails at compile-time because `any()` is only
16 | // supported for Tensor[BOOLEAN].
17 | ```
18 |
19 | Other new features include:
20 |
21 | - Improvements to the high-level learn API:
22 | - Layers can now provide and use their own parameter generator, and
23 | can also access the current training step
24 | (using `Layer.currentStep`).
25 | - Layers now support `.map(...)`.
26 | - Added support for batch normalization.
27 | - Added support for `tf.logSigmoid` and `tf.lrn`.
28 | - Added support for the following new metrics:
29 | - Grouped precision.
30 | - Precision-at-k.
31 | - `data` module:
32 | - Added support for loading the extreme classification repository
33 | datasets (i.e., `data.XCLoader`).
34 | - Added support for randomly splitting datasets.
35 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.4.0.md:
--------------------------------------------------------------------------------
1 | # Release 0.4.0
2 |
3 | This is a major release with a lot of new features related to static
4 | types for tensors and ops. The graph construction API is now
5 | statically-typed, thus enabling much better type safety than before.
6 |
7 | Tensors and outputs are now statically-typed and the types used are the
8 | Scala types that correspond to the tensors' TensorFlow data types. For
9 | example:
10 |
11 | ```scala
12 | val t1 = Tensor(0.5, 1) // The inferred type is Tensor[Double].
13 | val t2 = Tensor(1, 2) // The inferred type is Tensor[Int].
14 | val t3 = t1 + t2 // The inferred type is Tensor[Double].
15 | val t4 = t3.isNaN // The inferred type is Tensor[Boolean].
16 | val t5 = t3.any() // Fails at compile-time because `any()` is only
17 | // supported for Tensor[Boolean].
18 | ```
19 |
20 | A similar situation now applies to `Output`s. `Op`s are also typed and
21 | so is the auto-differentiation implementation.
22 |
23 | This resulted in major simplifications in the data pipeline and the high
24 | level learn API. Datasets and dataset iterators do not "carry" `T`, `V`,
25 | `D`, and `S` types with them now, but rather just the type of the
26 | elements they contain/produce.
27 |
28 | A new type trait called `TF` is also introduced that denotes supported
29 | Scala types in TensorFlow (e.g., `TF[Int]` and `TF[Float]`). Similarly,
30 | some more type traits are introduced to denote type constraints for
31 | various ops (e.g., `IsIntOrUInt[Int]`, `IsIntOrUInt[Long]`,
32 | `IsFloatOrDouble[Float]`, etc.). These type traits are powered by a
33 | general implementation of union types for Scala.
34 |
35 | Other new features include:
36 |
37 | - `data` module:
38 | - Added support for the `mapAndBatch` transformation.
39 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.4.1.md:
--------------------------------------------------------------------------------
1 | # Release 0.4.1
2 |
3 | Fixed the precompiled TensorFlow binaries, and also added the following
4 | new features:
5 |
6 | - `io` module:
7 | - Added support for a new `TFRecordWriter`.
8 | - `ops` module:
9 | - Added a new ops namespace, `sparse`, that includes all sparse ops.
10 | - Added support for `sparse.reorder` and `sparse.merge`.
11 | - Added support for parsing TF records.
12 | - `data` module:
13 | - Added support for `Dataset.shuffleAndRepeat`.
14 | - `optimizers` module:
15 | - Added support for the Adafactor optimizer.
16 | - Renamed `SqrtDecay` to `RSqrtDecay` which is more appropriate.
17 | - `math` module:
18 | - Added support for `batchGather`.
19 | - Added support for bitwise ops.
20 | - `rnn` module:
21 | - Simplified the attention mechanisms functionality so that it is
22 | now not required to tile memory tensors for beam search outside
23 | the beam search decoder.
24 | - Moved the `seq2seq` module to a separate repository (that of
25 | [Symphony Machine Translation](https://github.com/eaplatanios/symphony-mt)).
26 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.5.0.md:
--------------------------------------------------------------------------------
1 | # Release 0.5.0
2 |
3 | This release introduces support for TensorFlow 2.0.
4 |
--------------------------------------------------------------------------------
/docs/src/main/paradox/release_notes/0.5.1.md:
--------------------------------------------------------------------------------
1 | # Release 0.5.1
2 |
3 | This release introduces support for TensorFlow 2.2 and
4 | Scala 2.13 and drops support for Scala 2.11. The
5 | distributed precompiled binaries for this version will only
6 | work with CUDA 10.1 on Linux. Finally, this release also
7 | brings improved support for implicit derivations in some
8 | cases where case classes over tensors are used.
9 |
--------------------------------------------------------------------------------
/docs/src/main/scala/AddingOps.scala:
--------------------------------------------------------------------------------
1 | import org.platanios.tensorflow.api._
2 |
3 | object AddingOps {
4 | // #add_op_example
5 | def add[T: TF : IsNotQuantized](
6 | x: Output[T],
7 | y: Output[T],
8 | name: String = "Add"
9 | ): Output[T] = {
10 | Op.Builder[(Output[T], Output[T]), Output[T]](
11 | opType = "Add",
12 | name = name,
13 | input = (x, y)
14 | ).setGradientFn(addGradient(_, _)(TF[T], IsNotQuantized[T]))
15 | .build().output
16 | }
17 |
18 | protected def addGradient[T: TF : IsNotQuantized](
19 | op: Op[(Output[T], Output[T]), Output[T]],
20 | outputGradient: Output[T]
21 | ): (Output[T], Output[T]) = {
22 | val xShape = tf.shape(op.input._1)
23 | val yShape = tf.shape(op.input._2)
24 | val (rx, ry) = tf.broadcastGradientArguments(xShape, yShape)
25 | (tf.reshape(tf.sum(outputGradient, rx), xShape),
26 | tf.reshape(tf.sum(outputGradient, ry), yShape))
27 | }
28 | // #add_op_example
29 | }
30 |
--------------------------------------------------------------------------------
/docs/src/main/scala/Estimators.scala:
--------------------------------------------------------------------------------
1 | import org.platanios.tensorflow.api._
2 | import org.platanios.tensorflow.api.ops.metrics.Metric
3 | import org.platanios.tensorflow.api.tf.learn._
4 |
5 | // #inference_model
6 | trait InferenceModel[In, Out] extends Model {
7 | def buildInferOps(): Model.InferOps[In, Out]
8 | }
9 | // #inference_model
10 |
11 | // #trainable_models
12 | trait TrainableModel[In, TrainIn, Out, TrainOut, Loss, EvalIn] extends InferenceModel[In, Out] {
13 | def buildTrainOps(): Model.TrainOps[TrainIn, TrainOut, Loss]
14 | def buildEvalOps(metrics: Seq[Metric[EvalIn, Output[Float]]]): Model.EvalOps[TrainIn, Out]
15 | }
16 |
17 | trait SupervisedTrainableModel[In, TrainIn, Out, TrainOut, Loss] extends TrainableModel[In, (In, TrainIn), Out, TrainOut, Loss, (Out, (In, TrainIn))] {
18 | override def buildTrainOps(): Model.TrainOps[(In, TrainIn), TrainOut, Loss]
19 | override def buildEvalOps(metrics: Seq[Metric[(Out, (In, TrainIn)), Output[Float]]]): Model.EvalOps[(In, TrainIn), Out]
20 | }
21 |
22 | trait UnsupervisedTrainableModel[In, Out, Loss] extends TrainableModel[In, In, Out, Out, Loss, Out] {
23 | override def buildTrainOps(): Model.TrainOps[In, Out, Loss]
24 | override def buildEvalOps(metrics: Seq[Metric[Out, Output[Float]]]): Model.EvalOps[In, Out]
25 | }
26 | // #trainable_models
27 |
--------------------------------------------------------------------------------
/docs/src/main/scala/installation.sh:
--------------------------------------------------------------------------------
1 | // #clone_repository
2 | git clone https://github.com/tensorflow/tensorflow.git
3 | cd
4 | git checkout 1aaa68d93c6b2f4151446eb211399b4330c96a09
5 | // #clone_repository
6 |
7 | // #compile_tf
8 | ./configure
9 | bazel build --config=opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 //tensorflow:libtensorflow.so
10 | // #compile_tf
11 |
12 | // #apt_get_install_protobuf
13 | apt-get install protobuf-compiler
14 | // #apt_get_install_protobuf
15 |
16 | // #brew_install_protobuf
17 | brew install protobuf
18 | // #brew_install_protobuf
19 |
--------------------------------------------------------------------------------
/modules/api/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/Documentation.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api
17 |
18 | /** Groups together documentation from various sub-packages.
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | private[api] trait Documentation extends ops.Documentation
23 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/config/SummaryConfig.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.config
17 |
18 | /** Summary configuration used while training models.
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | sealed trait SummaryConfig
23 |
24 | /** Summary configuration for not saving any summaries. */
25 | case object NoSummaries extends SummaryConfig
26 |
27 | /** Summary configuration for step-based summaries (i.e., summaries every `n` steps).
28 | *
29 | * @param steps Save summaries every this many steps.
30 | *
31 | */
32 | case class StepBasedSummaries(steps: Int = 1000) extends SummaryConfig {
33 | require(steps >= 0, s"'steps' (set to $steps) needs to be a non-negative integer.")
34 | }
35 |
36 | /** Summary configuration for time-based summaries (i.e., summaries every `n` seconds).
37 | *
38 | * @param seconds Save summaries every this many seconds.
39 | *
40 | */
41 | case class TimeBasedSummaries(seconds: Int = 600) extends SummaryConfig {
42 | require(seconds >= 0, s"'seconds' (set to $seconds) needs to be a non-negative integer.")
43 | }
44 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/config/TensorBoardConfig.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.config
17 |
18 | import java.nio.file.Path
19 |
20 | /** TensorBoard configuration, which can be used when training using estimators.
21 | *
22 | * @param logDir Directory containing the logs and summaries that the TensorBoard instance should use.
23 | * @param host Host to use for the TensorBoard service.
24 | * @param port Port to use for the TensorBoard service.
25 | * @param reloadInterval Interval at which the backend reloads more data in seconds.
26 | *
27 | * @author Emmanouil Antonios Platanios
28 | */
29 | case class TensorBoardConfig(
30 | logDir: Path,
31 | host: String = "localhost",
32 | port: Int = 6006,
33 | reloadInterval: Int = 5
34 | ) {
35 | private[api] val processBuilder = new ProcessBuilder(
36 | "tensorboard",
37 | "--logdir", logDir.toAbsolutePath.toString,
38 | "--host", host,
39 | "--port", port.toString,
40 | "--reload_interval", reloadInterval.toString)
41 | }
42 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/core/Devices.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.core
17 |
18 | import org.platanios.tensorflow.api.core.client.SessionConfig
19 | import org.platanios.tensorflow.jni.{Session => NativeSession}
20 | import org.platanios.tensorflow.proto.DeviceAttributes
21 |
22 | /** Contains helper methods for dealing with devices.
23 | *
24 | * @author Emmanouil Antonios Platanios
25 | */
26 | object Devices {
27 | /** Returns a sequence containing information for all the devices available to this local process.
28 | *
29 | * @param sessionConfig Optional session configuration to use.
30 | * @return Sequence with information for all the devices available to this local process.
31 | */
32 | def local(sessionConfig: Option[SessionConfig] = None): Seq[DeviceAttributes] = {
33 | val devices = NativeSession.deviceList(sessionConfig.map(_.toConfigProto.toByteArray).orNull).toSeq
34 | devices.map(DeviceAttributes.parseFrom)
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/core/Implicits.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.core
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | private[api] trait Implicits
22 | extends client.Implicits
23 | with types.Implicits {
24 | // TODO: [INDEXERS] Add begin mask support (not simple).
25 |
26 | implicit def intToIndex(index: Int): Index = Index(index = index)
27 |
28 | implicit def intToIndexerConstruction(n: Int): IndexerConstructionWithOneNumber = {
29 | IndexerConstructionWithOneNumber(n)
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/core/Logging.scala:
--------------------------------------------------------------------------------
1 | package org.platanios.tensorflow.api.core
2 |
3 | import org.platanios.tensorflow.jni.TensorFlow
4 |
5 | object Logging {
6 |
7 | /** Represents the TensorFlow logging level. */
8 | sealed trait Level {
9 | private[Logging] def value: Int
10 | }
11 |
12 | case object DEBUG extends Level {
13 | override private[Logging] def value = 0
14 | }
15 |
16 | case object INFO extends Level {
17 | override private[Logging] def value = 1
18 | }
19 |
20 | case object WARNING extends Level {
21 | override private[Logging] def value = 2
22 | }
23 |
24 | case object ERROR extends Level {
25 | override private[Logging] def value = 3
26 | }
27 |
28 | /** Sets the current TensorFlow logging [[Level]]. */
29 | def setLoggingLevel(level: Level): Unit = {
30 | TensorFlow.setLogLevel(level.value.toString)
31 | }
32 |
33 | /** Returns the current TensorFlow logging [[Level]]. */
34 | def currentLoggingLevel: Level = TensorFlow.getLogLevel match {
35 | case null | "0" => DEBUG
36 | case "1" => INFO
37 | case "2" => WARNING
38 | case "3" => ERROR
39 | case _ => throw new AssertionError("This should be unreachable.")
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/core/client/Implicits.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.core.client
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | private[core] trait Implicits
22 | extends FeedMap.Implicits
23 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/core/distributed/Protocol.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.core.distributed
17 |
18 | /** Trait used to represented supported communication protocols for [[Server]]s.
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | sealed trait Protocol {
23 | val name: String
24 | }
25 |
26 | /** GRPC communication protocol. */
27 | case object GRPC extends Protocol {
28 | override val name: String = "grpc"
29 | }
30 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.implicits
17 |
18 | import org.platanios.tensorflow.api.core.Shape
19 | import org.platanios.tensorflow.api.core.types.{DataType, Variant}
20 | import org.platanios.tensorflow.api.ops.Output
21 | import org.platanios.tensorflow.api.ops.data.Dataset
22 |
23 | package object helpers {
24 | type SparseDataType[T] = (DataType[Long], DataType[T], DataType[Long])
25 | type IndexedSlicesDataType[T] = (DataType[Int], DataType[T], DataType[Int])
26 | type SparseShape = (Shape, Shape, Shape)
27 |
28 | // TODO: [FUNCTIONS] !!! Find a better way to deal with this for use in the reduce function of the "GroupByWindowDataset".
29 |
30 | case class VariantDataset[T: OutputStructure] protected(
31 | handle: Output[Variant],
32 | private val _outputDataTypes: Any = null,
33 | private val _outputShapes: Any = null
34 | ) extends Dataset[T] {
35 | override val name: String = "VariantDataset"
36 |
37 | override def createHandle[D, S]()(implicit
38 | evOutputToDataType: OutputToDataType.Aux[T, D],
39 | evOutputToShape: OutputToShape.Aux[T, S]
40 | ): Output[Variant] = {
41 | handle
42 | }
43 |
44 | override def outputDataTypes[D](implicit ev: OutputToDataType.Aux[T, D]): D = {
45 | _outputDataTypes.asInstanceOf[D]
46 | }
47 |
48 | override def outputShapes[S](implicit ev: OutputToShape.Aux[T, S]): S = {
49 | _outputShapes.asInstanceOf[S]
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/ops/ControlFlowImplicits.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.implicits.ops
17 |
18 | import org.platanios.tensorflow.api.ops.UntypedOp
19 |
20 | trait ControlFlowImplicits {
21 | implicit class ControlFlowOps(val op: UntypedOp) {
22 | /** Returns `true` if the provided op is within a cond statement. */
23 | def isInCond: Boolean = {
24 | op.controlFlowContext.flatMap(_.condContext).isDefined
25 | }
26 |
27 | /** Returns `true` if the provided op is within a while loop statement. */
28 | def isInWhileLoop: Boolean = {
29 | op.controlFlowContext.flatMap(_.whileLoopContext()).isDefined
30 | }
31 |
32 | /** Returns `true` if the provided op is within an XLA control flow context. */
33 | def isInXLAContext: Boolean = {
34 | val xlaCompile = {
35 | try {
36 | op.booleanAttribute("_XlaCompile")
37 | } catch {
38 | case _: IllegalArgumentException => false
39 | }
40 | }
41 | xlaCompile || op.controlFlowContext.flatMap(_.xlaContext).isDefined
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/ops/EmbeddingImplicits.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.implicits.ops
17 |
18 | import org.platanios.tensorflow.api.core.types.{IsNotQuantized, TF}
19 | import org.platanios.tensorflow.api.ops.Embedding.{OutputParameters, VariableParameters}
20 | import org.platanios.tensorflow.api.ops.{EmbeddingMap, EmbeddingParameters, Output}
21 | import org.platanios.tensorflow.api.ops.variables.Variable
22 |
23 | trait EmbeddingImplicits {
24 | implicit def singlePartitionEmbeddingMap[T: TF](
25 | parameters: EmbeddingParameters[T]
26 | ): EmbeddingMap[T] = {
27 | EmbeddingMap(Seq(parameters))
28 | }
29 |
30 | implicit def multiplePartitionsEmbeddingMap[T: TF](
31 | parameters: Seq[EmbeddingParameters[T]]
32 | ): EmbeddingMap[T] = {
33 | EmbeddingMap(parameters)
34 | }
35 |
36 | implicit def outputToEmbeddingMap[T: TF : IsNotQuantized](
37 | parameters: Output[T]
38 | ): EmbeddingMap[T] = {
39 | OutputParameters(parameters)
40 | }
41 |
42 | implicit def variableToEmbeddingMap[T: TF : IsNotQuantized](
43 | parameters: Variable[T]
44 | ): EmbeddingMap[T] = {
45 | VariableParameters(parameters)
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/io/CompressionType.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.io
17 |
18 | /**
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | sealed trait CompressionType {
23 | val name: String
24 |
25 | override def toString: String = name
26 | }
27 |
28 | case object NoCompression extends CompressionType {
29 | override val name: String = ""
30 | }
31 |
32 | case object ZLIBCompression extends CompressionType {
33 | override val name: String = "ZLIB"
34 | }
35 |
36 | case object GZIPCompression extends CompressionType {
37 | override val name: String = "GZIP"
38 | }
39 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/io/Loader.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.io
17 |
18 | /** Simple trait for representing data loaders of arbitrary types. */
19 | trait Loader[T] {
20 | /** Loads entries and returns an iterator over them. */
21 | def load(): Iterator[T]
22 | }
23 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/io/TFRecordWriter.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.io
17 |
18 | import org.platanios.tensorflow.api.utilities.{CRC32C, Coding}
19 | import org.platanios.tensorflow.proto.Example
20 |
21 | import java.io.BufferedOutputStream
22 | import java.nio.file.{Files, Path, StandardOpenOption}
23 |
24 | /** Helper used to write `Example` protocol buffers to TensorFlow record files.
25 | *
26 | * @param filePath TensorFlow record file path.
27 | *
28 | * @author Emmanouil Antonios Platanios
29 | */
30 | case class TFRecordWriter(filePath: Path) {
31 | protected var fileStream: BufferedOutputStream = {
32 | new BufferedOutputStream(Files.newOutputStream(
33 | filePath, StandardOpenOption.CREATE_NEW, StandardOpenOption.APPEND))
34 | }
35 |
36 | /** Appends `example` to the TensorFlow records file. */
37 | def write(example: Example): Unit = {
38 | val recordBytes = example.toByteArray
39 | // Format of a single record:
40 | // uint64 length
41 | // uint32 masked crc of length
42 | // byte data[length]
43 | // uint32 masked crc of data
44 | val encLength = Coding.encodeFixedInt64(recordBytes.length)
45 | val encLengthMaskedCrc = Coding.encodeFixedInt32(CRC32C.mask(CRC32C.value(encLength)))
46 | val encDataMaskedCrc = Coding.encodeFixedInt32(CRC32C.mask(CRC32C.value(recordBytes)))
47 | fileStream.write(encLength ++ encLengthMaskedCrc ++ recordBytes ++ encDataMaskedCrc)
48 | }
49 |
50 | /** Pushes outstanding examples to disk. */
51 | def flush(): Unit = {
52 | fileStream.flush()
53 | }
54 |
55 | /** Calls `flush()` and then closes the current TensorFlow records file. */
56 | def close(): Unit = {
57 | fileStream.close()
58 | }
59 | }
60 |
61 | object TFRecordWriter {
62 | def apply(filePath: Path): TFRecordWriter = {
63 | new TFRecordWriter(filePath)
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/io/events/EventType.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.io.events
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | sealed trait EventType
22 | case object ScalarEventType extends EventType
23 | case object ImageEventType extends EventType
24 | case object AudioEventType extends EventType
25 | case object HistogramEventType extends EventType
26 | case object CompressedHistogramEventType extends EventType
27 | case object TensorEventType extends EventType
28 | case object GraphEventType extends EventType
29 | case object MetaGraphEventType extends EventType
30 | case object RunMetadataEventType extends EventType
31 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/io/events/SummaryFileWriterCache.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.io.events
17 |
18 | import org.platanios.tensorflow.api.ops.Op
19 | import org.platanios.tensorflow.api.core.Graph
20 |
21 | import java.nio.file.Path
22 |
23 | import scala.collection.mutable
24 |
25 | /** Cache for summary file writers, which caches one writer per directory.
26 | *
27 | * @author Emmanouil Antonios Platanios
28 | */
29 | object SummaryFileWriterCache {
30 | private[this] val cache: mutable.Map[Path, SummaryFileWriter] = mutable.HashMap.empty[Path, SummaryFileWriter]
31 |
32 | /** Returns the summary file writer responsible for the specified directory. */
33 | def get(directory: Path, graph: Graph = Op.currentGraph): SummaryFileWriter = cache synchronized {
34 | cache.getOrElseUpdate(directory, SummaryFileWriter(directory, graph))
35 | }
36 |
37 | /** Clears the cached summary writers. Currently only used for testing. */
38 | private[io] def clear(): Unit = cache synchronized {
39 | // Make sure all the writers are closed.
40 | // Otherwise, open file handles may hang around, blocking deletions on Windows.
41 | cache.values.foreach(_.close())
42 | cache.clear()
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/Mode.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn
17 |
18 | /** Represents the mode that a model is on, while being used by a learner (e.g., training mode, evaluation mode, or
19 | * prediction mode).
20 | *
21 | * @author Emmanouil Antonios Platanios
22 | */
23 | sealed trait Mode {
24 | val isTraining: Boolean
25 | }
26 |
27 | case object TRAINING extends Mode {
28 | override val isTraining: Boolean = true
29 | }
30 |
31 | case object EVALUATION extends Mode {
32 | override val isTraining: Boolean = false
33 | }
34 |
35 | case object INFERENCE extends Mode {
36 | override val isTraining: Boolean = false
37 | }
38 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/ModelInstance.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn
17 |
18 | import org.platanios.tensorflow.api.core.types.{IsFloatOrDouble, TF}
19 | import org.platanios.tensorflow.api.ops.{Output, OutputLike, UntypedOp}
20 | import org.platanios.tensorflow.api.ops.data.DatasetIterator
21 | import org.platanios.tensorflow.api.ops.variables.Variable
22 |
23 | // TODO: [LEARN] What about "trainOutput"?
24 |
25 | /** Represents an instance of a constructed model. Such instances are constructed by estimators and passed on to
26 | * model-dependent hooks.
27 | *
28 | * @author Emmanouil Antonios Platanios
29 | */
30 | case class ModelInstance[In, TrainIn, Out, TrainOut, Loss: TF : IsFloatOrDouble, EvalIn](
31 | model: TrainableModel[In, TrainIn, Out, TrainOut, Loss, EvalIn],
32 | configuration: Configuration,
33 | trainInputIterator: Option[DatasetIterator[TrainIn]] = None,
34 | trainInput: Option[TrainIn] = None,
35 | output: Option[Out] = None,
36 | trainOutput: Option[TrainOut] = None,
37 | loss: Option[Output[Loss]] = None,
38 | gradientsAndVariables: Option[Seq[(OutputLike[Loss], Variable[Any])]] = None,
39 | trainOp: Option[UntypedOp] = None)
40 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object estimators {
22 | private[api] trait API {
23 | type Estimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.Estimator[In, TrainIn, Out, TrainOut, Loss, EvalIn]
24 | type InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn]
25 | type FileBasedEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.FileBasedEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn]
26 |
27 | val Estimator : estimators.Estimator.type = estimators.Estimator
28 | val InMemoryEstimator : estimators.InMemoryEstimator.type = estimators.InMemoryEstimator
29 | val FileBasedEstimator: estimators.FileBasedEstimator.type = estimators.FileBasedEstimator
30 | }
31 |
32 | private[api] object API extends API
33 | }
34 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/hooks/ModelDependentHook.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn.hooks
17 |
18 | import org.platanios.tensorflow.api.learn.ModelInstance
19 |
20 | /** Represents hooks that may dependent on the constructed model.
21 | *
22 | * This class offers the `modelInstance` field that sub-classes can access and that contains information specific to
23 | * the created model. It is only updated when the model graph is constructed (i.e., it is not updated while recovering
24 | * failed sessions).
25 | *
26 | * For example, a hook that logs the loss function value depends on the created loss op, or an evaluation hook may
27 | * depends on multiple ops created as part of the model.
28 | *
29 | * @author Emmanouil Antonios Platanios
30 | */
31 | trait ModelDependentHook[In, TrainIn, Out, TrainOut, Loss, EvalIn] extends Hook {
32 | protected var modelInstance: ModelInstance[In, TrainIn, Out, TrainOut, Loss, EvalIn] = _
33 |
34 | /** This method will be called by estimators at graph construction time, before `begin()`. It will **not** be called
35 | * again if a session fails and is recovered. */
36 | private[learn] final def setModelInstance(
37 | modelInstance: ModelInstance[In, TrainIn, Out, TrainOut, Loss, EvalIn]
38 | ): Unit = {
39 | this.modelInstance = modelInstance
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/Embedding.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn.layers
17 |
18 | import org.platanios.tensorflow.api._
19 | import org.platanios.tensorflow.api.core.types.{IsNotQuantized, TF}
20 | import org.platanios.tensorflow.api.learn.{Mode, layers}
21 | import org.platanios.tensorflow.api.ops
22 | import org.platanios.tensorflow.api.ops.Embedding.OutputParameters
23 | import org.platanios.tensorflow.api.ops.{EmbeddingMap, Output}
24 | import org.platanios.tensorflow.api.tensors.Tensor
25 |
26 | object Embedding {
27 | private[layers] trait API {
28 | type Embedding[T] = layers.Embedding[T]
29 |
30 | val Embedding: layers.Embedding.type = layers.Embedding
31 | }
32 |
33 | object API extends API
34 | }
35 |
36 | case class Embedding[T: TF : IsNotQuantized](
37 | override val name: String,
38 | vocabularySize: Int,
39 | embeddingSize: Int,
40 | partitionStrategy: ops.Embedding.PartitionStrategy = ops.Embedding.ModStrategy,
41 | transformFn: Output[T] => Output[T] = null,
42 | maxNorm: Tensor[T] = null
43 | ) extends Layer[Output[Int], Output[T]](name) {
44 | override val layerType: String = "Embedding"
45 |
46 | override def forwardWithoutContext(
47 | input: Output[Int]
48 | )(implicit mode: Mode): Output[T] = {
49 | val embeddingMap = getParameter[T]("EmbeddingMap", Shape(vocabularySize, embeddingSize))
50 | ops.Embedding.embeddingLookup(
51 | EmbeddingMap(Seq(OutputParameters(embeddingMap))), input, partitionStrategy, transformFn,
52 | if (maxNorm == null) null else ops.basic.Basic.constant(maxNorm),
53 | name)
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/core/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn.layers
17 |
18 | import org.platanios.tensorflow.api.core.types.{IsHalfOrFloatOrDouble, TF}
19 | import org.platanios.tensorflow.api.ops.Output
20 |
21 | /**
22 | * @author Emmanouil Antonios Platanios
23 | */
24 | package object core {
25 | private[layers] trait API {
26 | def MLP[T: TF : IsHalfOrFloatOrDouble](
27 | name: String,
28 | hiddenLayers: Seq[Int],
29 | outputSize: Int,
30 | activation: String => Layer[Output[T], Output[T]] = null,
31 | dropout: Float = 0.0f
32 | ): Layer[Output[T], Output[T]] = {
33 | if (hiddenLayers.isEmpty) {
34 | Linear(s"$name/Linear", outputSize)
35 | } else {
36 | val activationWithDefault = {
37 | if (activation == null)
38 | (name: String) => ReLU[T](name, 0.1f)
39 | else
40 | activation
41 | }
42 | val size = hiddenLayers.head
43 | var layer = Linear(s"$name/Layer0/Linear", size) >> activationWithDefault(s"$name/Layer0/Activation")
44 | hiddenLayers.zipWithIndex.tail.foreach(s => {
45 | layer = layer >>
46 | Linear(s"$name/Layer${s._2}/Linear", s._1) >>
47 | Dropout(s"$name/Layer${s._2}/Dropout", 1 - dropout) >>
48 | activationWithDefault(s"$name/Layer${s._2}/Activation")
49 | })
50 | layer >> Linear("OutputLayer/Linear", outputSize)
51 | }
52 | }
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/rnn/cell/RNNCell.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn.layers.rnn.cell
17 |
18 | import org.platanios.tensorflow.api.implicits.helpers.OutputToShape
19 | import org.platanios.tensorflow.api.learn.Mode
20 | import org.platanios.tensorflow.api.learn.layers.Layer
21 | import org.platanios.tensorflow.api.ops
22 | import org.platanios.tensorflow.api.ops.variables.VariableScope
23 |
24 | /**
25 | * @param name Name scope (also acting as variable scope) for this layer.
26 | *
27 | * @author Emmanouil Antonios Platanios
28 | */
29 | abstract class RNNCell[Out, State, OutShape, StateShape](
30 | override val name: String
31 | )(implicit
32 | val evOutputToShapeOut: OutputToShape.Aux[Out, OutShape],
33 | val evOutputToShapeState: OutputToShape.Aux[State, StateShape]
34 | ) extends Layer[Tuple[Out, State], Tuple[Out, State]](name) {
35 | def createCellWithoutContext(
36 | mode: Mode,
37 | inputShape: OutShape
38 | ): ops.rnn.cell.RNNCell[Out, State, OutShape, StateShape]
39 |
40 | final def createCell(
41 | mode: Mode,
42 | inputShape: OutShape
43 | ): ops.rnn.cell.RNNCell[Out, State, OutShape, StateShape] = {
44 | if (name != null) {
45 | VariableScope.scope(name, isPure = true) {
46 | createCellWithoutContext(mode, inputShape)
47 | }
48 | } else {
49 | createCellWithoutContext(mode, inputShape)
50 | }
51 | }
52 |
53 | override final def forwardWithoutContext(
54 | input: Tuple[Out, State]
55 | )(implicit mode: Mode): Tuple[Out, State] = {
56 | createCellWithoutContext(mode, evOutputToShapeOut.shape(input.output)).forward(input)
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/rnn/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.learn.layers
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object rnn {
22 | private[layers] trait API
23 | extends rnn.cell.API {
24 | type RNN[Out, State, OutShape, StateShape] = rnn.RNN[Out, State, OutShape, StateShape]
25 | type BidirectionalRNN[Out, State, OutShape, StateShape] = rnn.BidirectionalRNN[Out, State, OutShape, StateShape]
26 |
27 | val RNN : rnn.RNN.type = rnn.RNN
28 | val BidirectionalRNN: rnn.BidirectionalRNN.type = rnn.BidirectionalRNN
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Documentation.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /** Groups together documentation related to constructing symbolic ops.
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | private[api] trait Documentation
23 | extends basic.Basic.Documentation
24 | with Cast.Documentation
25 | with Checks.Documentation
26 | with Clip.Documentation
27 | with Embedding.Documentation
28 | with Image.Documentation
29 | with Logging.Documentation
30 | with math.Math.Documentation
31 | with NN.Documentation
32 | with Parsing.Documentation
33 | with Random.Documentation
34 | with Sets.Documentation
35 | with Sparse.Documentation
36 | with Statistics.Documentation
37 | with Summary.Documentation
38 | with Text.Documentation
39 | with control_flow.ControlFlow.Documentation
40 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Input.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | import org.platanios.tensorflow.api.core.Graph
19 | import org.platanios.tensorflow.api.core.types.DataType
20 | import org.platanios.tensorflow.api.utilities.using
21 | import org.platanios.tensorflow.jni.{Op => NativeOp}
22 |
23 | /** Wrapper around an op meant to represent one of its inputs. Actual op inputs have type [[Output]] since they
24 | * represent outputs of other ops. Currently, [[Input]] is only useful for representing consumers of an [[Op]]'s
25 | * outputs.
26 | *
27 | * @param op Op whose input this class represents.
28 | * @param index Input index.
29 | *
30 | * @author Emmanouil Antonios Platanios
31 | */
32 | final case class Input[T] private[ops](
33 | op: UntypedOp,
34 | index: Int
35 | ) {
36 | /** Name of this op input. This is simply set to `":"`. */
37 | lazy val name: String = {
38 | s"${op.name}:$index"
39 | }
40 |
41 | /** Data type of this op input. */
42 | lazy val dataType: DataType[T] = {
43 | using(graph.reference) { r =>
44 | DataType.fromCValue(
45 | NativeOp.inputDataType(r.nativeHandle, op.nativeHandle, index)
46 | ).asInstanceOf[DataType[T]]
47 | }
48 | }
49 |
50 | /** Graph where the op belongs. */
51 | def graph: Graph = {
52 | op.graph
53 | }
54 |
55 | override def toString: String = {
56 | s"Op.Input(name = $name, dataType = $dataType)"
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/OpSpecification.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | final case class OpSpecification(
22 | name: String,
23 | opType: String,
24 | device: String)
25 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/control_flow/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object control_flow {
22 | private[ops] trait API extends ControlFlow
23 | }
24 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object data {
22 | private[ops] trait API extends Data
23 | with data.DatasetIterator.API {
24 | type Dataset[T] = data.Dataset[T]
25 | type DatasetIterator[T] = data.DatasetIterator[T]
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/LookupTableInitializer.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.lookup
17 |
18 | import org.platanios.tensorflow.api.core.types.{DataType, TF}
19 | import org.platanios.tensorflow.api.ops.UntypedOp
20 |
21 | /** Lookup table initializer.
22 | *
23 | * @param keysDataType Data type of the table keys.
24 | * @param valuesDataType Data type of the table values.
25 | *
26 | * @author Emmanouil Antonios Platanios
27 | */
28 | abstract class LookupTableInitializer[K: TF, V: TF](
29 | val keysDataType: DataType[K],
30 | val valuesDataType: DataType[V]
31 | ) {
32 | /** Creates and returns an op that initializes the provided table.
33 | *
34 | * @param table Table to initialize.
35 | * @return Created initialization op for `table`.
36 | */
37 | def initialize(
38 | table: InitializableLookupTable[K, V],
39 | name: String = "Initialize"
40 | )(implicit evVTF: TF[V]): UntypedOp
41 | }
42 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/LookupTableTensorInitializer.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.lookup
17 |
18 | import org.platanios.tensorflow.api.core.Graph
19 | import org.platanios.tensorflow.api.core.types.{Resource, TF}
20 | import org.platanios.tensorflow.api.ops.{Op, Output, UntypedOp}
21 |
22 | /** Lookup table initializer that uses the provided tensors (containing keys and corresponding values) for initializing
23 | * a lookup table.
24 | *
25 | * @param keys Tensor containing the table keys.
26 | * @param values Tensor containing the table values.
27 | *
28 | * @author Emmanouil Antonios Platanios
29 | */
30 | class LookupTableTensorInitializer[K: TF, V: TF] protected (
31 | val keys: Output[K],
32 | val values: Output[V]
33 | ) extends LookupTableInitializer(keys.dataType, values.dataType) {
34 | /** Creates and returns an op that initializes the provided table.
35 | *
36 | * @param table Table to initialize.
37 | * @return Created initialization op for `table`.
38 | */
39 | override def initialize(
40 | table: InitializableLookupTable[K, V],
41 | name: String = "Initialize"
42 | )(implicit evVTF: TF[V]): UntypedOp = {
43 | Op.nameScope(name) {
44 | val initializationOp = Op.Builder[(Output[Resource], Output[K], Output[V]), Unit](
45 | opType = "InitializeTableV2",
46 | name = name,
47 | input = (table.handle, keys, values)
48 | ).build()
49 | Op.currentGraph.addToCollection(Graph.Keys.TABLE_INITIALIZERS)(initializationOp.asUntyped)
50 | initializationOp.asUntyped
51 | }
52 | }
53 | }
54 |
55 | object LookupTableTensorInitializer {
56 | def apply[K: TF, V: TF](
57 | keys: Output[K],
58 | values: Output[V]
59 | ): LookupTableTensorInitializer[K, V] = {
60 | new LookupTableTensorInitializer(keys, values)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object lookup {
22 | private[ops] trait API
23 | extends Lookup {
24 | type LookupTable[K, V] = lookup.LookupTable[K, V]
25 | type HashTable[K, V] = lookup.HashTable[K, V]
26 | type IDLookupTableWithHashBuckets[K] = lookup.IDLookupTableWithHashBuckets[K]
27 |
28 | val HashTable : lookup.HashTable.type = lookup.HashTable
29 | val IDLookupTableWithHashBuckets: lookup.IDLookupTableWithHashBuckets.type = lookup.IDLookupTableWithHashBuckets
30 |
31 | type LookupTableInitializer[K, V] = lookup.LookupTableInitializer[K, V]
32 | type LookupTableTensorInitializer[K, V] = lookup.LookupTableTensorInitializer[K, V]
33 | type LookupTableTextFileInitializer[K, V] = lookup.LookupTableTextFileInitializer[K, V]
34 |
35 | val LookupTableTensorInitializer: lookup.LookupTableTensorInitializer.type = {
36 | lookup.LookupTableTensorInitializer
37 | }
38 |
39 | val LookupTableTextFileInitializer: lookup.LookupTableTextFileInitializer.type = {
40 | lookup.LookupTableTextFileInitializer
41 | }
42 |
43 | type TextFileFieldExtractor[K] = lookup.TextFileFieldExtractor[K]
44 |
45 | val TextFileLineNumber: lookup.TextFileLineNumber.type = lookup.TextFileLineNumber
46 | val TextFileWholeLine : lookup.TextFileWholeLine.type = lookup.TextFileWholeLine
47 | val TextFileColumn : lookup.TextFileColumn.type = lookup.TextFileColumn
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/cell/RNNCell.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.rnn.cell
17 |
18 | import org.platanios.tensorflow.api.implicits.helpers.{OutputToShape, Zero}
19 | import org.platanios.tensorflow.api.ops.Output
20 |
21 | /** Contains functions for constructing ops related to recurrent neural network (RNN) cells.
22 | *
23 | * @author Emmanouil Antonios Platanios
24 | */
25 | abstract class RNNCell[Out, State, OutShape, StateShape](implicit
26 | val evOutputToShapeOut: OutputToShape.Aux[Out, OutShape],
27 | val evOutputToShapeState: OutputToShape.Aux[State, StateShape]
28 | ) {
29 | def outputShape: OutShape
30 | def stateShape: StateShape
31 |
32 | def zeroOutput(
33 | batchSize: Output[Int],
34 | name: String = "ZeroOutput"
35 | )(implicit evZero: Zero.Aux[Out, OutShape]): Out = {
36 | evZero.zero(batchSize, outputShape, name)
37 | }
38 |
39 | def zeroState(
40 | batchSize: Output[Int],
41 | name: String = "ZeroState"
42 | )(implicit evZero: Zero.Aux[State, StateShape]): State = {
43 | evZero.zero(batchSize, stateShape, name)
44 | }
45 |
46 | @throws[IllegalArgumentException]
47 | def forward(input: Tuple[Out, State]): Tuple[Out, State]
48 |
49 | def apply(input: Tuple[Out, State]): Tuple[Out, State] = {
50 | forward(input)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object rnn {
22 | private[ops] trait API
23 | extends attention.API
24 | with cell.API
25 | with RNN
26 | }
27 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.training
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object optimizers {
22 | private[training] trait API
23 | extends schedules.API {
24 | type Optimizer = optimizers.Optimizer
25 | type AdaDelta = optimizers.AdaDelta
26 | type Adafactor = optimizers.Adafactor
27 | type AdaGrad = optimizers.AdaGrad
28 | type RMSProp = optimizers.RMSProp
29 | type Adam = optimizers.Adam
30 | type AMSGrad = optimizers.AMSGrad
31 | type GradientDescent = optimizers.GradientDescent
32 | type LazyAdam = optimizers.LazyAdam
33 | type LazyAMSGrad = optimizers.LazyAMSGrad
34 | type YellowFin = optimizers.YellowFin
35 | type Yogi = optimizers.Yogi
36 |
37 | val AdaDelta : optimizers.AdaDelta.type = optimizers.AdaDelta
38 | val Adafactor : optimizers.Adafactor.type = optimizers.Adafactor
39 | val AdaGrad : optimizers.AdaGrad.type = optimizers.AdaGrad
40 | val RMSProp : optimizers.RMSProp.type = optimizers.RMSProp
41 | val Adam : optimizers.Adam.type = optimizers.Adam
42 | val AMSGrad : optimizers.AMSGrad.type = optimizers.AMSGrad
43 | val GradientDescent: optimizers.GradientDescent.type = optimizers.GradientDescent
44 | val LazyAdam : optimizers.LazyAdam.type = optimizers.LazyAdam
45 | val LazyAMSGrad : optimizers.LazyAMSGrad.type = optimizers.LazyAMSGrad
46 | val YellowFin : optimizers.YellowFin.type = optimizers.YellowFin
47 | val Yogi : optimizers.Yogi.type = optimizers.Yogi
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/schedules/FixedSchedule.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.training.optimizers.schedules
17 |
18 | import org.platanios.tensorflow.api.core.types.{TF, IsIntOrLong}
19 | import org.platanios.tensorflow.api.ops.Output
20 | import org.platanios.tensorflow.api.ops.variables.Variable
21 |
22 | /** Dummy scheduling method representing no schedule being used. Useful as a default value for `Schedule`-valued
23 | * function arguments.
24 | *
25 | * @author Emmanouil Antonios Platanios
26 | */
27 | case class FixedSchedule[T]() extends Schedule[T] {
28 | /** Applies the scheduling method to `value`, the current iteration in the optimization loop is `step` and returns the
29 | * result.
30 | *
31 | * @param value Value to change based on this schedule.
32 | * @param step Option containing current iteration in the optimization loop, if one has been provided.
33 | * @return Potentially modified value.
34 | * @throws IllegalArgumentException If the scheduling method requires a value for `step` but the provided option is
35 | * empty.
36 | */
37 | @throws[IllegalArgumentException]
38 | override def apply[I: TF : IsIntOrLong](
39 | value: Output[T],
40 | step: Option[Variable[I]]
41 | ): Output[T] = {
42 | value
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/schedules/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.training.optimizers
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object schedules {
22 | private[optimizers] trait API {
23 | type Schedule[T] = schedules.Schedule[T]
24 | type FixedSchedule[T] = schedules.FixedSchedule[T]
25 | type CosineDecay = schedules.CosineDecay
26 | type CycleLinear10xDecay = schedules.CycleLinear10xDecay
27 | type ExponentialDecay = schedules.ExponentialDecay
28 | type LuongExponentialDecay = schedules.LuongExponentialDecay
29 | type RSqrtDecay = schedules.RSqrtDecay
30 | type WarmUpExponentialSchedule = schedules.WarmUpExponentialSchedule
31 | type WarmUpLinearSchedule = schedules.WarmUpLinearSchedule
32 |
33 | val FixedSchedule : schedules.FixedSchedule.type = schedules.FixedSchedule
34 | val CosineDecay : schedules.CosineDecay.type = schedules.CosineDecay
35 | val CycleLinear10xDecay : schedules.CycleLinear10xDecay.type = schedules.CycleLinear10xDecay
36 | val ExponentialDecay : schedules.ExponentialDecay.type = schedules.ExponentialDecay
37 | val LuongExponentialDecay : schedules.LuongExponentialDecay.type = schedules.LuongExponentialDecay
38 | val RSqrtDecay : schedules.RSqrtDecay.type = schedules.RSqrtDecay
39 | val WarmUpExponentialSchedule: schedules.WarmUpExponentialSchedule.type = schedules.WarmUpExponentialSchedule
40 | val WarmUpLinearSchedule : schedules.WarmUpLinearSchedule.type = schedules.WarmUpLinearSchedule
41 |
42 | // TODO: Piecewise constant.
43 | // TODO: Polynomial.
44 | // TODO: Natural exp.
45 | // TODO: Inverse time.
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object training {
22 | private[ops] trait API extends optimizers.API {
23 | type ExponentialMovingAverage = training.ExponentialMovingAverage
24 | val ExponentialMovingAverage: training.ExponentialMovingAverage.type = training.ExponentialMovingAverage
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/Regularizer.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.variables
17 |
18 | import org.platanios.tensorflow.api.ops.Output
19 |
20 | /** A variable regularizer is simply a function that takes a tensor representing the variable value as input, and
21 | * returns another tensor representing the regularizer value as output.
22 | *
23 | * @author Emmanouil Antonios Platanios
24 | */
25 | trait Regularizer {
26 | def apply[T](value: Output[T]): Output[T]
27 | }
28 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/Reuse.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.variables
17 |
18 | /** Enumeration of possible variable reuse options, used by variable scopes and variable stores.
19 | *
20 | * The supported options are:
21 | * - [[ReuseExistingOnly]]: Reuse existing variables only and throw an exception if no appropriate variable exists.
22 | * - [[CreateNewOnly]]: Create new variables only and throw an exception if a variable with the same name exists.
23 | * - [[ReuseOrCreateNew]]: Reuse existing variables or create new ones, if no variable with the provided name exists.
24 | *
25 | * @author Emmanouil Antonios Platanios
26 | */
27 | sealed trait Reuse
28 |
29 | /** Trait marking the variable reuse modes that allow reusing existing variables. */
30 | sealed trait ReuseAllowed extends Reuse
31 |
32 | /** Reuse existing variables only and throw an exception if no appropriate variable exists. */
33 | case object ReuseExistingOnly extends ReuseAllowed
34 |
35 | /** Create new variables only and throw an exception if a variable with the same name exists. */
36 | case object CreateNewOnly extends Reuse
37 |
38 | /** Reuse existing variables or create new ones, if no variable with the provided name exists. */
39 | case object ReuseOrCreateNew extends ReuseAllowed
40 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/VariableScopeStore.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.variables
17 |
18 | import org.platanios.tensorflow.api.ops.Op
19 |
20 | /** A thread-local score for the current variable scope and scope counts.
21 | *
22 | * @author Emmanouil Antonios Platanios
23 | */
24 | case class VariableScopeStore private[api]() {
25 | private[api] var scope: VariableScope = {
26 | VariableScope(CreateNewOnly)
27 | }
28 |
29 | /** Map with variable scope names as keys and the corresponding use counts as values. */
30 | private[api] var variableScopeCounts: Map[String, Int] = {
31 | Map.empty[String, Int]
32 | }
33 |
34 | private[api] def enterVariableScope(scope: String): Unit = {
35 | variableScopeCounts += scope -> (variableScopeCounts.getOrElse(scope, 0) + 1)
36 | }
37 |
38 | private[api] def closeVariableSubScopes(scope: String): Unit = {
39 | variableScopeCounts.keySet.filter(_.startsWith(s"$scope/")).foreach(variableScopeCounts -= _)
40 | }
41 |
42 | /** Returns the use count of the provided scope in this variable store.
43 | *
44 | * @param scope Variable scope name.
45 | * @return Number of usages of the provided variable scope name, in this variable store.
46 | */
47 | private[api] def variableScopeCount(scope: String): Int = {
48 | variableScopeCounts.getOrElse(scope, 0)
49 | }
50 | }
51 |
52 | object VariableScopeStore {
53 | def current: VariableScopeStore = {
54 | Op.currentGraph.variableScopeStore.value
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Cast.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.tensors.ops
17 |
18 | import org.platanios.tensorflow.api.core.types._
19 | import org.platanios.tensorflow.api.tensors._
20 | import org.platanios.tensorflow.jni.generated.tensors.{Math => NativeTensorOpsMath}
21 |
22 | /** Contains functions for executing cast-related ops.
23 | *
24 | * @author Emmanouil Antonios Platanios
25 | */
26 | trait Cast {
27 | /** $OpDocCastCast
28 | *
29 | * @group CastOps
30 | *
31 | * @param input Tensor to cast.
32 | * @tparam R Target data type.
33 | * @return Result as a new tensor.
34 | */
35 | private[tensors] def cast[T, R: TF, TL[TT] <: TensorLike[TT]](
36 | input: TL[T],
37 | truncate: Boolean = false
38 | )(implicit ev: TensorOps.Aux[TL, T]): TL[R] = {
39 | val dataType = implicitly[TF[R]].dataType
40 | if (input.dataType == dataType) {
41 | input.asInstanceOf[TL[R]]
42 | } else {
43 | ev.applyUnary(input, t => {
44 | Tensor.fromNativeHandle[R](NativeTensorOpsMath.cast(
45 | executionContext.value.nativeHandle, t.nativeHandle, dataType.cValue, truncate))
46 | })
47 | }
48 | }
49 |
50 | // TODO: [OPS] saturateCast
51 |
52 | /** $OpDocCastBitcast
53 | *
54 | * @group CastOps
55 | *
56 | * @param input Input tensor.
57 | * @tparam R Target data type.
58 | * @return Result as a new tensor.
59 | */
60 | private[tensors] def bitcast[T: IsNumeric, R: TF, TL[TT] <: TensorLike[TT]](
61 | input: TL[T]
62 | )(implicit ev: TensorOps.Aux[TL, T]): TL[R] = {
63 | val dataType = implicitly[TF[R]].dataType
64 | ev.applyUnary(input, t => {
65 | Tensor.fromNativeHandle[R](NativeTensorOpsMath.bitcast(
66 | executionContext.value.nativeHandle, t.nativeHandle, dataType.cValue))
67 | })
68 | }
69 | }
70 |
71 | object Cast extends Cast
72 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Random.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.tensors.ops
17 |
18 | import org.platanios.tensorflow.api.core.types.TF
19 | import org.platanios.tensorflow.api.ops.Op
20 | import org.platanios.tensorflow.api.tensors._
21 | import org.platanios.tensorflow.jni.generated.tensors.{Random => NativeTensorOpsRandom}
22 |
23 | /** Contains functions for executing ops related to random numbers and tensors.
24 | *
25 | * @author Emmanouil Antonios Platanios
26 | */
27 | trait Random {
28 | /** $OpDocRandomRandomShuffle
29 | *
30 | * @group RandomOps
31 | * @param value Tensor to be shuffled.
32 | * @param seed Optional random seed, used to generate a random seed pair for the random number generator, when
33 | * combined with the graph-level seed.
34 | * @return Result as a new tensor.
35 | */
36 | def randomShuffle[T: TF](
37 | value: Tensor[T],
38 | seed: Option[Int] = None
39 | ): Tensor[T] = {
40 | val (graphSeed, opSeed) = Op.currentGraphRandomSeed(seed)
41 | Tensor.fromNativeHandle[T](NativeTensorOpsRandom.randomShuffle(
42 | executionContext.value.nativeHandle, value.nativeHandle,
43 | graphSeed.getOrElse(0).toLong, opSeed.getOrElse(0).toLong))
44 | }
45 | }
46 |
47 | object Random extends Random
48 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.tensors
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object ops {
22 | private[api] trait API
23 | extends Basic
24 | with Cast
25 | with Math
26 | with NN
27 | with Random
28 | }
29 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api
17 |
18 | import scala.util.DynamicVariable
19 |
20 | /**
21 | * @author Emmanouil Antonios Platanios
22 | */
23 | package object tensors {
24 | private[api] val executionContext: DynamicVariable[Context] = {
25 | new DynamicVariable[Context](Context(Some(core.defaultSessionConfig)))
26 | }
27 |
28 | private[api] trait API
29 | extends tensors.ops.API
30 | }
31 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/Collections.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.utilities
17 |
18 | /** Contains helper functions for manipulating collections.
19 | *
20 | * @author Emmanouil Antonios Platanios
21 | */
22 | object Collections {
23 | /** Segments a sequence according to a provided sequence of segment lengths.
24 | *
25 | * For example.
26 | * {{{
27 | * val xs = Seq(3, 5, 2, 77, 12, 45, 78, 21, 89, 1, 0, -1, 123)
28 | * val n = Seq(3, 1, 0, 2, 5, 2)
29 | * segment(xs, n) = Seq(Seq(3, 5, 2), Seq(77), Seq(), Seq(12, 45), Seq(78, 21, 89, 1, 0), Seq(-1, 123))
30 | * }}}
31 | *
32 | * Note that the function returns when either one of `xs` or `n` is exhausted. This means that no exception is thrown
33 | * if the provided segment lengths do not match the original sequence length.
34 | *
35 | * @param xs Sequence to segment.
36 | * @param n Segment lengths.
37 | * @return Sequence containing the segments of `xs`.
38 | */
39 | def segment[V](xs: Seq[V], n: Seq[Int]): Seq[Seq[V]] = {
40 | if (xs.isEmpty && n.isEmpty) {
41 | Nil
42 | } else {
43 | val (ys, zs) = xs.splitAt(n.head)
44 | ys +: segment(zs, n.tail)
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/Proto.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.utilities
17 |
18 | import org.platanios.tensorflow.api.io.FileIO
19 |
20 | import com.google.protobuf.GeneratedMessageV3
21 |
22 | import java.nio.file.{Files, Path}
23 |
24 | /** Contains helper functions for working with ProtoBuf.
25 | *
26 | * @author Emmanouil Antonios Platanios
27 | */
28 | object Proto {
29 | /** Writes `message` to the specified file.
30 | *
31 | * @param directory Directory in which to write the file.
32 | * @param filename Name of the file.
33 | * @param message ProtoBuf message to write.
34 | * @param asText Boolean value indicating whether to serialize the ProtoBuf message in the human-friendly text
35 | * format, or in the more efficient binary format.
36 | * @return Path of the written file.
37 | */
38 | def write(directory: Path, filename: String, message: GeneratedMessageV3, asText: Boolean = false): Path = {
39 | // GCS does not have the concept of a directory at the moment.
40 | if (!Files.exists(directory) && !directory.startsWith("gs:")) {
41 | Files.createDirectories(directory)
42 | }
43 | val filePath = directory.resolve(filename)
44 | if (asText)
45 | FileIO.writeStringToFileAtomic(filePath, message.toString)
46 | else
47 | message.writeTo(Files.newOutputStream(filePath))
48 | filePath
49 | }
50 |
51 | /** Trait that all ProtoBuf-serializable objects should extend. */
52 | trait Serializable {
53 | /** Converts this object to its corresponding ProtoBuf object.
54 | *
55 | * @return ProtoBuf object corresponding to this object.
56 | */
57 | def toProto: GeneratedMessageV3
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/package.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | package object utilities {
22 | trait Closeable {
23 | protected val closeFn: () => Unit
24 |
25 | /** Releases the native resources associated with this object. */
26 | def close(): Unit = closeFn()
27 | }
28 |
29 | def using[T <: Closeable, R](resource: T)(block: T => R): R = {
30 | try {
31 | block(resource)
32 | } finally {
33 | if (resource != null)
34 | resource.close()
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/modules/api/src/test/scala/org/platanios/tensorflow/api/core/SessionSpec.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.core
17 |
18 | import org.platanios.tensorflow.api._
19 |
20 | import org.scalatest.flatspec.AnyFlatSpec
21 | import org.scalatest.matchers.should.Matchers
22 |
23 | /**
24 | * @author Emmanouil Antonios Platanios
25 | */
26 | class SessionSpec extends AnyFlatSpec with Matchers {
27 | "Session run fetch by name" should "return the correct result" in {
28 | val graph = Graph()
29 | tf.createWith(graph = graph) {
30 | val a = tf.constant(Tensor(Tensor(2, 3)), name = "A")
31 | val x = tf.placeholder[Int](Shape(1, 2), name = "X")
32 | tf.subtract(tf.constant(1), tf.matmul(a = a, b = x, transposeB = true), name = "Y")
33 | }
34 | val session = Session(graph = graph)
35 | val feeds = Map(graph.getOutputByName("X:0").asInstanceOf[Output[Int]] -> Tensor(Tensor(5, 7)))
36 | val fetches = graph.getOutputByName("Y:0").asInstanceOf[Output[Int]]
37 | val output = session.run(feeds, fetches)
38 | val expectedResult = Tensor(Tensor(-30))
39 | assert(output.scalar == expectedResult.scalar)
40 | graph.close()
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/modules/api/src/test/scala/org/platanios/tensorflow/api/ops/BasicSpec.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | import org.scalatest.flatspec.AnyFlatSpec
19 | import org.scalatest.matchers.should.Matchers
20 |
21 | /**
22 | * @author Emmanouil Antonios Platanios
23 | */
24 | class BasicSpec extends AnyFlatSpec with Matchers {
25 | // "'ArrayOps.constant'" must "create a constant op when provided a Tensor of the same data type and shape" in {
26 | // // DataType.Int32 Tensor
27 | // val tensor1 = Tensor(Tensor(Tensor(2, 3), Tensor(0, 0), Tensor(5, 7)),
28 | // Tensor(Tensor(1, 23), Tensor(4, -5), Tensor(7, 9)),
29 | // Tensor(Tensor(56, 1), Tensor(-2, -4), Tensor(-7, -9)))
30 | // val constant1 = ArrayOps.constant(tensor1)
31 | // val constantValue1 = constant1.value()
32 | // assert(tensor1.get(1, 1, 1) === -5)
33 | // assert(constant1.shape === Shape(3, 3, 2))
34 | // assert(constant1.dataType === DataType.Int32)
35 | // assert(constantValue1.shape === Shape(3, 3, 2))
36 | // assert(constantValue1.dataType === DataType.Int32)
37 | // assert(constantValue1(1, 1, 1) === -5)
38 | // tensor1.close()
39 | //
40 | // // DataType.Float64 Tensor
41 | // val tensor2 = Tensor.create(array, dataType = DataType.Float64)
42 | // val constant2 = ArrayOps.constant(tensor2)
43 | // val constantValue2 = constant2.value()
44 | // assert(tensor2(1, 1, 1) === -5.0)
45 | // assert(constant2.shape === Shape(3, 3, 2))
46 | // assert(constant2.dataType === DataType.Float64)
47 | // assert(constantValue2.shape === Shape(3, 3, 2))
48 | // assert(constantValue2.dataType === DataType.Float64)
49 | // assert(constantValue2(1, 1, 1) === -5.0)
50 | // tensor2.close()
51 | // }
52 | }
53 |
--------------------------------------------------------------------------------
/modules/api/src/test/scala/org/platanios/tensorflow/api/ops/NNSpec.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops
17 |
18 | import org.platanios.tensorflow.api._
19 |
20 | import org.scalatest.matchers.should.Matchers
21 | import org.scalatestplus.junit.JUnitSuite
22 | import org.junit.Test
23 |
24 | /**
25 | * @author Emmanouil Antonios Platanios
26 | */
27 | class NNSpec extends JUnitSuite with Matchers {
28 | @Test def testLogSoftmax(): Unit = {
29 | val tensor = Tensor(Tensor(Tensor(2, 3), Tensor(0, 0), Tensor(5, 7)),
30 | Tensor(Tensor(1, 23), Tensor(4, -5), Tensor(7, 9)),
31 | Tensor(Tensor(56, 1), Tensor(-2, -4), Tensor(-7, -9)))
32 | val constant = tf.constant(tensor).toFloat
33 | val logSoftmaxLastAxis = tf.logSoftmax(constant, axis = -1)
34 | val logSoftmaxPenultimateAxis = tf.logSoftmax(constant, axis = 1)
35 | val session = Session()
36 | assertApproximatelyEqual(
37 | session.run(fetches = logSoftmaxLastAxis).toArray,
38 | Array(
39 | -1.3132616f, -0.31326163f, -0.6931472f, -0.6931472f, -2.126928f, -0.12692805f,
40 | -22.0f, 0.0f, -1.23374e-4f, -9.000123f, -2.126928f, -0.12692805f, 0.0f, -55.0f,
41 | -0.12692805f, -2.126928f, -0.12692805f, -2.126928f,
42 | ),
43 | )
44 | assertApproximatelyEqual(
45 | session.run(fetches = logSoftmaxPenultimateAxis).toArray,
46 | Array(
47 | -3.0549853f, -4.019045f, -5.054985f, -7.019045f, -0.054985214f, -0.019044992f,
48 | -6.0509458f, -8.344647e-7f, -3.0509458f, -28.0f, -0.05094571f, -14.000001f, 0.0f,
49 | -0.0067604627f, -58.0f, -5.0067606f, -63.0f, -10.006761f,
50 | ),
51 | )
52 | }
53 |
54 | def assertApproximatelyEqual(x: Array[Float], y: Array[Float]): Unit = {
55 | x.zip(y).foreach { case (xElement, yElement) =>
56 | assert(xElement === yElement +- 1e-6f)
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/modules/api/src/test/scala/org/platanios/tensorflow/api/ops/data/FilterDatasetSuite.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.ops.data
17 |
18 | import org.platanios.tensorflow.api.core.client.Session
19 | import org.platanios.tensorflow.api.core.{Graph, Shape}
20 | import org.platanios.tensorflow.api.implicits.Implicits._
21 | import org.platanios.tensorflow.api.ops.Op
22 | import org.platanios.tensorflow.api.ops.math.Math
23 | import org.platanios.tensorflow.api.tensors.Tensor
24 | import org.platanios.tensorflow.api.utilities.using
25 |
26 | import org.junit.Test
27 | import org.scalatestplus.junit.JUnitSuite
28 |
29 | /**
30 | * @author Emmanouil Antonios Platanios
31 | */
32 | class FilterDatasetSuite extends JUnitSuite {
33 | @Test def testFilterRange(): Unit = using(Graph()) { graph =>
34 | Op.createWith(graph) {
35 | val dataset = Data.datasetFromRange(0, 100).filter(x => {
36 | Math.notEqual(Math.mod(x, 3L), 2L)
37 | })
38 | val iterator = dataset.createInitializableIterator()
39 | val initOp = iterator.initializer
40 | val nextOutput = iterator.next()
41 | assert(nextOutput.shape == Shape.scalar())
42 | val session = Session()
43 | session.run(targets = initOp)
44 | assert(session.run(fetches = nextOutput) == (0L: Tensor[Long]))
45 | assert(session.run(fetches = nextOutput) == (1L: Tensor[Long]))
46 | assert(session.run(fetches = nextOutput) == (3L: Tensor[Long]))
47 | assert(session.run(fetches = nextOutput) == (4L: Tensor[Long]))
48 | assert(session.run(fetches = nextOutput) == (6L: Tensor[Long]))
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/modules/api/src/test/scala/org/platanios/tensorflow/api/utilities/CRC32CSuite.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.api.utilities
17 |
18 | import org.scalatestplus.junit.JUnitSuite
19 | import org.junit.Test
20 |
21 | /**
22 | * @author Emmanouil Antonios Platanios
23 | */
24 | class CRC32CSuite extends JUnitSuite {
25 | @Test def testValue(): Unit = {
26 | // From rfc3720 section B.4.
27 | assert(CRC32C.value(Array.fill[Byte](32)(0x00.toByte)) === 0x8a9136aa)
28 | assert(CRC32C.value(Array.fill[Byte](32)(0xff.toByte)) === 0x62a8ab43)
29 | assert(CRC32C.value((0 until 32).map(_.toByte).toArray) === 0x46dd794e)
30 | assert(CRC32C.value((31 to 0 by -1).map(_.toByte).toArray) === 0x113fdb5c)
31 |
32 | val bytes = Array(
33 | 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
34 | 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
35 | 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
36 | 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
37 | ).map(_.toByte)
38 | assert(CRC32C.value(bytes) === 0xd9963a56)
39 |
40 | assert(CRC32C.value("a") !== CRC32C.value("foo"))
41 | }
42 |
43 | @Test def testExtend(): Unit = {
44 | assert(CRC32C.value("hello world") === CRC32C.extend(CRC32C.value("hello "), "world"))
45 | }
46 |
47 | @Test def testMask(): Unit = {
48 | val crc = CRC32C.value("foo")
49 | assert(crc !== CRC32C.mask(crc))
50 | assert(crc !== CRC32C.mask(CRC32C.mask(crc)))
51 | assert(crc === CRC32C.unmask(CRC32C.mask(crc)))
52 | assert(crc === CRC32C.unmask(CRC32C.unmask(CRC32C.mask(CRC32C.mask(crc)))))
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/modules/data/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/modules/data/src/main/scala/org/platanios/tensorflow/data/utilities/CompressedFiles.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.data.utilities
17 |
18 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream
19 | import org.apache.commons.compress.utils.IOUtils
20 |
21 | import java.io.{File, FileOutputStream, InputStream}
22 | import java.nio.file.{Files, Path}
23 | import java.util.zip.GZIPInputStream
24 |
25 | /**
26 | * @author Emmanouil Antonios Platanios
27 | */
28 | object CompressedFiles {
29 | def decompressTGZ(tgzFilePath: Path, destinationPath: Path, bufferSize: Int = 8192): Unit = {
30 | decompressTGZStream(Files.newInputStream(tgzFilePath), destinationPath, bufferSize)
31 | }
32 |
33 | def decompressTar(tarFilePath: Path, destinationPath: Path, bufferSize: Int = 8192): Unit = {
34 | decompressTarStream(Files.newInputStream(tarFilePath), destinationPath, bufferSize)
35 | }
36 |
37 | def decompressTGZStream(tgzStream: InputStream, destinationPath: Path, bufferSize: Int = 8192): Unit = {
38 | decompressTarStream(new GZIPInputStream(tgzStream), destinationPath, bufferSize)
39 | }
40 |
41 | def decompressTarStream(tarStream: InputStream, destinationPath: Path, bufferSize: Int = 8192): Unit = {
42 | val inputStream = new TarArchiveInputStream(tarStream)
43 | var entry = inputStream.getNextTarEntry
44 | while (entry != null) {
45 | if (!entry.isDirectory) {
46 | val currentFile = new File(destinationPath.toAbsolutePath.toString, entry.getName)
47 | val parentFile = currentFile.getParentFile
48 | if (!parentFile.exists)
49 | parentFile.mkdirs()
50 | IOUtils.copy(inputStream, new FileOutputStream(currentFile))
51 | }
52 | entry = inputStream.getNextTarEntry
53 | }
54 | inputStream.close()
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/modules/data/src/test/scala/org/platanios/tensorflow/data/image/MNISTLoaderSpec.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.data.image
17 |
18 | import org.scalatest.flatspec.AnyFlatSpec
19 |
20 | /**
21 | * @author Emmanouil Antonios Platanios
22 | */
23 | class MNISTLoaderSpec extends AnyFlatSpec {
24 | // val directory: Path = Paths.get("/Users/Anthony/Development/GitHub/tensorflow_scala/temp/data/mnist")
25 | // Files.createDirectories(directory)
26 |
27 | "The MNIST data set loader" must "work" in {
28 | // val dataSet = MNISTLoader.load(directory)
29 | // val label0 = dataSet.trainLabels.summarize(10)
30 | // print(dataSet)
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "linear-regression"
2 | all_model_checkpoint_paths: "virgin-linear-regression"
3 | all_model_checkpoint_paths: "linear-regression"
4 |
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/linear-regression.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.data-00000-of-00001
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/linear-regression.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.index
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/linear-regression.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.meta
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/virgin-linear-regression.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.data-00000-of-00001
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/virgin-linear-regression.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.index
--------------------------------------------------------------------------------
/modules/examples/src/main/resources/python2scala/virgin-linear-regression.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.meta
--------------------------------------------------------------------------------
/modules/jni/src/main/native/checkpoint_reader.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_CheckpointReader__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_CheckpointReader__
6 | #define _Included_org_platanios_tensorflow_jni_CheckpointReader__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 | /*
11 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
12 | * Method: newCheckpointReader
13 | * Signature: (Ljava/lang/String;)J
14 | */
15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_newCheckpointReader
16 | (JNIEnv *, jobject, jstring);
17 |
18 | /*
19 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
20 | * Method: debugString
21 | * Signature: (J)Ljava/lang/String;
22 | */
23 | JNIEXPORT jstring JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_debugString
24 | (JNIEnv *, jobject, jlong);
25 |
26 | /*
27 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
28 | * Method: hasTensor
29 | * Signature: (JLjava/lang/String;)B
30 | */
31 | JNIEXPORT jboolean JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_hasTensor
32 | (JNIEnv *, jobject, jlong, jstring);
33 |
34 | /*
35 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
36 | * Method: getTensor
37 | * Signature: (JLjava/lang/String;)J
38 | */
39 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_getTensor
40 | (JNIEnv *, jobject, jlong, jstring);
41 |
42 | /*
43 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
44 | * Method: variableShapes
45 | * Signature: (J)Lorg/platanios/tensorflow/jni/VariableShapes;
46 | */
47 | JNIEXPORT jobject JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_variableShapes
48 | (JNIEnv *, jobject, jlong);
49 |
50 | /*
51 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
52 | * Method: variableDataTypes
53 | * Signature: (J)Lorg/platanios/tensorflow/jni/VariableDataTypes;
54 | */
55 | JNIEXPORT jobject JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_variableDataTypes
56 | (JNIEnv *, jobject, jlong);
57 |
58 | /*
59 | * Class: org_platanios_tensorflow_jni_CheckpointReader__
60 | * Method: delete
61 | * Signature: (J)V
62 | */
63 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_delete
64 | (JNIEnv *, jobject, jlong);
65 |
66 | #ifdef __cplusplus
67 | }
68 | #endif
69 | #endif
70 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/function.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_Function__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_Function__
6 | #define _Included_org_platanios_tensorflow_jni_Function__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 |
11 | /*
12 | * Class: org_platanios_tensorflow_jni_Function__
13 | * Method: graphToFunction
14 | * Signature: (JLjava/lang/String;Z[J[J[I[J[I[Ljava/lang/String;)J
15 | */
16 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Function_00024_graphToFunction
17 | (JNIEnv *, jobject, jlong, jstring, jboolean, jlongArray, jlongArray, jintArray, jlongArray, jintArray, jobjectArray);
18 |
19 | /*
20 | * Class: org_platanios_tensorflow_jni_Function__
21 | * Method: copyToGraph
22 | * Signature: (JJJ)V
23 | */
24 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Function_00024_copyToGraph
25 | (JNIEnv *, jobject, jlong, jlong, jlong);
26 |
27 | /*
28 | * Class: org_platanios_tensorflow_jni_Function__
29 | * Method: toFunctionDef
30 | * Signature: (J)[B
31 | */
32 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Function_00024_toFunctionDef
33 | (JNIEnv *, jobject, jlong);
34 |
35 | /*
36 | * Class: org_platanios_tensorflow_jni_Function__
37 | * Method: delete
38 | * Signature: (J)V
39 | */
40 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Function_00024_delete
41 | (JNIEnv *, jobject, jlong);
42 |
43 | #ifdef __cplusplus
44 | }
45 | #endif
46 | #endif
47 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/generated/tensor_random_ops.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_generated_tensors_Random__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Random__
6 | #define _Included_org_platanios_tensorflow_jni_generated_tensors_Random__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 | /*
11 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__
12 | * Method: randomShuffle
13 | * Signature: (JJJJ)J
14 | */
15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomShuffle
16 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong);
17 |
18 | /*
19 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__
20 | * Method: randomUniform
21 | * Signature: (JJIJJ)J
22 | */
23 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomUniform
24 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong);
25 |
26 | /*
27 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__
28 | * Method: randomUniformInt
29 | * Signature: (JJJJJJ)J
30 | */
31 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomUniformInt
32 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jlong, jlong);
33 |
34 | /*
35 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__
36 | * Method: randomStandardNormal
37 | * Signature: (JJIJJ)J
38 | */
39 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomStandardNormal
40 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong);
41 |
42 | /*
43 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__
44 | * Method: truncatedNormal
45 | * Signature: (JJIJJ)J
46 | */
47 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_truncatedNormal
48 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong);
49 |
50 | #ifdef __cplusplus
51 | }
52 | #endif
53 | #endif
54 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/generated/tensor_sparse_ops.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_generated_tensors_Sparse__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Sparse__
6 | #define _Included_org_platanios_tensorflow_jni_generated_tensors_Sparse__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 | /*
11 | * Class: org_platanios_tensorflow_jni_generated_tensors_Sparse__
12 | * Method: sparseToDense
13 | * Signature: (JJJJJZ)J
14 | */
15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Sparse_00024_sparseToDense
16 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jlong, jboolean);
17 |
18 | #ifdef __cplusplus
19 | }
20 | #endif
21 | #endif
22 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/graph.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_Graph__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_Graph__
6 | #define _Included_org_platanios_tensorflow_jni_Graph__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 |
11 | /*
12 | * Class: org_platanios_tensorflow_jni_Graph__
13 | * Method: allocate
14 | * Signature: ()J
15 | */
16 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_allocate
17 | (JNIEnv *, jobject);
18 |
19 | /*
20 | * Class: org_platanios_tensorflow_jni_Graph__
21 | * Method: delete
22 | * Signature: (J)V
23 | */
24 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_delete
25 | (JNIEnv *, jobject, jlong);
26 |
27 | /*
28 | * Class: org_platanios_tensorflow_jni_Graph__
29 | * Method: findOp
30 | * Signature: (JLjava/lang/String;)J
31 | */
32 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_findOp
33 | (JNIEnv *, jobject, jlong, jstring);
34 |
35 | /*
36 | * Class: org_platanios_tensorflow_jni_Graph__
37 | * Method: ops
38 | * Signature: (J)[J
39 | */
40 | JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_ops
41 | (JNIEnv *, jobject, jlong);
42 |
43 | /*
44 | * Class: org_platanios_tensorflow_jni_Graph__
45 | * Method: addGradients
46 | * Signature: (J[Lorg/platanios/tensorflow/jni/Output;[Lorg/platanios/tensorflow/jni/Output;[Lorg/platanios/tensorflow/jni/Output;)[Lorg/platanios/tensorflow/jni/Output;
47 | */
48 | JNIEXPORT jobjectArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_addGradients
49 | (JNIEnv *, jobject, jlong, jobjectArray, jobjectArray, jobjectArray);
50 |
51 | /*
52 | * Class: org_platanios_tensorflow_jni_Graph__
53 | * Method: importGraphDef
54 | * Signature: (J[BLjava/lang/String;[Ljava/lang/String;[I[J[I[Ljava/lang/String;[J[J)V
55 | */
56 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_importGraphDef
57 | (JNIEnv *, jobject, jlong, jbyteArray, jstring, jobjectArray, jintArray, jlongArray, jintArray, jobjectArray, jlongArray, jlongArray);
58 |
59 | /*
60 | * Class: org_platanios_tensorflow_jni_Graph__
61 | * Method: toGraphDef
62 | * Signature: (J)[B
63 | */
64 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_toGraphDef
65 | (JNIEnv *, jobject, jlong);
66 |
67 | #ifdef __cplusplus
68 | }
69 | #endif
70 | #endif
71 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/include/tensorflow/c/c_api_macros.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef TENSORFLOW_C_C_API_MACROS_H_
17 | #define TENSORFLOW_C_C_API_MACROS_H_
18 |
19 | #ifdef SWIG
20 | #define TF_CAPI_EXPORT
21 | #else
22 | #if defined(_WIN32)
23 | #ifdef TF_COMPILE_LIBRARY
24 | #define TF_CAPI_EXPORT __declspec(dllexport)
25 | #else
26 | #define TF_CAPI_EXPORT __declspec(dllimport)
27 | #endif // TF_COMPILE_LIBRARY
28 | #else
29 | #define TF_CAPI_EXPORT __attribute__((visibility("default")))
30 | #endif // _WIN32
31 | #endif // SWIG
32 |
33 | // TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
34 | // the datatype for boolean tensors.
35 | #ifndef TF_Bool
36 | #define TF_Bool unsigned char
37 | #endif // TF_Bool
38 |
39 | // Macro used to calculate struct size for maintaining ABI stability across
40 | // different struct implementations.
41 | #ifndef TF_OFFSET_OF_END
42 | #define TF_OFFSET_OF_END(TYPE, MEMBER) \
43 | (offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
44 | #endif // TF_OFFSET_OF_END
45 |
46 | #endif // TENSORFLOW_C_C_API_MACROS_H_
47 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/include/tensorflow/c/eager/dlpack.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef TENSORFLOW_C_EAGER_DLPACK_H_
17 | #define TENSORFLOW_C_EAGER_DLPACK_H_
18 |
19 | #include "tensorflow/c/eager/c_api.h"
20 |
21 | namespace tensorflow {
22 |
23 | // PyCapsule name for DLPack Tensor
24 | const char* const kDlTensorCapsuleName = "dltensor";
25 |
26 | // Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
27 | // void* for further PyCapsule construction.
28 | TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
29 | TF_Status* status);
30 |
31 | // Converts DLPack (DLManagedTensor*) to eager tensor handle.
32 | TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
33 | TF_Status* status,
34 | TFE_Context* ctx);
35 |
36 | // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
37 | TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
38 | } // namespace tensorflow
39 |
40 | #endif // TENSORFLOW_C_EAGER_DLPACK_H_
41 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/include/tensorflow/c/tf_attrtype.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 | #ifndef TENSORFLOW_C_TF_ATTRTYPE_H_
16 | #define TENSORFLOW_C_TF_ATTRTYPE_H_
17 |
18 | #ifdef __cplusplus
19 | extern "C" {
20 | #endif
21 |
22 | // TF_AttrType describes the type of the value of an attribute on an operation.
23 | typedef enum TF_AttrType {
24 | TF_ATTR_STRING = 0,
25 | TF_ATTR_INT = 1,
26 | TF_ATTR_FLOAT = 2,
27 | TF_ATTR_BOOL = 3,
28 | TF_ATTR_TYPE = 4,
29 | TF_ATTR_SHAPE = 5,
30 | TF_ATTR_TENSOR = 6,
31 | TF_ATTR_PLACEHOLDER = 7,
32 | TF_ATTR_FUNC = 8,
33 | } TF_AttrType;
34 |
35 | #ifdef __cplusplus
36 | } /* end extern "C" */
37 | #endif
38 |
39 | #endif // TENSORFLOW_C_TF_ATTRTYPE_H_
40 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/include/tensorflow/c/tf_file_statistics.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef TENSORFLOW_C_TF_FILE_STATISTICS_H_
17 | #define TENSORFLOW_C_TF_FILE_STATISTICS_H_
18 |
19 | #include
20 |
21 | typedef struct TF_FileStatistics {
22 | // The length of the file in bytes.
23 | int64_t length;
24 | // The last modified time in nanoseconds.
25 | int64_t mtime_nsec;
26 | // Whether the name refers to a directory.
27 | bool is_directory;
28 | } TF_FileStatistics;
29 |
30 | // TODO(mihaimaruseac): `tensorflow::FileStatistics` from
31 | // `core/platform/file_statistics.h` is a duplicate of this so maybe try to
32 | // remove duplication later?
33 |
34 | #endif // TENSORFLOW_C_TF_FILE_STATISTICS_H_
35 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/include/tensorflow/c/tf_tstring.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 | #ifndef TENSORFLOW_C_TF_TSTRING_H_
16 | #define TENSORFLOW_C_TF_TSTRING_H_
17 |
18 | #include "tensorflow/core/platform/ctstring.h"
19 |
20 | #endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_
21 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/ops/beam_search_ops.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
17 | #define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
18 |
19 | #include "tensorflow/core/framework/tensor_types.h"
20 | #include "tensorflow/core/platform/types.h"
21 |
22 | namespace tensorflow {
23 | class OpKernelContext;
24 |
25 | namespace functor {
26 |
27 | template
28 | struct GatherTree {
29 | void operator()(OpKernelContext* ctx, const Device& d,
30 | typename TTypes::ConstTensor step_ids,
31 | typename TTypes::ConstTensor parent_ids,
32 | TTypes::ConstVec max_sequence_lengths,
33 | const T end_token, typename TTypes::Tensor beams);
34 | };
35 |
36 | } // namespace functor
37 | } // namespace tensorflow
38 |
39 | #endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
40 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/ops/jvm_callback_op.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | #ifndef TENSORFLOW_JVM_CALLBACK_OP_H_
17 | #define TENSORFLOW_JVM_CALLBACK_OP_H_
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | #include "tensorflow/c/c_api.h"
25 | #include "tensorflow/c/kernels.h"
26 |
27 | // A call to the registered JVM function.
28 | struct JVMCall {
29 | JNIEnv* env;
30 | jclass registry;
31 | jmethodID call_method_id;
32 |
33 | TF_OpKernelContext* ctx;
34 |
35 | // True if and only if this op has been placed on a GPU.
36 | bool gpu;
37 |
38 | // Passed to the JVM to call the function registered with this ID.
39 | int id;
40 |
41 | // Inputs and outputs of this function invocation.
42 | std::vector inputs;
43 | std::vector outputs;
44 | };
45 |
46 | #endif // TENSORFLOW_JVM_CALLBACK_OP_H_
47 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/server.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_Server__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_Server__
6 | #define _Included_org_platanios_tensorflow_jni_Server__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 | /*
11 | * Class: org_platanios_tensorflow_jni_Server__
12 | * Method: newServer
13 | * Signature: ([B)J
14 | */
15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Server_00024_newServer
16 | (JNIEnv *, jobject, jbyteArray);
17 |
18 | /*
19 | * Class: org_platanios_tensorflow_jni_Server__
20 | * Method: target
21 | * Signature: (J)Ljava/lang/String;
22 | */
23 | JNIEXPORT jstring JNICALL Java_org_platanios_tensorflow_jni_Server_00024_target
24 | (JNIEnv *, jobject, jlong);
25 |
26 | /*
27 | * Class: org_platanios_tensorflow_jni_Server__
28 | * Method: startServer
29 | * Signature: (J)V
30 | */
31 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_startServer
32 | (JNIEnv *, jobject, jlong);
33 |
34 | /*
35 | * Class: org_platanios_tensorflow_jni_Server__
36 | * Method: stopServer
37 | * Signature: (J)V
38 | */
39 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_stopServer
40 | (JNIEnv *, jobject, jlong);
41 |
42 | /*
43 | * Class: org_platanios_tensorflow_jni_Server__
44 | * Method: joinServer
45 | * Signature: (J)V
46 | */
47 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_joinServer
48 | (JNIEnv *, jobject, jlong);
49 |
50 | /*
51 | * Class: org_platanios_tensorflow_jni_Server__
52 | * Method: deleteServer
53 | * Signature: (J)V
54 | */
55 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_deleteServer
56 | (JNIEnv *, jobject, jlong);
57 |
58 | #ifdef __cplusplus
59 | }
60 | #endif
61 | #endif
62 |
--------------------------------------------------------------------------------
/modules/jni/src/main/native/session.h:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 | #include
3 | /* Header for class org_platanios_tensorflow_jni_Session__ */
4 |
5 | #ifndef _Included_org_platanios_tensorflow_jni_Session__
6 | #define _Included_org_platanios_tensorflow_jni_Session__
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 | /*
11 | * Class: org_platanios_tensorflow_jni_Session__
12 | * Method: allocate
13 | * Signature: (JLjava/lang/String;[B)J
14 | */
15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Session_00024_allocate
16 | (JNIEnv *, jobject, jlong, jstring, jbyteArray);
17 |
18 | /*
19 | * Class: org_platanios_tensorflow_jni_Session__
20 | * Method: delete
21 | * Signature: (J)V
22 | */
23 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Session_00024_delete
24 | (JNIEnv *, jobject, jlong);
25 |
26 | /*
27 | * Class: org_platanios_tensorflow_jni_Session__
28 | * Method: run
29 | * Signature: (J[B[J[J[I[J[I[JZ[J)[B
30 | */
31 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Session_00024_run
32 | (JNIEnv *, jobject, jlong, jbyteArray, jlongArray, jlongArray, jintArray, jlongArray, jintArray, jlongArray, jboolean, jlongArray);
33 |
34 | ///*
35 | // * Class: org_platanios_tensorflow_jni_Session__
36 | // * Method: extend
37 | // * Signature: (J)V
38 | // */
39 | //JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Session_00024_extend
40 | // (JNIEnv *, jobject, jlong);
41 |
42 | /*
43 | * Class: org_platanios_tensorflow_jni_Session__
44 | * Method: deviceList
45 | * Signature: ([B)[[B
46 | */
47 | JNIEXPORT jobjectArray JNICALL Java_org_platanios_tensorflow_jni_Session_00024_deviceList
48 | (JNIEnv *, jobject, jbyteArray);
49 |
50 | #ifdef __cplusplus
51 | }
52 | #endif
53 | #endif
54 |
--------------------------------------------------------------------------------
/modules/jni/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/CheckpointReader.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | object CheckpointReader {
22 | TensorFlow.load()
23 |
24 | @native def newCheckpointReader(filePattern: String): Long
25 | @native def debugString(handle: Long): String
26 | @native def hasTensor(handle: Long, name: String): Boolean
27 | @native def getTensor(handle: Long, name: String): Long
28 | @native def variableShapes(handle: Long): VariableShapes
29 | @native def variableDataTypes(handle: Long): VariableDataTypes
30 | @native def delete(handle: Long): Unit
31 | }
32 |
33 | case class VariableShapes(variables: Array[String], shapes: Array[Array[Long]])
34 | case class VariableDataTypes(variables: Array[String], dataTypes: Array[Int])
35 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/Function.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | object Function {
22 | TensorFlow.load()
23 |
24 | @native def graphToFunction(
25 | fnBodyGraphHandle: Long,
26 | fnName: String,
27 | appendHashToFnName: Boolean,
28 | opHandles: Array[Long],
29 | inputOpHandles: Array[Long],
30 | inputOpIndices: Array[Int],
31 | outputOpHandles: Array[Long],
32 | outputOpIndices: Array[Int],
33 | outputNames: Array[String]
34 | ): Long
35 |
36 | @native def copyToGraph(graphHandle: Long, functionHandle: Long, gradientHandle: Long): Unit
37 | @native def toFunctionDef(handle: Long): Array[Byte]
38 | @native def delete(handle: Long): Unit
39 | }
40 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/Graph.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | object Graph {
22 | TensorFlow.load()
23 |
24 | @native def allocate(): Long
25 | @native def delete(handle: Long): Unit
26 | @native def findOp(handle: Long, name: String): Long
27 | @native def ops(handle: Long): Array[Long]
28 |
29 | @native def addGradients(
30 | handle: Long,
31 | y: Array[Output],
32 | x: Array[Output],
33 | dx: Array[Output]
34 | ): Array[Output]
35 |
36 | @native def importGraphDef(
37 | handle: Long,
38 | graphDef: Array[Byte],
39 | prefix: String,
40 | inputsMapSourceOpNames: Array[String],
41 | inputsMapSourceOutputIndices: Array[Int],
42 | inputsMapDestinationOpHandles: Array[Long],
43 | inputsMapDestinationOutputIndices: Array[Int],
44 | controlDependenciesMapSourceOpNames: Array[String],
45 | controlDependenciesMapDestinationOpHandles: Array[Long],
46 | controlDependenciesOpHandles: Array[Long]): Unit
47 |
48 | @native def toGraphDef(handle: Long): Array[Byte]
49 | }
50 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/ScalaCallbacksRegistry.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | import scala.collection.mutable
19 |
20 | /** Keeps a map from unique tokens (i.e., integer IDs) to Scala functions (i.e., callbacks), each of which takes an
21 | * array with handles of tensors as input and returns an array with handles of tensors as output.
22 | *
23 | * @author Emmanouil Antonios Platanios
24 | */
25 | object ScalaCallbacksRegistry {
26 | private[this] var uniqueId = 0
27 | private[this] val callbacks = mutable.Map.empty[Int, Array[Long] => Array[Long]]
28 |
29 | /** Number of callbacks currently registered. */
30 | def size: Int = callbacks.size
31 |
32 | /** Registers the provided callback function and returns a unique token to use when creating ops invoking it. */
33 | def register(function: Array[Long] => Array[Long]): Int = this synchronized {
34 | val token = uniqueId
35 | callbacks.update(uniqueId, function)
36 | uniqueId += 1
37 | token
38 | }
39 |
40 | /** De-registers (i.e., removes from this registry) the function that corresponds to the provided token. */
41 | def deregister(token: Int): Unit = this synchronized {
42 | callbacks.remove(token)
43 | }
44 |
45 | /** Invokes the callback identified by `token` using the provides input arguments. */
46 | def call(token: Int, inputs: Array[Long]): Array[Long] = {
47 | val callback = this synchronized callbacks(token)
48 | callback(inputs)
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/Server.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | /**
19 | * @author Emmanouil Antonios Platanios
20 | */
21 | object Server {
22 | TensorFlow.load()
23 |
24 | @native def newServer(serverDef: Array[Byte]): Long
25 | @native def target(serverHandle: Long): String
26 | @native def startServer(serverHandle: Long): Unit
27 | @native def stopServer(serverHandle: Long): Unit
28 | @native def joinServer(serverHandle: Long): Unit
29 | @native def deleteServer(serverHandle: Long): Unit
30 | }
31 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/Tensor.scala:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | package org.platanios.tensorflow.jni
17 |
18 | import java.nio.ByteBuffer
19 |
20 | /**
21 | * @author Emmanouil Antonios Platanios
22 | */
23 | object Tensor {
24 | TensorFlow.load()
25 |
26 | //region TensorFlow C Tensors
27 |
28 | @native def allocate(dataType: Int, shape: Array[Long], numBytes: Long): Long
29 | @native def fromBuffer(dataType: Int, shape: Array[Long], numBytes: Long, buffer: ByteBuffer): Long
30 | @native def dataType(handle: Long): Int
31 | @native def shape(handle: Long): Array[Long]
32 | @native def buffer(handle: Long): ByteBuffer
33 | @native def delete(handle: Long): Unit
34 |
35 | //endregion TensorFlow C Tensors
36 |
37 | //region TensorFlow Eager Tensors
38 |
39 | @native def eagerAllocateContext(configProto: Array[Byte]): Long
40 | @native def eagerDeleteContext(handle: Long): Unit
41 | @native def eagerAllocate(tensorHandle: Long): Long
42 | @native def eagerDataType(handle: Long): Int
43 | @native def eagerShape(handle: Long): Array[Long]
44 | @native def eagerDevice(handle: Long): String
45 | @native def eagerDelete(handle: Long): Unit
46 | @native def eagerResolve(handle: Long): Long
47 | @native def eagerCopyToDevice(handle: Long, contextHandle: Long, device: String): Long
48 |
49 | @native def eagerNewOp(contextHandle: Long, opOrFunctionName: String): Long
50 | @native def eagerDeleteOp(opHandle: Long): Unit
51 | @native def eagerSetOpDevice(opHandle: Long, device: String): Unit
52 |
53 | //endregion TensorFlow Eager Tensors
54 |
55 | //region String Helpers
56 |
57 | @native def setStringBytes(stringBytes: Array[Byte], buffer: ByteBuffer): Int
58 | @native def getStringBytes(buffer: ByteBuffer): Array[Byte]
59 | @native def tfStringSize(): Int
60 |
61 | //endregion String Helpers
62 | }
63 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Random.scala:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 |
3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
6 | * use this file except in compliance with the License. You may obtain a copy of
7 | * the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 | * License for the specific language governing permissions and limitations under
15 | * the License.
16 | */
17 |
18 | package org.platanios.tensorflow.jni.generated.tensors
19 |
20 | import org.platanios.tensorflow.jni.TensorFlow
21 |
22 | object Random {
23 | TensorFlow.load()
24 |
25 | @native def randomShuffle(contextHandle: Long, value: Long, seed: Long, seed2: Long): Long
26 | @native def randomUniform(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long
27 | @native def randomUniformInt(contextHandle: Long, shape: Long, minval: Long, maxval: Long, seed: Long, seed2: Long): Long
28 | @native def randomStandardNormal(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long
29 | @native def truncatedNormal(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long
30 | }
31 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Sparse.scala:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 |
3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
6 | * use this file except in compliance with the License. You may obtain a copy of
7 | * the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 | * License for the specific language governing permissions and limitations under
15 | * the License.
16 | */
17 |
18 | package org.platanios.tensorflow.jni.generated.tensors
19 |
20 | import org.platanios.tensorflow.jni.TensorFlow
21 |
22 | object Sparse {
23 | TensorFlow.load()
24 |
25 | @native def sparseToDense(contextHandle: Long, sparse_indices: Long, output_shape: Long, sparse_values: Long, default_value: Long, validate_indices: Boolean): Long
26 | }
27 |
--------------------------------------------------------------------------------
/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Text.scala:
--------------------------------------------------------------------------------
1 | /* DO NOT EDIT THIS FILE - it is machine generated */
2 |
3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
6 | * use this file except in compliance with the License. You may obtain a copy of
7 | * the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 | * License for the specific language governing permissions and limitations under
15 | * the License.
16 | */
17 |
18 | package org.platanios.tensorflow.jni.generated.tensors
19 |
20 | import org.platanios.tensorflow.jni.TensorFlow
21 |
22 | object Text {
23 | TensorFlow.load()
24 |
25 | @native def stringJoin(contextHandle: Long, inputs: Array[Long], separator: Array[Byte]): Long
26 | @native def stringSplit(contextHandle: Long, input: Long, delimiter: Long, skip_empty: Boolean): Array[Long]
27 | @native def encodeBase64(contextHandle: Long, input: Long, pad: Boolean): Long
28 | @native def decodeBase64(contextHandle: Long, input: Long): Long
29 | @native def stringToHashBucket(contextHandle: Long, string_tensor: Long, num_buckets: Long): Long
30 | @native def stringToHashBucketFast(contextHandle: Long, input: Long, num_buckets: Long): Long
31 | @native def stringToHashBucketStrong(contextHandle: Long, input: Long, num_buckets: Long, key: Array[Long]): Long
32 | }
33 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/allocation_description.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "AllocationDescriptionProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/allocation_description_go_proto";
10 |
11 | message AllocationDescription {
12 | // Total number of bytes requested
13 | int64 requested_bytes = 1;
14 |
15 | // Total number of bytes allocated if known
16 | int64 allocated_bytes = 2;
17 |
18 | // Name of the allocator used
19 | string allocator_name = 3;
20 |
21 | // Identifier of the allocated buffer if known
22 | int64 allocation_id = 4;
23 |
24 | // Set if this tensor only has one remaining reference
25 | bool has_single_reference = 5;
26 |
27 | // Address of the allocation.
28 | uint64 ptr = 6;
29 | }
30 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/autotuning.proto:
--------------------------------------------------------------------------------
1 | // This file defines protos that store the results of autotuning various
2 | // operations.
3 | //
4 | // They are in proto format because we want to log them structured. They offer
5 | // tremendous statistical, testing, and debugging value.
6 | syntax = "proto3";
7 |
8 | package org.platanios.tensorflow.proto;
9 |
10 | import "any.proto";
11 | import "duration.proto";
12 |
13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
14 |
15 | message CudnnVersion {
16 | int32 major = 1;
17 | int32 minor = 2;
18 | int32 patch = 3;
19 | }
20 |
21 | message ComputeCapability {
22 | int32 major = 1;
23 | int32 minor = 2;
24 | }
25 |
26 | message AutotuneResult {
27 | enum FailureKind {
28 | UNKNOWN = 0;
29 | REDZONE_MODIFIED = 1;
30 | WRONG_RESULT = 2;
31 | }
32 |
33 | message FailureResult {
34 | FailureKind kind = 1;
35 | string msg = 2;
36 |
37 | // For failure_kind == WRONG_RESULT, this field indicates the reference
38 | // configuration that we compared against.
39 | //
40 | // Note that the reference algorithm isn't always correct. However,
41 | // empirically it's more correct, as it's "algo 0", less fancy than the
42 | // compared one.
43 | oneof key {
44 | ConvKey reference_conv = 11;
45 | GemmKey reference_gemm = 12;
46 | }
47 |
48 | int64 buffer_address = 13;
49 | }
50 |
51 | message ConvKey {
52 | int64 algorithm = 1;
53 | bool tensor_ops_enabled = 2;
54 | }
55 |
56 | message GemmKey {
57 | int64 algorithm = 1;
58 | }
59 |
60 | int64 scratch_bytes = 8;
61 | tensorflow.proto.Duration run_time = 9;
62 |
63 | FailureResult failure = 7;
64 |
65 | oneof key {
66 | ConvKey conv = 5;
67 | GemmKey gemm = 6;
68 | }
69 |
70 | // Next ID: 14
71 | }
72 |
73 | message AutotuningLog {
74 | google.Any instr = 1;
75 |
76 | // Records all auto-tuning results per algorithm.
77 | repeated AutotuneResult results = 2;
78 |
79 | CudnnVersion cudnn_version = 3;
80 | ComputeCapability compute_capability = 4;
81 |
82 | // stream_executor::DeviceDescription::pci_bus_id.
83 | string device_pci_bus_id = 5;
84 |
85 | string blas_version = 6;
86 |
87 | // Next ID: 7
88 | }
89 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/bfc_memory_map.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
6 |
7 | // Some of the data from AllocatorStats
8 | message MemAllocatorStats {
9 | int64 num_allocs = 1;
10 | int64 bytes_in_use = 2;
11 | int64 peak_bytes_in_use = 3;
12 | int64 largest_alloc_size = 4;
13 | float fragmentation_metric = 5;
14 | }
15 |
16 | message MemChunk {
17 | uint64 address = 1;
18 | int64 size = 2;
19 | int64 requested_size = 3;
20 | int32 bin = 4;
21 | string op_name = 5;
22 | uint64 freed_at_count = 6;
23 | uint64 action_count = 7;
24 | bool in_use = 8;
25 | uint64 step_id = 9;
26 | }
27 |
28 | message BinSummary {
29 | int32 bin = 1;
30 | int64 total_bytes_in_use = 2;
31 | int64 total_bytes_in_bin = 3;
32 | int64 total_chunks_in_use = 4;
33 | int64 total_chunks_in_bin = 5;
34 | }
35 |
36 | message SnapShot {
37 | uint64 action_count = 1;
38 | int64 size = 2;
39 | }
40 |
41 | message MemoryDump {
42 | string allocator_name = 1;
43 | repeated BinSummary bin_summary = 2;
44 | repeated MemChunk chunk = 3;
45 | repeated SnapShot snap_shot = 4;
46 | MemAllocatorStats stats = 5;
47 | }
48 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/checkpoint_state.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 | option cc_enable_arenas = true;
5 |
6 | option java_package = "org.platanios.tensorflow.proto";
7 | option java_outer_classname = "CheckpointStateProto";
8 |
9 | // Protocol buffer representing the checkpoint state.
10 | message CheckpointState {
11 | // Path to the most-recent model checkpoint.
12 | string model_checkpoint_path = 1;
13 |
14 | // Paths to all not-yet-deleted model checkpoints, sorted from oldest to newest.
15 | // Note that the value of model_checkpoint_path should be the last item in this list.
16 | repeated string all_model_checkpoint_paths = 2;
17 | }
18 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/critical_section.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "CriticalSectionProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
10 |
11 | // Protocol buffer representing a CriticalSection.
12 | message CriticalSectionDef {
13 | // Name of the critical section handle.
14 | string critical_section_name = 1;
15 | }
16 |
17 | // Protocol buffer representing a CriticalSection execution.
18 | message CriticalSectionExecutionDef {
19 | // Name of the critical section handle.
20 | string execute_in_critical_section_name = 1;
21 | // Whether this operation requires exclusive access to its resources,
22 | // (i.e., no other CriticalSections may request the same resources).
23 | bool exclusive_resource_access = 2;
24 | }
25 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/device_attributes.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "DeviceAttributesProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/device_attributes_go_proto";
10 |
11 | message InterconnectLink {
12 | int32 device_id = 1;
13 | string type = 2;
14 | int32 strength = 3;
15 | }
16 |
17 | message LocalLinks {
18 | repeated InterconnectLink link = 1;
19 | }
20 |
21 | message DeviceLocality {
22 | // Optional bus locality of device. Default value of 0 means
23 | // no specific locality. Specific localities are indexed from 1.
24 | int32 bus_id = 1;
25 |
26 | // Optional NUMA locality of device.
27 | int32 numa_node = 2;
28 |
29 | // Optional local interconnect links to other devices.
30 | LocalLinks links = 3;
31 | }
32 |
33 | message DeviceAttributes {
34 | // Fully specified name of the device within a cluster.
35 | string name = 1;
36 |
37 | // String representation of device_type.
38 | string device_type = 2;
39 |
40 | // Memory capacity of device in bytes.
41 | int64 memory_limit = 4;
42 |
43 | // Platform-specific data about device that may be useful
44 | // for supporting efficient data transfers.
45 | DeviceLocality locality = 5;
46 |
47 | // A device is assigned a global unique number each time it is
48 | // initialized. "incarnation" should never be 0.
49 | fixed64 incarnation = 6;
50 |
51 | // String representation of the physical device that this device maps to.
52 | string physical_device_desc = 7;
53 | }
54 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/device_properties.proto:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | syntax = "proto3";
17 |
18 | package org.platanios.tensorflow.proto;
19 |
20 | option cc_enable_arenas = true;
21 | option java_outer_classname = "DevicePropertiesProtos";
22 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
23 |
24 | message DeviceProperties {
25 | // Device type (CPU, GPU, ...)
26 | string type = 1;
27 | // Vendor (Intel, nvidia, ...)
28 | string vendor = 2;
29 | // Model (Haswell, K40, ...)
30 | string model = 3;
31 | // Core Frequency in Mhz
32 | int64 frequency = 4;
33 | // Number of cores
34 | int64 num_cores = 5;
35 | // Version of the tools and libraries used with this device (e.g. gcc 4.9,
36 | // cudnn 5.1)
37 | map environment = 6;
38 | // Number of registers per core.
39 | int64 num_registers = 7;
40 | // L1 cache size in bytes
41 | int64 l1_cache_size = 8;
42 | // L2 cache size in bytes
43 | int64 l2_cache_size = 9;
44 | // L3 cache size in bytes
45 | int64 l3_cache_size = 10;
46 | // Shared memory size per multiprocessor in bytes. This field is
47 | // applicable to GPUs only.
48 | int64 shared_memory_size_per_multiprocessor = 11;
49 | // Memory size in bytes
50 | int64 memory_size = 12;
51 | // Memory bandwidth in KB/s
52 | int64 bandwidth = 13;
53 | }
54 |
55 | message NamedDevice {
56 | string name = 1;
57 | DeviceProperties properties = 2;
58 | }
59 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/graph.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "function.proto";
6 | import "node_def.proto";
7 | import "versions.proto";
8 |
9 | option cc_enable_arenas = true;
10 | option java_outer_classname = "GraphProtos";
11 | option java_multiple_files = true;
12 | option java_package = "org.platanios.tensorflow.proto";
13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_go_proto";
14 |
15 | // Represents the graph of operations
16 | message GraphDef {
17 | repeated NodeDef node = 1;
18 |
19 | // Compatibility versions of the graph. See core/public/version.h for version
20 | // history. The GraphDef version is distinct from the TensorFlow version, and
21 | // each release of TensorFlow will support a range of GraphDef versions.
22 | VersionDef versions = 4;
23 |
24 | // Deprecated single version field; use versions above instead. Since all
25 | // GraphDef changes before "versions" was introduced were forward
26 | // compatible, this field is entirely ignored.
27 | int32 version = 3 [deprecated = true];
28 |
29 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
30 | //
31 | // "library" provides user-defined functions.
32 | //
33 | // Naming:
34 | // * library.function.name are in a flat namespace.
35 | // NOTE: We may need to change it to be hierarchical to support
36 | // different orgs. E.g.,
37 | // { "/google/nn", { ... }},
38 | // { "/google/vision", { ... }}
39 | // { "/org_foo/module_bar", { ... }}
40 | // map named_lib;
41 | // * If node[i].op is the name of one function in "library",
42 | // node[i] is deemed as a function call. Otherwise, node[i].op
43 | // must be a primitive operation supported by the runtime.
44 | //
45 | //
46 | // Function call semantics:
47 | //
48 | // * The callee may start execution as soon as some of its inputs
49 | // are ready. The caller may want to use Tuple() mechanism to
50 | // ensure all inputs are ready in the same time.
51 | //
52 | // * The consumer of return values may start executing as soon as
53 | // the return values the consumer depends on are ready. The
54 | // consumer may want to use Tuple() mechanism to ensure the
55 | // consumer does not start until all return values of the callee
56 | // function are ready.
57 | FunctionDefLibrary library = 2;
58 | }
59 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/graph_debug_info.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "GraphDebugInfoProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
10 |
11 | message GraphDebugInfo {
12 | // This represents a file/line location in the source code.
13 | message FileLineCol {
14 | // File name index, which can be used to retrieve the file name string from
15 | // `files`. The value should be between 0 and (len(files)-1)
16 | int32 file_index = 1;
17 |
18 | // Line number in the file.
19 | int32 line = 2;
20 |
21 | // Col number in the file line.
22 | int32 col = 3;
23 |
24 | // Name of function contains the file line.
25 | string func = 4;
26 |
27 | // Source code contained in this file line.
28 | string code = 5;
29 | }
30 |
31 | // This represents a stack trace which is a ordered list of `FileLineCol`.
32 | message StackTrace {
33 | // Each line in the stack trace.
34 | repeated FileLineCol file_line_cols = 1;
35 | }
36 |
37 | // This stores all the source code file names and can be indexed by the
38 | // `file_index`.
39 | repeated string files = 1;
40 |
41 | // This maps a node name to a stack trace in the source code.
42 | // The map key is a mangling of the containing function and op name with
43 | // syntax:
44 | // op.name '@' func_name
45 | // For ops in the top-level graph, the func_name is the empty string.
46 | // Note that op names are restricted to a small number of characters which
47 | // exclude '@', making it impossible to collide keys of this form. Function
48 | // names accept a much wider set of characters.
49 | // It would be preferable to avoid mangling and use a tuple key of (op.name,
50 | // func_name), but this is not supported with protocol buffers.
51 | map traces = 2;
52 | }
53 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/graph_transfer_info.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "types.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "GraphTransferInfoProto";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_transfer_info_go_proto";
12 |
13 | message GraphTransferNodeInput {
14 | int32 node_id = 1;
15 | int32 output_port = 2;
16 | }
17 | message GraphTransferNodeInfo {
18 | string name = 1;
19 | int32 node_id = 2;
20 | string type_name = 3;
21 | int32 soc_op_id = 4;
22 | int32 padding_id = 5;
23 | int32 input_count = 6;
24 | int32 output_count = 7;
25 | }
26 | message GraphTransferConstNodeInfo {
27 | string name = 1;
28 | int32 node_id = 2;
29 | repeated int64 shape = 3;
30 | bytes data = 4;
31 | DataType dtype = 5;
32 | }
33 | message GraphTransferNodeInputInfo {
34 | int32 node_id = 1;
35 | repeated GraphTransferNodeInput node_input = 2;
36 | }
37 | message GraphTransferNodeOutputInfo {
38 | int32 node_id = 1;
39 | repeated int32 max_byte_size = 2;
40 | }
41 | message GraphTransferGraphInputNodeInfo {
42 | string name = 1;
43 | repeated int64 shape = 2;
44 | DataType dtype = 3;
45 | }
46 |
47 | message GraphTransferGraphOutputNodeInfo {
48 | string name = 1;
49 | repeated int64 shape = 2;
50 | DataType dtype = 3;
51 | }
52 |
53 | // Protocol buffer representing a handle to a tensorflow resource. Handles are
54 | // not valid across executions, but can be serialized back and forth from within
55 | // a single run.
56 | message GraphTransferInfo {
57 | enum Destination {
58 | NOP = 0;
59 | HEXAGON = 1;
60 | }
61 |
62 | repeated GraphTransferNodeInfo node_info = 1;
63 | repeated GraphTransferConstNodeInfo const_node_info = 2;
64 | repeated GraphTransferNodeInputInfo node_input_info = 3;
65 | repeated GraphTransferNodeOutputInfo node_output_info = 4;
66 | // Input Node parameters of transferred graph
67 | repeated GraphTransferGraphInputNodeInfo graph_input_node_info = 5;
68 | repeated GraphTransferGraphOutputNodeInfo graph_output_node_info = 6;
69 | // Destination of graph transfer
70 | Destination destination = 7;
71 | }
72 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/kernel_def.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "attr_value.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "KernelDefProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/kernel_def_go_proto";
12 |
13 | message KernelDef {
14 | // Must match the name of an Op.
15 | string op = 1;
16 |
17 | // Type of device this kernel runs on.
18 | string device_type = 2;
19 |
20 | message AttrConstraint {
21 | // Name of an attr from the Op.
22 | string name = 1;
23 |
24 | // A list of values that this kernel supports for this attr.
25 | // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops.
26 | AttrValue allowed_values = 2;
27 | }
28 | repeated AttrConstraint constraint = 3;
29 |
30 | // Names of the Op's input_/output_args that reside in host memory
31 | // instead of device memory.
32 | repeated string host_memory_arg = 4;
33 |
34 | // This allows experimental kernels to be registered for an op that
35 | // won't be used unless the user specifies a "_kernel" attr with
36 | // value matching this.
37 | string label = 5;
38 |
39 | // Prioritization of kernel amongst different devices. By default we assume
40 | // priority is 0. The higher the priority the better. By default (i.e. if
41 | // this is not set), we prefer GPU kernels over CPU.
42 | int32 priority = 6;
43 | }
44 |
45 | // A collection of KernelDefs
46 | message KernelList {
47 | repeated KernelDef kernel = 1;
48 | }
49 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/named_tensor.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "tensor.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "NamedTensorProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
12 |
13 | // A pair of tensor name and tensor values.
14 | message NamedTensorProto {
15 | // Name of the tensor.
16 | string name = 1;
17 |
18 | // The client can populate a TensorProto using a tensorflow::Tensor`, or
19 | // directly using the protobuf field accessors.
20 | //
21 | // The client specifies whether the returned tensor values should be
22 | // filled tensor fields (float_val, int_val, etc.) or encoded in a
23 | // compact form in tensor.tensor_content.
24 | TensorProto tensor = 2;
25 | }
26 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/queue_runner.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "error_codes.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "QueueRunnerProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
12 |
13 | // Protocol buffer representing a QueueRunner.
14 | message QueueRunnerDef {
15 | // Queue name.
16 | string queue_name = 1;
17 |
18 | // A list of enqueue operations.
19 | repeated string enqueue_op_name = 2;
20 |
21 | // The operation to run to close the queue.
22 | string close_op_name = 3;
23 |
24 | // The operation to run to cancel the queue.
25 | string cancel_op_name = 4;
26 |
27 | // A list of exception types considered to signal a safely closed queue
28 | // if raised during enqueue operations.
29 | repeated error.Code queue_closed_exception_types = 5;
30 | }
31 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/reader_base.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "ReaderBaseProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/reader_base_go_proto";
10 |
11 | // For serializing and restoring the state of ReaderBase, see
12 | // reader_base.h for details.
13 | message ReaderBaseState {
14 | int64 work_started = 1;
15 | int64 work_finished = 2;
16 | int64 num_records_produced = 3;
17 | bytes current_work = 4;
18 | }
19 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/remote_tensor_handle.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "tensor_shape.proto";
6 | import "types.proto";
7 |
8 | option cc_enable_arenas = true;
9 | option java_outer_classname = "RemoteTensorHandleProtos";
10 | option java_multiple_files = true;
11 | option java_package = "org.platanios.tensorflow.proto";
12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
13 |
14 | message ResourceDtypeAndShape {
15 | DataType dtype = 1;
16 | TensorShapeProto shape = 2;
17 | }
18 |
19 | message RemoteTensorHandle {
20 | // The ID of the operation that produced this tensor.
21 | int64 op_id = 1;
22 | // The index into the outputs of the operation that produced this tensor.
23 | int32 output_num = 2;
24 | // Device of the operation that produced this tensor. Cannot be empty.
25 | // For multi-device functions, it's the default device passed to placer.
26 | string device = 3;
27 | // Device where the tensor is located. Can be empty if the operation producing
28 | // this tensor is a multi-device function.
29 | string op_device = 4;
30 | // Tensor type.
31 | DataType dtype = 5;
32 | // Optional data types and shapes of a remote resource variable.
33 | repeated ResourceDtypeAndShape resource_dtypes_and_shapes = 6;
34 | }
35 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/replay_log.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "master.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
9 |
10 | // Records the creation of a new replay session. We record the device listing
11 | // here to capture the state of the cluster.
12 | message NewReplaySession {
13 | ListDevicesResponse devices = 1;
14 | string session_handle = 2;
15 | }
16 |
17 | message ReplayOp {
18 | double start_time_us = 31;
19 | double end_time_us = 32;
20 |
21 | oneof op {
22 | CreateSessionRequest create_session = 1;
23 | ExtendSessionRequest extend_session = 2;
24 | PartialRunSetupRequest partial_run_setup = 3;
25 | RunStepRequest run_step = 4;
26 | CloseSessionRequest close_session = 5;
27 | ListDevicesRequest list_devices = 6;
28 | ResetRequest reset_request = 7;
29 | MakeCallableRequest make_callable = 8;
30 | RunCallableRequest run_callable = 9;
31 | ReleaseCallableRequest release_callable = 10;
32 | NewReplaySession new_replay_session = 11;
33 | }
34 |
35 | oneof response {
36 | CreateSessionResponse create_session_response = 21;
37 | ExtendSessionResponse extend_session_response = 22;
38 | PartialRunSetupResponse partial_run_setup_response = 23;
39 | RunStepResponse run_step_response = 24;
40 | CloseSessionResponse close_session_response = 25;
41 | ListDevicesResponse list_devices_response = 26;
42 | ResetResponse reset_request_response = 27;
43 | MakeCallableResponse make_callable_response = 28;
44 | RunCallableResponse run_callable_response = 29;
45 | ReleaseCallableResponse release_callable_response = 30;
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/resource_handle.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "tensor_shape.proto";
6 | import "types.proto";
7 |
8 | option cc_enable_arenas = true;
9 | option java_outer_classname = "ResourceHandle";
10 | option java_multiple_files = true;
11 | option java_package = "org.platanios.tensorflow.proto";
12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto";
13 |
14 | // Protocol buffer representing a handle to a tensorflow resource. Handles are
15 | // not valid across executions, but can be serialized back and forth from within
16 | // a single run.
17 | message ResourceHandleProto {
18 | // Unique name for the device containing the resource.
19 | string device = 1;
20 |
21 | // Container in which this resource is placed.
22 | string container = 2;
23 |
24 | // Unique name of this resource.
25 | string name = 3;
26 |
27 | // Hash code for the type of the resource. Is only valid in the same device
28 | // and in the same execution.
29 | uint64 hash_code = 4;
30 |
31 | // For debug-only, the name of the type pointed to by this handle, if
32 | // available.
33 | string maybe_type_name = 5;
34 |
35 | // Protocol buffer representing a pair of (data type, tensor shape).
36 | message DtypeAndShape {
37 | DataType dtype = 1;
38 | TensorShapeProto shape = 2;
39 | }
40 |
41 | // Data types and shapes for the underlying resource.
42 | repeated DtypeAndShape dtypes_and_shapes = 6;
43 |
44 | // A set of devices containing the resource. If empty, the resource only
45 | // exists on `device`.
46 | repeated string allowed_devices = 7;
47 | }
48 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/saved_model.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "meta_graph.proto";
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "SavedModelProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
12 |
13 | // SavedModel is the high level serialization format for TensorFlow Models.
14 | // See [todo: doc links, similar to session_bundle] for more information.
15 | message SavedModel {
16 | // The schema version of the SavedModel instance. Used for versioning when
17 | // making future changes to the specification/implementation. Initial value
18 | // at release will be 1.
19 | int64 saved_model_schema_version = 1;
20 |
21 | // One or more MetaGraphs.
22 | repeated MetaGraphDef meta_graphs = 2;
23 | }
24 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/saver.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "SaverProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
10 |
11 | // Protocol buffer representing the configuration of a Saver.
12 | message SaverDef {
13 | // The name of the tensor in which to specify the filename when saving or
14 | // restoring a model checkpoint.
15 | string filename_tensor_name = 1;
16 |
17 | // The operation to run when saving a model checkpoint.
18 | string save_tensor_name = 2;
19 |
20 | // The operation to run when restoring a model checkpoint.
21 | string restore_op_name = 3;
22 |
23 | // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
24 | int32 max_to_keep = 4;
25 |
26 | // Shard the save files, one per device that has Variable nodes.
27 | bool sharded = 5;
28 |
29 | // How often to keep an additional checkpoint. If not specified, only the last
30 | // "max_to_keep" checkpoints are kept; if specified, in addition to keeping
31 | // the last "max_to_keep" checkpoints, an additional checkpoint will be kept
32 | // for every n hours of training.
33 | float keep_checkpoint_every_n_hours = 6;
34 |
35 | // A version number that identifies a different on-disk checkpoint format.
36 | // Usually, each subclass of BaseSaverBuilder works with a particular
37 | // version/format. However, it is possible that the same builder may be
38 | // upgraded to support a newer checkpoint format in the future.
39 | enum CheckpointFormatVersion {
40 | // Internal legacy format.
41 | LEGACY = 0;
42 | // Deprecated format: tf.Saver() which works with tensorflow::table::Table.
43 | V1 = 1;
44 | // Current format: more efficient.
45 | V2 = 2;
46 | }
47 | CheckpointFormatVersion version = 7;
48 | }
49 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/source_context.proto:
--------------------------------------------------------------------------------
1 | // Protocol Buffers - Google's data interchange format
2 | // Copyright 2008 Google Inc. All rights reserved.
3 | // https://developers.google.com/protocol-buffers/
4 | //
5 | // Redistribution and use in source and binary forms, with or without
6 | // modification, are permitted provided that the following conditions are
7 | // met:
8 | //
9 | // * Redistributions of source code must retain the above copyright
10 | // notice, this list of conditions and the following disclaimer.
11 | // * Redistributions in binary form must reproduce the above
12 | // copyright notice, this list of conditions and the following disclaimer
13 | // in the documentation and/or other materials provided with the
14 | // distribution.
15 | // * Neither the name of Google Inc. nor the names of its
16 | // contributors may be used to endorse or promote products derived from
17 | // this software without specific prior written permission.
18 | //
19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
31 | syntax = "proto3";
32 |
33 | package org.platanios.tensorflow.proto;
34 |
35 | option csharp_namespace = "Google.Protobuf.WellKnownTypes";
36 | option java_package = "org.platanios.tensorflow.proto";
37 | option java_outer_classname = "SourceContextProto";
38 | option java_multiple_files = true;
39 | option objc_class_prefix = "GPB";
40 | option go_package = "google.golang.org/genproto/protobuf/source_context;source_context";
41 |
42 | // `SourceContext` represents information about the source of a
43 | // protobuf element, like the file in which it is defined.
44 | message SourceContext {
45 | // The path-qualified name of the .proto file that contained the associated
46 | // protobuf element. For example: `"google/protobuf/source_context.proto"`.
47 | string file_name = 1;
48 | }
49 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/tensor_bundle.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "tensor_shape.proto";
6 | import "tensor_slice.proto";
7 | import "types.proto";
8 | import "versions.proto";
9 |
10 | option cc_enable_arenas = true;
11 | option java_outer_classname = "TensorBundleProtos";
12 | option java_multiple_files = true;
13 | option java_package = "org.platanios.tensorflow.proto";
14 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
15 |
16 | // Protos used in the tensor bundle module (tf/core/util/tensor_bundle/).
17 |
18 | // Special header that is associated with a bundle.
19 | //
20 | // TODO(zongheng,zhifengc): maybe in the future, we can add information about
21 | // which binary produced this checkpoint, timestamp, etc. Sometime, these can be
22 | // valuable debugging information. And if needed, these can be used as defensive
23 | // information ensuring reader (binary version) of the checkpoint and the writer
24 | // (binary version) must match within certain range, etc.
25 | message BundleHeaderProto {
26 | // Number of data files in the bundle.
27 | int32 num_shards = 1;
28 |
29 | // An enum indicating the endianness of the platform that produced this
30 | // bundle. A bundle can only be read by a platform with matching endianness.
31 | // Defaults to LITTLE, as most modern platforms are little-endian.
32 | //
33 | // Affects the binary tensor data bytes only, not the metadata in protobufs.
34 | enum Endianness {
35 | LITTLE = 0;
36 | BIG = 1;
37 | }
38 | Endianness endianness = 2;
39 |
40 | // Versioning of the tensor bundle format.
41 | VersionDef version = 3;
42 | }
43 |
44 | // Describes the metadata related to a checkpointed tensor.
45 | message BundleEntryProto {
46 | // The tensor dtype and shape.
47 | DataType dtype = 1;
48 | TensorShapeProto shape = 2;
49 | // The binary content of the tensor lies in:
50 | // File "shard_id": bytes [offset, offset + size).
51 | int32 shard_id = 3;
52 | int64 offset = 4;
53 | int64 size = 5;
54 |
55 | // The CRC32C checksum of the tensor bytes.
56 | fixed32 crc32c = 6;
57 |
58 | // Iff present, this entry represents a partitioned tensor. The previous
59 | // fields are interpreted as follows:
60 | //
61 | // "dtype", "shape": describe the full tensor.
62 | // "shard_id", "offset", "size", "crc32c": all IGNORED.
63 | // These information for each slice can be looked up in their own
64 | // BundleEntryProto, keyed by each "slice_name".
65 | repeated TensorSliceProto slices = 7;
66 | }
67 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/tensor_description.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | import "allocation_description.proto";
6 | import "tensor_shape.proto";
7 | import "types.proto";
8 |
9 | option cc_enable_arenas = true;
10 | option java_outer_classname = "TensorDescriptionProtos";
11 | option java_multiple_files = true;
12 | option java_package = "org.platanios.tensorflow.proto";
13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_description_go_proto";
14 |
15 | message TensorDescription {
16 | // Data type of tensor elements
17 | DataType dtype = 1;
18 |
19 | // Shape of the tensor.
20 | TensorShapeProto shape = 2;
21 |
22 | // Information about the size and allocator used for the data
23 | AllocationDescription allocation_description = 4;
24 | }
25 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/tensor_shape.proto:
--------------------------------------------------------------------------------
1 | // Protocol buffer representing the shape of tensors.
2 |
3 | syntax = "proto3";
4 |
5 | package org.platanios.tensorflow.proto;
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "TensorShapeProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto";
12 |
13 | // Dimensions of a tensor.
14 | message TensorShapeProto {
15 | // One dimension of the tensor.
16 | message Dim {
17 | // Size of the tensor in that dimension.
18 | // This value must be >= -1, but values of -1 are reserved for "unknown"
19 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers
20 | // that work with TensorShapeProto may fail at runtime when deserializing
21 | // a TensorShapeProto containing a dim value of -1.
22 | int64 size = 1;
23 |
24 | // Optional name of the tensor dimension.
25 | string name = 2;
26 | };
27 |
28 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40}
29 | // for a 30 x 40 2D tensor. If an entry has size -1, this
30 | // corresponds to a dimension of unknown size. The names are
31 | // optional.
32 | //
33 | // The order of entries in "dim" matters: It indicates the layout of the
34 | // values in the tensor in-memory representation.
35 | //
36 | // The first entry in "dim" is the outermost dimension used to layout the
37 | // values, the last entry is the innermost dimension. This matches the
38 | // in-memory layout of RowMajor Eigen tensors.
39 | //
40 | // If "dim.size()" > 0, "unknown_rank" must be false.
41 | repeated Dim dim = 2;
42 |
43 | // If true, the number of dimensions in the shape is unknown.
44 | //
45 | // If true, "dim.size()" must be 0.
46 | bool unknown_rank = 3;
47 | };
48 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/tensor_slice.proto:
--------------------------------------------------------------------------------
1 | // Protocol buffer representing slices of a tensor
2 |
3 | syntax = "proto3";
4 |
5 | package org.platanios.tensorflow.proto;
6 |
7 | option cc_enable_arenas = true;
8 | option java_outer_classname = "TensorSliceProtos";
9 | option java_multiple_files = true;
10 | option java_package = "org.platanios.tensorflow.proto";
11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_slice_go_proto";
12 |
13 | // Can only be interpreted if you know the corresponding TensorShape.
14 | message TensorSliceProto {
15 | // Extent of the slice in one dimension.
16 | message Extent {
17 | // Either both or no attributes must be set. When no attribute is set
18 | // means: All data in that dimension.
19 |
20 | // Start index of the slice, starting at 0.
21 | int64 start = 1;
22 |
23 | // Length of the slice: if the length is missing or -1 we will
24 | // interpret this as "everything in this dimension". We use
25 | // "oneof" to preserve information about whether the length is
26 | // present without changing the serialization format from the
27 | // prior proto2 version of this proto.
28 | oneof has_length {
29 | int64 length = 2;
30 | }
31 | }
32 |
33 | // Extent of the slice in all tensor dimensions.
34 | //
35 | // Must have one entry for each of the dimension of the tensor that this
36 | // slice belongs to. The order of sizes is the same as the order of
37 | // dimensions in the TensorShape.
38 | repeated Extent extent = 1;
39 | }
40 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/tensorflow_server.proto:
--------------------------------------------------------------------------------
1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | syntax = "proto3";
17 |
18 | package org.platanios.tensorflow.proto;
19 |
20 | import "cluster.proto";
21 | import "config.proto";
22 | import "device_filters.proto";
23 |
24 | option cc_enable_arenas = true;
25 | option java_outer_classname = "ServerProtos";
26 | option java_multiple_files = true;
27 | option java_package = "org.platanios.tensorflow.proto";
28 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
29 |
30 | // Defines the configuration of a single TensorFlow server.
31 | message ServerDef {
32 | // The cluster of which this server is a member.
33 | ClusterDef cluster = 1;
34 |
35 | // The name of the job of which this server is a member.
36 | //
37 | // NOTE(mrry): The `cluster` field must contain a `JobDef` with a `name` field
38 | // that matches this name.
39 | string job_name = 2;
40 |
41 | // The task index of this server in its job.
42 | //
43 | // NOTE: The `cluster` field must contain a `JobDef` with a matching `name`
44 | // and a mapping in its `tasks` field for this index.
45 | int32 task_index = 3;
46 |
47 | // The default configuration for sessions that run on this server.
48 | ConfigProto default_session_config = 4;
49 |
50 | // The protocol to be used by this server.
51 | //
52 | // Acceptable values include: "grpc", "grpc+verbs".
53 | string protocol = 5;
54 |
55 | // The server port. If not set, then we identify the port from the job_name.
56 | int32 port = 6;
57 |
58 | // Device filters for remote tasks in the cluster.
59 | // NOTE: This is an experimental feature and only effective in TensorFlow 2.x.
60 | ClusterDeviceFilters cluster_device_filters = 7;
61 | }
62 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/trackable_object_graph.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
7 |
8 | // A TensorBundle addition which saves extra information about the objects which
9 | // own variables, allowing for more robust checkpoint loading into modified
10 | // programs.
11 |
12 | message TrackableObjectGraph {
13 | message TrackableObject {
14 | message ObjectReference {
15 | // An index into `TrackableObjectGraph.nodes`, indicating the object
16 | // being referenced.
17 | int32 node_id = 1;
18 | // A user-provided name for the edge.
19 | string local_name = 2;
20 | }
21 |
22 | message SerializedTensor {
23 | // A name for the Tensor. Simple variables have only one
24 | // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may
25 | // be restored on object creation as an optimization.
26 | string name = 1;
27 | // The full name of the variable/tensor, if applicable. Used to allow
28 | // name-based loading of checkpoints which were saved using an
29 | // object-based API. Should match the checkpoint key which would have been
30 | // assigned by tf.train.Saver.
31 | string full_name = 2;
32 | // The generated name of the Tensor in the checkpoint.
33 | string checkpoint_key = 3;
34 | // Whether checkpoints should be considered as matching even without this
35 | // value restored. Used for non-critical values which don't affect the
36 | // TensorFlow graph, such as layer configurations.
37 | bool optional_restore = 4;
38 | }
39 |
40 | message SlotVariableReference {
41 | // An index into `TrackableObjectGraph.nodes`, indicating the
42 | // variable object this slot was created for.
43 | int32 original_variable_node_id = 1;
44 | // The name of the slot (e.g. "m"/"v").
45 | string slot_name = 2;
46 | // An index into `TrackableObjectGraph.nodes`, indicating the
47 | // `Object` with the value of the slot variable.
48 | int32 slot_variable_node_id = 3;
49 | }
50 |
51 | // Objects which this object depends on.
52 | repeated ObjectReference children = 1;
53 | // Serialized data specific to this object.
54 | repeated SerializedTensor attributes = 2;
55 | // Slot variables owned by this object.
56 | repeated SlotVariableReference slot_variables = 3;
57 | }
58 |
59 | repeated TrackableObject nodes = 1;
60 | }
61 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/transport_options.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
6 |
7 | // Extra data needed on a non-RDMA RecvBufResponse.
8 | message RecvBufRespExtra {
9 | repeated bytes tensor_content = 1;
10 | }
11 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/verifier_config.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "VerifierConfigProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto";
10 |
11 | // The config for graph verifiers.
12 | message VerifierConfig {
13 | enum Toggle {
14 | DEFAULT = 0;
15 | ON = 1;
16 | OFF = 2;
17 | }
18 |
19 | // Deadline for completion of all verification i.e. all the Toggle ON
20 | // verifiers must complete execution within this time.
21 | int64 verification_timeout_in_ms = 1;
22 |
23 | // Perform structural validation on a tensorflow graph. Default is OFF.
24 | Toggle structure_verifier = 2;
25 |
26 | // Next tag: 3
27 | }
28 |
--------------------------------------------------------------------------------
/modules/proto/src/main/proto/versions.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package org.platanios.tensorflow.proto;
4 |
5 | option cc_enable_arenas = true;
6 | option java_outer_classname = "VersionsProtos";
7 | option java_multiple_files = true;
8 | option java_package = "org.platanios.tensorflow.proto";
9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/versions_go_proto";
10 |
11 | // Version information for a piece of serialized data
12 | //
13 | // There are different types of versions for each type of data
14 | // (GraphDef, etc.), but they all have the same common shape
15 | // described here.
16 | //
17 | // Each consumer has "consumer" and "min_producer" versions (specified
18 | // elsewhere). A consumer is allowed to consume this data if
19 | //
20 | // producer >= min_producer
21 | // consumer >= min_consumer
22 | // consumer not in bad_consumers
23 | //
24 | message VersionDef {
25 | // The version of the code that produced this data.
26 | int32 producer = 1;
27 |
28 | // Any consumer below this version is not allowed to consume this data.
29 | int32 min_consumer = 2;
30 |
31 | // Specific consumer versions which are disallowed (e.g. due to bugs).
32 | repeated int32 bad_consumers = 3;
33 | }
34 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.4.7
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 | * use this file except in compliance with the License. You may obtain a copy of
5 | * the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 | * License for the specific language governing permissions and limitations under
13 | * the License.
14 | */
15 |
16 | logLevel := Level.Warn
17 |
18 | libraryDependencies ++= Seq(
19 | "ch.qos.logback" % "logback-classic" % "1.2.3",
20 | "org.ow2.asm" % "asm" % "6.2.1",
21 | // The following is needed to automatically generate the eager ops.
22 | "org.tensorflow" % "proto" % "1.15.0")
23 |
24 | addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.2.5")
25 | addSbtPlugin("com.github.sbt" % "sbt-protobuf" % "0.7.0")
26 |
27 | // Plugins used for the documentation website.
28 | addSbtPlugin("com.lightbend.paradox" % "sbt-paradox" % "0.8.0")
29 | addSbtPlugin("io.github.jonas" % "sbt-paradox-material-theme" % "0.6.0")
30 | addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.3.2")
31 | addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.3")
32 | addSbtPlugin("com.thoughtworks.sbt-api-mappings" % "sbt-api-mappings" % "latest.release")
33 |
34 | // Packaging and publishing related plugins.
35 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.13")
36 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.0")
37 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.2")
38 |
--------------------------------------------------------------------------------
/version.sbt:
--------------------------------------------------------------------------------
1 | version in ThisBuild := "0.6.6-SNAPSHOT"
2 |
--------------------------------------------------------------------------------