├── .circleci ├── config.yml └── images │ ├── Dockerfile.linux-cpu-x86_64 │ ├── Dockerfile.linux-gpu-x86_64 │ ├── docker_build_linux-cpu-x86_64.sh │ ├── docker_build_linux-gpu-x86_64.sh │ └── install.sh ├── .gitattributes ├── .gitignore ├── .jvmopts ├── .scalafmt.conf ├── LICENSE ├── README.md ├── RELEASE.md ├── build.sbt ├── docs ├── images │ ├── logo.afdesign │ ├── logo.pdf │ ├── logo.png │ ├── logo.svg │ ├── logo_square.afdesign │ └── logo_square.png └── src │ └── main │ ├── paradox │ ├── assets │ │ ├── custom.css │ │ └── images │ │ │ ├── afosr_logo.gif │ │ │ ├── cmu_logo.svg │ │ │ ├── favicon.ico │ │ │ ├── logo.png │ │ │ ├── nsf_logo.svg │ │ │ ├── tensorboard_mnist_example_plot.png │ │ │ ├── tensorflow_logo.png │ │ │ ├── tensorflow_logo_square.svg │ │ │ └── tensorflow_logo_wide.svg │ ├── contributing.md │ ├── guides.md │ ├── guides │ │ ├── adding_ops.md │ │ ├── estimators.md │ │ ├── graph_construction.md │ │ └── tensors.md │ ├── index.md │ ├── installation.md │ ├── release_notes.md │ └── release_notes │ │ ├── 0.1.0.md │ │ ├── 0.1.1.md │ │ ├── 0.2.0.md │ │ ├── 0.2.1.md │ │ ├── 0.2.2.md │ │ ├── 0.2.3.md │ │ ├── 0.2.4.md │ │ ├── 0.3.0.md │ │ ├── 0.4.0.md │ │ ├── 0.4.1.md │ │ ├── 0.5.0.md │ │ └── 0.5.1.md │ └── scala │ ├── AddingOps.scala │ ├── Estimators.scala │ ├── Index.scala │ ├── Tensors.scala │ └── installation.sh ├── modules ├── api │ └── src │ │ ├── main │ │ ├── resources │ │ │ └── logback.xml │ │ └── scala │ │ │ └── org │ │ │ └── platanios │ │ │ └── tensorflow │ │ │ └── api │ │ │ ├── Documentation.scala │ │ │ ├── config │ │ │ ├── CheckpointConfig.scala │ │ │ ├── ClusterConfig.scala │ │ │ ├── SummaryConfig.scala │ │ │ └── TensorBoardConfig.scala │ │ │ ├── core │ │ │ ├── DeviceSpecification.scala │ │ │ ├── Devices.scala │ │ │ ├── Graph.scala │ │ │ ├── Implicits.scala │ │ │ ├── Indexer.scala │ │ │ ├── Logging.scala │ │ │ ├── Shape.scala │ │ │ ├── client │ │ │ │ ├── FeedMap.scala │ │ │ │ ├── Implicits.scala │ │ │ │ ├── Session.scala │ │ │ │ ├── SessionConfig.scala │ │ │ │ └── Timeline.scala │ │ │ ├── distributed │ │ │ │ ├── Protocol.scala │ │ │ │ ├── ReplicaDevicePlacer.scala │ │ │ │ └── Server.scala │ │ │ ├── package.scala │ │ │ └── types │ │ │ │ ├── DataType.scala │ │ │ │ ├── Implicits.scala │ │ │ │ └── package.scala │ │ │ ├── implicits │ │ │ ├── Implicits.scala │ │ │ ├── helpers │ │ │ │ ├── DataTypeStructure.scala │ │ │ │ ├── DataTypeToOutput.scala │ │ │ │ ├── DataTypeToShape.scala │ │ │ │ ├── OpStructure.scala │ │ │ │ ├── OutputStructure.scala │ │ │ │ ├── OutputToDataType.scala │ │ │ │ ├── OutputToShape.scala │ │ │ │ ├── OutputToTensor.scala │ │ │ │ ├── PlaceholderSupport.scala │ │ │ │ ├── ShapeStructure.scala │ │ │ │ ├── TensorStructure.scala │ │ │ │ ├── TensorToDataType.scala │ │ │ │ ├── TensorToOutput.scala │ │ │ │ ├── TensorToShape.scala │ │ │ │ ├── Zero.scala │ │ │ │ └── package.scala │ │ │ └── ops │ │ │ │ ├── BasicImplicits.scala │ │ │ │ ├── ClipImplicits.scala │ │ │ │ ├── ControlFlowImplicits.scala │ │ │ │ ├── EmbeddingImplicits.scala │ │ │ │ ├── Implicits.scala │ │ │ │ ├── MathImplicits.scala │ │ │ │ ├── NNImplicits.scala │ │ │ │ ├── SparseImplicits.scala │ │ │ │ ├── StatisticsImplicits.scala │ │ │ │ └── TextImplicits.scala │ │ │ ├── io │ │ │ ├── CheckpointReader.scala │ │ │ ├── CompressionType.scala │ │ │ ├── DirectoryLoader.scala │ │ │ ├── FileIO.scala │ │ │ ├── Loader.scala │ │ │ ├── NPY.scala │ │ │ ├── TFRecordReader.scala │ │ │ ├── TFRecordWriter.scala │ │ │ └── events │ │ │ │ ├── EventAccumulator.scala │ │ │ │ ├── EventFileReader.scala │ │ │ │ ├── EventFileWriter.scala │ │ │ │ ├── EventMultiplexer.scala │ │ │ │ ├── EventPluginUtilities.scala │ │ │ │ ├── EventRecord.scala │ │ │ │ ├── EventType.scala │ │ │ │ ├── SummaryFileWriter.scala │ │ │ │ └── SummaryFileWriterCache.scala │ │ │ ├── learn │ │ │ ├── ClipGradients.scala │ │ │ ├── Configuration.scala │ │ │ ├── Counter.scala │ │ │ ├── Implicits.scala │ │ │ ├── Mode.scala │ │ │ ├── Model.scala │ │ │ ├── ModelInstance.scala │ │ │ ├── SessionCreator.scala │ │ │ ├── SessionManager.scala │ │ │ ├── SessionScaffold.scala │ │ │ ├── SessionWrapper.scala │ │ │ ├── StopCriteria.scala │ │ │ ├── estimators │ │ │ │ ├── Estimator.scala │ │ │ │ ├── FileBasedEstimator.scala │ │ │ │ ├── InMemoryEstimator.scala │ │ │ │ └── package.scala │ │ │ ├── hooks │ │ │ │ ├── CheckpointSaver.scala │ │ │ │ ├── Evaluator.scala │ │ │ │ ├── Hook.scala │ │ │ │ ├── HookTrigger.scala │ │ │ │ ├── LossLogger.scala │ │ │ │ ├── ModelDependentHook.scala │ │ │ │ ├── NaNChecker.scala │ │ │ │ ├── StepRateLogger.scala │ │ │ │ ├── Stopper.scala │ │ │ │ ├── SummarySaver.scala │ │ │ │ ├── SummaryWriterHookAddOn.scala │ │ │ │ ├── TensorBoardHook.scala │ │ │ │ ├── TensorLogger.scala │ │ │ │ ├── TimelineHook.scala │ │ │ │ ├── TriggeredHook.scala │ │ │ │ └── package.scala │ │ │ ├── layers │ │ │ │ ├── Activation.scala │ │ │ │ ├── Basic.scala │ │ │ │ ├── Embedding.scala │ │ │ │ ├── Input.scala │ │ │ │ ├── Layer.scala │ │ │ │ ├── Loss.scala │ │ │ │ ├── Math.scala │ │ │ │ ├── NN.scala │ │ │ │ ├── Summary.scala │ │ │ │ ├── core │ │ │ │ │ └── package.scala │ │ │ │ ├── package.scala │ │ │ │ └── rnn │ │ │ │ │ ├── BidirectionalRNN.scala │ │ │ │ │ ├── RNN.scala │ │ │ │ │ ├── cell │ │ │ │ │ ├── BasicLSTMCell.scala │ │ │ │ │ ├── BasicRNNCell.scala │ │ │ │ │ ├── DeviceWrapper.scala │ │ │ │ │ ├── DropoutWrapper.scala │ │ │ │ │ ├── GRUCell.scala │ │ │ │ │ ├── LSTMCell.scala │ │ │ │ │ ├── RNNCell.scala │ │ │ │ │ ├── ResidualWrapper.scala │ │ │ │ │ ├── StackedCell.scala │ │ │ │ │ └── package.scala │ │ │ │ │ └── package.scala │ │ │ ├── models │ │ │ │ └── RBM.scala │ │ │ └── package.scala │ │ │ ├── ops │ │ │ ├── Callback.scala │ │ │ ├── Cast.scala │ │ │ ├── Checks.scala │ │ │ ├── Clip.scala │ │ │ ├── DataFlow.scala │ │ │ ├── Documentation.scala │ │ │ ├── Embedding.scala │ │ │ ├── Files.scala │ │ │ ├── Function.scala │ │ │ ├── Gradients.scala │ │ │ ├── Image.scala │ │ │ ├── Input.scala │ │ │ ├── Logging.scala │ │ │ ├── NN.scala │ │ │ ├── Op.scala │ │ │ ├── OpSpecification.scala │ │ │ ├── Output.scala │ │ │ ├── OutputOps.scala │ │ │ ├── Parsing.scala │ │ │ ├── Queue.scala │ │ │ ├── Random.scala │ │ │ ├── Resources.scala │ │ │ ├── Sets.scala │ │ │ ├── Slot.scala │ │ │ ├── Sparse.scala │ │ │ ├── Statistics.scala │ │ │ ├── Summary.scala │ │ │ ├── TensorArray.scala │ │ │ ├── Text.scala │ │ │ ├── basic │ │ │ │ ├── Basic.scala │ │ │ │ ├── Constructors.scala │ │ │ │ ├── Inplace.scala │ │ │ │ ├── Manipulation.scala │ │ │ │ └── Masking.scala │ │ │ ├── control_flow │ │ │ │ ├── CondContext.scala │ │ │ │ ├── Context.scala │ │ │ │ ├── ControlFlow.scala │ │ │ │ ├── GradientLoopState.scala │ │ │ │ ├── GradientState.scala │ │ │ │ ├── WhileLoopContext.scala │ │ │ │ └── package.scala │ │ │ ├── data │ │ │ │ ├── Data.scala │ │ │ │ ├── Dataset.scala │ │ │ │ ├── DatasetIterator.scala │ │ │ │ ├── Experimental.scala │ │ │ │ └── package.scala │ │ │ ├── lookup │ │ │ │ ├── IDLookupTableWithHashBuckets.scala │ │ │ │ ├── Lookup.scala │ │ │ │ ├── LookupTable.scala │ │ │ │ ├── LookupTableInitializer.scala │ │ │ │ ├── LookupTableTensorInitializer.scala │ │ │ │ ├── LookupTableTextFileInitializer.scala │ │ │ │ └── package.scala │ │ │ ├── math │ │ │ │ ├── Bitwise.scala │ │ │ │ └── Math.scala │ │ │ ├── metrics │ │ │ │ ├── Accuracy.scala │ │ │ │ ├── ConfusionMatrix.scala │ │ │ │ ├── GroupedPrecision.scala │ │ │ │ ├── MapMetric.scala │ │ │ │ ├── Mean.scala │ │ │ │ ├── Metric.scala │ │ │ │ ├── PrecisionAtK.scala │ │ │ │ └── package.scala │ │ │ ├── package.scala │ │ │ ├── rnn │ │ │ │ ├── RNN.scala │ │ │ │ ├── attention │ │ │ │ │ ├── Attention.scala │ │ │ │ │ ├── AttentionWrapperCell.scala │ │ │ │ │ ├── BahdanauAttention.scala │ │ │ │ │ ├── LuongAttention.scala │ │ │ │ │ └── package.scala │ │ │ │ ├── cell │ │ │ │ │ ├── BasicLSTMCell.scala │ │ │ │ │ ├── BasicRNNCell.scala │ │ │ │ │ ├── DeviceWrapper.scala │ │ │ │ │ ├── DropoutWrapper.scala │ │ │ │ │ ├── GRUCell.scala │ │ │ │ │ ├── LSTMCell.scala │ │ │ │ │ ├── RNNCell.scala │ │ │ │ │ ├── ResidualWrapper.scala │ │ │ │ │ ├── StackedCell.scala │ │ │ │ │ └── package.scala │ │ │ │ └── package.scala │ │ │ ├── training │ │ │ │ ├── ExponentialMovingAverage.scala │ │ │ │ ├── optimizers │ │ │ │ │ ├── AMSGrad.scala │ │ │ │ │ ├── AdaDelta.scala │ │ │ │ │ ├── AdaGrad.scala │ │ │ │ │ ├── Adafactor.scala │ │ │ │ │ ├── Adam.scala │ │ │ │ │ ├── GradientDescent.scala │ │ │ │ │ ├── LazyAMSGrad.scala │ │ │ │ │ ├── LazyAdam.scala │ │ │ │ │ ├── Optimizer.scala │ │ │ │ │ ├── RMSProp.scala │ │ │ │ │ ├── YellowFin.scala │ │ │ │ │ ├── Yogi.scala │ │ │ │ │ ├── package.scala │ │ │ │ │ └── schedules │ │ │ │ │ │ ├── ComposedSchedule.scala │ │ │ │ │ │ ├── CosineDecay.scala │ │ │ │ │ │ ├── CycleLinear10xDecay.scala │ │ │ │ │ │ ├── ExponentialDecay.scala │ │ │ │ │ │ ├── FixedSchedule.scala │ │ │ │ │ │ ├── RSqrtDecay.scala │ │ │ │ │ │ ├── Schedule.scala │ │ │ │ │ │ ├── WarmUpExponentialSchedule.scala │ │ │ │ │ │ ├── WarmUpLinearSchedule.scala │ │ │ │ │ │ └── package.scala │ │ │ │ └── package.scala │ │ │ └── variables │ │ │ │ ├── Initializer.scala │ │ │ │ ├── Regularizer.scala │ │ │ │ ├── Reuse.scala │ │ │ │ ├── Saver.scala │ │ │ │ ├── Variable.scala │ │ │ │ ├── VariableLike.scala │ │ │ │ ├── VariableScope.scala │ │ │ │ ├── VariableScopeStore.scala │ │ │ │ ├── VariableStore.scala │ │ │ │ └── package.scala │ │ │ ├── package.scala │ │ │ ├── tensors │ │ │ ├── Context.scala │ │ │ ├── Implicits.scala │ │ │ ├── Tensor.scala │ │ │ ├── TensorOps.scala │ │ │ ├── ops │ │ │ │ ├── Basic.scala │ │ │ │ ├── Cast.scala │ │ │ │ ├── Math.scala │ │ │ │ ├── NN.scala │ │ │ │ ├── Random.scala │ │ │ │ └── package.scala │ │ │ └── package.scala │ │ │ └── utilities │ │ │ ├── ByteCodable.scala │ │ │ ├── CRC32C.scala │ │ │ ├── Coding.scala │ │ │ ├── Collections.scala │ │ │ ├── DefaultsTo.scala │ │ │ ├── Disposer.scala │ │ │ ├── NativeHandleWrapper.scala │ │ │ ├── Proto.scala │ │ │ ├── Reservoir.scala │ │ │ └── package.scala │ │ └── test │ │ └── scala │ │ └── org │ │ └── platanios │ │ └── tensorflow │ │ └── api │ │ ├── core │ │ ├── DataTypeSpec.scala │ │ ├── DeviceSpecificationSpec.scala │ │ ├── GraphSpec.scala │ │ ├── IndexerSpec.scala │ │ ├── SessionSpec.scala │ │ ├── ShapeSpec.scala │ │ └── client │ │ │ ├── FeedableSuite.scala │ │ │ └── FetchableSuite.scala │ │ ├── implicits │ │ └── helpers │ │ │ └── OpStructureSuite.scala │ │ ├── io │ │ ├── DirectoryLoaderSuite.scala │ │ └── events │ │ │ └── EventFileReaderSuite.scala │ │ ├── ops │ │ ├── BasicSpec.scala │ │ ├── CallbackSuite.scala │ │ ├── FunctionSuite.scala │ │ ├── GradientsSuite.scala │ │ ├── NNSpec.scala │ │ ├── OpSpec.scala │ │ ├── TextSuite.scala │ │ ├── VariableSpec.scala │ │ ├── control_flow │ │ │ └── ControlFlowSuite.scala │ │ ├── data │ │ │ ├── DatasetSuite.scala │ │ │ └── FilterDatasetSuite.scala │ │ └── training │ │ │ └── optimizers │ │ │ └── GradientDescentSpec.scala │ │ ├── tensors │ │ └── TensorSuite.scala │ │ └── utilities │ │ ├── CRC32CSuite.scala │ │ └── ReservoirSuite.scala ├── data │ └── src │ │ ├── main │ │ ├── resources │ │ │ └── logback.xml │ │ └── scala │ │ │ └── org │ │ │ └── platanios │ │ │ └── tensorflow │ │ │ └── data │ │ │ ├── Loader.scala │ │ │ ├── XCLoader.scala │ │ │ ├── image │ │ │ ├── CIFARLoader.scala │ │ │ ├── MNISTLoader.scala │ │ │ └── STL10Loader.scala │ │ │ ├── models │ │ │ └── ObjectDetectionModelLoader.scala │ │ │ ├── text │ │ │ └── PTBLoader.scala │ │ │ └── utilities │ │ │ ├── CompressedFiles.scala │ │ │ └── Split.scala │ │ └── test │ │ └── scala │ │ └── org │ │ └── platanios │ │ └── tensorflow │ │ └── data │ │ └── image │ │ └── MNISTLoaderSpec.scala ├── examples │ └── src │ │ └── main │ │ ├── resources │ │ ├── logback.xml │ │ └── python2scala │ │ │ ├── MetaGraphDef.txt │ │ │ ├── checkpoint │ │ │ ├── linear-regression.data-00000-of-00001 │ │ │ ├── linear-regression.index │ │ │ ├── linear-regression.meta │ │ │ ├── linearRegression.py │ │ │ ├── virgin-linear-regression.data-00000-of-00001 │ │ │ ├── virgin-linear-regression.index │ │ │ └── virgin-linear-regression.meta │ │ └── scala │ │ └── org │ │ └── platanios │ │ └── tensorflow │ │ └── examples │ │ ├── CIFAR.scala │ │ ├── LinearRegression.scala │ │ ├── MNIST.scala │ │ ├── RNNTutorialUsingPTB.scala │ │ ├── STL10.scala │ │ ├── package.scala │ │ └── python2scala │ │ └── LinearRegressionFromRestoredPythonModel.scala ├── jni │ └── src │ │ ├── main │ │ ├── native │ │ │ ├── CMakeLists.txt │ │ │ ├── c_api_internal.cc │ │ │ ├── checkpoint_reader.cc │ │ │ ├── checkpoint_reader.h │ │ │ ├── checkpoint_reader_internal.cc │ │ │ ├── checkpoint_reader_internal.h │ │ │ ├── exception.h │ │ │ ├── function.cc │ │ │ ├── function.h │ │ │ ├── generated │ │ │ │ ├── tensor_basic_ops.cc │ │ │ │ ├── tensor_basic_ops.h │ │ │ │ ├── tensor_math_ops.cc │ │ │ │ ├── tensor_math_ops.h │ │ │ │ ├── tensor_nn_ops.cc │ │ │ │ ├── tensor_nn_ops.h │ │ │ │ ├── tensor_random_ops.cc │ │ │ │ ├── tensor_random_ops.h │ │ │ │ ├── tensor_sparse_ops.cc │ │ │ │ ├── tensor_sparse_ops.h │ │ │ │ ├── tensor_text_ops.cc │ │ │ │ └── tensor_text_ops.h │ │ │ ├── graph.cc │ │ │ ├── graph.h │ │ │ ├── include │ │ │ │ └── tensorflow │ │ │ │ │ ├── c │ │ │ │ │ ├── c_api.h │ │ │ │ │ ├── c_api_experimental.h │ │ │ │ │ ├── c_api_macros.h │ │ │ │ │ ├── eager │ │ │ │ │ │ ├── c_api.h │ │ │ │ │ │ ├── c_api_experimental.h │ │ │ │ │ │ └── dlpack.h │ │ │ │ │ ├── tensor_interface.h │ │ │ │ │ ├── tf_attrtype.h │ │ │ │ │ ├── tf_datatype.h │ │ │ │ │ ├── tf_file_statistics.h │ │ │ │ │ ├── tf_status.h │ │ │ │ │ ├── tf_tensor.h │ │ │ │ │ └── tf_tstring.h │ │ │ │ │ └── core │ │ │ │ │ └── platform │ │ │ │ │ ├── ctstring.h │ │ │ │ │ └── ctstring_internal.h │ │ │ ├── op.cc │ │ │ ├── op.h │ │ │ ├── ops │ │ │ │ ├── beam_search_ops.cc │ │ │ │ ├── beam_search_ops.h │ │ │ │ ├── beam_search_ops_gpu.cu.cc │ │ │ │ ├── jvm_callback_op.cc │ │ │ │ └── jvm_callback_op.h │ │ │ ├── python_api.cc │ │ │ ├── python_api.h │ │ │ ├── server.cc │ │ │ ├── server.h │ │ │ ├── session.cc │ │ │ ├── session.h │ │ │ ├── tensor.cc │ │ │ ├── tensor.h │ │ │ ├── tensorflow.cc │ │ │ ├── tensorflow.h │ │ │ └── utilities.h │ │ ├── resources │ │ │ ├── logback.xml │ │ │ └── ops.pbtxt │ │ └── scala │ │ │ └── org │ │ │ └── platanios │ │ │ └── tensorflow │ │ │ └── jni │ │ │ ├── CheckpointReader.scala │ │ │ ├── Function.scala │ │ │ ├── Graph.scala │ │ │ ├── Op.scala │ │ │ ├── ScalaCallbacksRegistry.scala │ │ │ ├── Server.scala │ │ │ ├── Session.scala │ │ │ ├── Tensor.scala │ │ │ ├── TensorFlow.scala │ │ │ ├── TensorFlowException.scala │ │ │ └── generated │ │ │ └── tensors │ │ │ ├── Basic.scala │ │ │ ├── Math.scala │ │ │ ├── NN.scala │ │ │ ├── Random.scala │ │ │ ├── Sparse.scala │ │ │ └── Text.scala │ │ └── test │ │ └── scala │ │ └── org │ │ └── platanios │ │ └── tensorflow │ │ └── jni │ │ └── TensorFlowSpec.scala └── proto │ └── src │ └── main │ └── proto │ ├── allocation_description.proto │ ├── any.proto │ ├── api_def.proto │ ├── attr_value.proto │ ├── autotuning.proto │ ├── bfc_memory_map.proto │ ├── checkpoint_state.proto │ ├── cluster.proto │ ├── config.proto │ ├── control_flow.proto │ ├── cost_graph.proto │ ├── critical_section.proto │ ├── debug.proto │ ├── debug_event.proto │ ├── device_attributes.proto │ ├── device_filters.proto │ ├── device_properties.proto │ ├── duration.proto │ ├── eager_service.proto │ ├── error_codes.proto │ ├── event.proto │ ├── example.proto │ ├── feature.proto │ ├── function.proto │ ├── graph.proto │ ├── graph_debug_info.proto │ ├── graph_transfer_info.proto │ ├── kernel_def.proto │ ├── log_memory.proto │ ├── master.proto │ ├── master_service.proto │ ├── meta_graph.proto │ ├── named_tensor.proto │ ├── node_def.proto │ ├── op_def.proto │ ├── queue_runner.proto │ ├── reader_base.proto │ ├── remote_tensor_handle.proto │ ├── replay_log.proto │ ├── resource_handle.proto │ ├── rewriter_config.proto │ ├── saved_model.proto │ ├── saved_object_graph.proto │ ├── saver.proto │ ├── source_context.proto │ ├── step_stats.proto │ ├── struct.proto │ ├── summary.proto │ ├── tensor.proto │ ├── tensor_bundle.proto │ ├── tensor_description.proto │ ├── tensor_shape.proto │ ├── tensor_slice.proto │ ├── tensorflow_server.proto │ ├── trackable_object_graph.proto │ ├── transport_options.proto │ ├── type.proto │ ├── types.proto │ ├── variable.proto │ ├── verifier_config.proto │ ├── versions.proto │ ├── worker.proto │ └── worker_service.proto ├── project ├── BuildTool.scala ├── JniCrossPackage.scala ├── JniJavah.scala ├── JniNative.scala ├── JniPackage.scala ├── OpGenerator.scala ├── TensorFlowGenerateTensorOps.scala ├── TensorFlowNativePackage.scala ├── build.properties └── plugins.sbt └── version.sbt /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | build-linux-x86_64: 5 | docker: 6 | - image: eaplatanios/tensorflow_scala:linux-cpu-x86_64-0.6.0 7 | working_directory: ~/repository 8 | environment: 9 | LD_LIBRARY_PATH: /usr/lib:$LD_LIBRARY_PATH 10 | JAVA_HOME: /usr/lib/jvm/java-8-openjdk-amd64 11 | JVM_OPTS: -Xmx3200m 12 | TERM: dumb 13 | steps: 14 | - checkout 15 | - run: mkdir downloads && cd downloads 16 | - run: wget https://oss.sonatype.org/service/local/repositories/snapshots/content/org/platanios/tensorflow_2.13/0.6.0-SNAPSHOT/tensorflow_2.13-0.6.0-SNAPSHOT-linux.jar 17 | - run: jar xf tensorflow_2.13-0.6.0-SNAPSHOT-linux.jar 18 | - run: mv libtensorflow.so /usr/lib/libtensorflow.so 19 | - run: mv libtensorflow_framework.so /usr/lib/libtensorflow_framework.so 20 | - run: ln -s /usr/lib/libtensorflow.so /usr/lib/libtensorflow.so.2 21 | - run: ln -s /usr/lib/libtensorflow.so /usr/lib/libtensorflow.so.2.4.0 22 | - run: ln -s /usr/lib/libtensorflow_framework.so /usr/lib/libtensorflow_framework.so.2 23 | - run: ln -s /usr/lib/libtensorflow_framework.so /usr/lib/libtensorflow_framework.so.2.4.0 24 | - run: cd .. 25 | - restore_cache: 26 | keys: 27 | - v0.5.1-dependencies-{{ checksum "build.sbt" }} 28 | - run: cat /dev/null | sbt +test:compile 29 | - save_cache: 30 | paths: 31 | - ~/.ivy2 32 | key: v0.5.1-dependencies--{{ checksum "build.sbt" }} 33 | - run: cat /dev/null | sbt +test:test 34 | 35 | workflows: 36 | version: 2 37 | test: 38 | jobs: 39 | - build-linux-x86_64 40 | -------------------------------------------------------------------------------- /.circleci/images/Dockerfile.linux-cpu-x86_64: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | # Copy and run the install script. 4 | COPY .circleci/images/install.sh /install.sh 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | RUN /install.sh 7 | 8 | # Set up MPI 9 | ENV TF_NEED_MPI 1 10 | 11 | ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 12 | 13 | # COPY . /tensorflow_scala 14 | -------------------------------------------------------------------------------- /.circleci/images/Dockerfile.linux-gpu-x86_64: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04 2 | 3 | # In the Ubuntu 18.04 images, cudnn is placed in system paths. Move them to 4 | # /usr/local/cuda 5 | RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include 6 | RUN cp -P /usr/lib/x86_64-linux-gnu/libcudnn* /usr/local/cuda/lib64 7 | 8 | # Copy and run the install script. 9 | COPY .circleci/images/install.sh /install.sh 10 | ARG DEBIAN_FRONTEND=noninteractive 11 | RUN /install.sh 12 | 13 | # Set up MPI 14 | ENV TF_NEED_MPI 1 15 | 16 | # Set up the master bazelrc configuration file. 17 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH 18 | 19 | # Configure the build for our CUDA configuration. 20 | ENV TF_NEED_CUDA 1 21 | ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 22 | 23 | ENV JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 24 | 25 | # COPY . /tensorflow_scala 26 | -------------------------------------------------------------------------------- /.circleci/images/docker_build_linux-cpu-x86_64.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 2 | DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/../../)" 3 | 4 | DOCKER_IMAGE="eaplatanios/tensorflow_scala:linux-cpu-x86_64-0.6.0" 5 | DOCKER_FILE=".circleci/images/Dockerfile.linux-cpu-x86_64" 6 | 7 | docker build \ 8 | -t "${DOCKER_IMAGE}" \ 9 | -f "${DOCKER_CONTEXT_PATH}/${DOCKER_FILE}" \ 10 | "${DOCKER_CONTEXT_PATH}" 11 | -------------------------------------------------------------------------------- /.circleci/images/docker_build_linux-gpu-x86_64.sh: -------------------------------------------------------------------------------- 1 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 2 | DOCKER_CONTEXT_PATH="$(realpath ${SCRIPT_DIR}/../../)" 3 | 4 | DOCKER_IMAGE="eaplatanios/tensorflow_scala:linux-gpu-x86_64-0.6.0" 5 | DOCKER_FILE=".circleci/images/Dockerfile.linux-gpu-x86_64" 6 | 7 | docker build \ 8 | -t "${DOCKER_IMAGE}" \ 9 | -f "${DOCKER_CONTEXT_PATH}/${DOCKER_FILE}" \ 10 | "${DOCKER_CONTEXT_PATH}" 11 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | modules/jni/src/main/native/include/**/* linguist-vendored=true 2 | modules/jni/src/main/native/generated/**/* linguist-generated=true 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | 3 | # Project Specific 4 | temp/ 5 | datasets/ 6 | 7 | # Scala Specific 8 | *.class 9 | *.log 10 | 11 | # SBT Specific 12 | .cache 13 | .history 14 | .lib/ 15 | dist/* 16 | target/ 17 | lib_managed/ 18 | src_managed/ 19 | project/boot/ 20 | project/plugins/project/ 21 | 22 | # Bloop Specific 23 | .bloop 24 | .bsp 25 | 26 | # Metals Specific 27 | .metals 28 | 29 | # Hydra Specific 30 | .hydra 31 | 32 | **/src/main/native/lib/**/* 33 | **/src/main/generated/**/* 34 | 35 | # IntelliJ Specific 36 | **/.idea/**/* 37 | 38 | # CLion Specific 39 | **/cmake-build-debug/**/* 40 | 41 | # VS Code Specific 42 | **/.vscode/**/* 43 | -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | -Dfile.encoding=UTF8 2 | -Xms1G 3 | -Xmx4G 4 | -Xss8M 5 | -XX:ReservedCodeCacheSize=512M 6 | -XX:MaxDirectMemorySize=16G 7 | -XX:MaxMetaspaceSize=1G 8 | -------------------------------------------------------------------------------- /docs/images/logo.afdesign: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.afdesign -------------------------------------------------------------------------------- /docs/images/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.pdf -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo.png -------------------------------------------------------------------------------- /docs/images/logo_square.afdesign: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo_square.afdesign -------------------------------------------------------------------------------- /docs/images/logo_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/images/logo_square.png -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/afosr_logo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/afosr_logo.gif -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/favicon.ico -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/logo.png -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/tensorboard_mnist_example_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/tensorboard_mnist_example_plot.png -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/tensorflow_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/docs/src/main/paradox/assets/images/tensorflow_logo.png -------------------------------------------------------------------------------- /docs/src/main/paradox/assets/images/tensorflow_logo_square.svg: -------------------------------------------------------------------------------- 1 | 2 | 13 | 15 | 17 | 18 | 20 | image/svg+xml 21 | 23 | 24 | 25 | 26 | 27 | 30 | 32 | 37 | 42 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /docs/src/main/paradox/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | It would be awesome if people could contribute to this library. Given 4 | its scope and its early state, before I settle on the API for some of 5 | the features, I would really appreciate contributions on the following: 6 | 7 | - **Unit Tests:** Currently unit tests are missing for a big part of 8 | the library and it would be extremely useful if we had those. 9 | - **Examples:** Examples of code using the library would be great and 10 | would also make issues come up early so they can be fixed. 11 | -------------------------------------------------------------------------------- /docs/src/main/paradox/guides/adding_ops.md: -------------------------------------------------------------------------------- 1 | # Adding Support for New Ops 2 | 3 | TensorFlow graphs are constructed by buildings ops that 4 | receive tensors as input and produce tensors as output. 5 | Internally, multiple *kernels* are registered for each op, 6 | which are implementations of the op for different 7 | architectures (e.g., CPU kernels, CUDA GPU kernels, etc.). 8 | TensorFlow Scala offers the 9 | @scaladoc[Op.Builder](org.platanios.tensorflow.api.ops.Op.Builder) 10 | interface to allow users to create arbitrary ops that the 11 | TensorFlow runtime supports. 12 | 13 | For example, the implentation of `tf.add(x, y)` in 14 | TensorFlow Scala looks like this: 15 | 16 | @@snip [AddingOps.scala](/docs/src/main/scala/AddingOps.scala) { #add_op_example } 17 | -------------------------------------------------------------------------------- /docs/src/main/paradox/guides/graph_construction.md: -------------------------------------------------------------------------------- 1 | # Graph Construction 2 | 3 | The low level API can be used to define computations that will be executed at a later point, and potentially execute 4 | them. It can also be used to create custom layers for the [Learn API](#neural-networks-2). The main type of object 5 | underlying the low level API is the [`Output`][output], which represents the value of a [`Tensor`][tensor] that has not 6 | yet been computed. Its name comes from the fact that it represents the *output* of some computation. An 7 | [`Output`][output] object thus represents a partially defined computation that will eventually produce a value. Core 8 | TensorFlow programs work by first building a graph of [`Output`][output] objects, detailing how each output is computed 9 | based on the other available outputs, and then by running parts of this graph to achieve the desired results. 10 | 11 | Similar to a [`Tensor`][tensor], each element in an [`Output`][output] has the same data type, and the data type is 12 | always known. However, the shape of an [`Output`][output] might be only partially known. Most operations produce tensors 13 | of fully-known shapes if the shapes of their inputs are also fully known, but in some cases it's only possible to find 14 | the shape of a tensor at graph execution time. 15 | 16 | It is important to understand the main concepts underlying the core API: 17 | 18 | - **Tensor:** 19 | - **Output:** 20 | - **Sparse Output:** 21 | - **Placeholder:** 22 | - **Variable:** 23 | - **Graph:** 24 | - **Session:** 25 | 26 | With the exception of [`Variable`][variable]s, the value of outputs is immutable, which means that in the context of a 27 | single execution, outputs only have a single value. However, evaluating the same output twice can result in different 28 | values. For example, that tensor may be the result of reading data from disk, or generating a random number. 29 | 30 | ## Graph 31 | 32 | 33 | ## Working with Outputs 34 | 35 | 36 | ### Evaluating Outputs 37 | 38 | 39 | ### Printing Outputs 40 | 41 | 42 | ### Logging 43 | 44 | Logging in the native TensorFlow library can be controlled by setting the `TF_CPP_MIN_LOG_LEVEL` environment variable: 45 | 46 | - `0`: Debug level (default). 47 | - `1`: Warning level. 48 | - `2`: Error level. 49 | - `3`: Fatal level. 50 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes.md: -------------------------------------------------------------------------------- 1 | # Release Notes 2 | 3 | @@toc { depth=1 } 4 | 5 | @@@ index 6 | 7 | - [0.5.1](release_notes/0.5.1.md) 8 | - [0.5.0](release_notes/0.5.0.md) 9 | - [0.4.1](release_notes/0.4.1.md) 10 | - [0.4.0](release_notes/0.4.0.md) 11 | - [0.3.0](release_notes/0.3.0.md) 12 | - [0.2.4](release_notes/0.2.4.md) 13 | - [0.2.3](release_notes/0.2.3.md) 14 | - [0.2.2](release_notes/0.2.2.md) 15 | - [0.2.1](release_notes/0.2.1.md) 16 | - [0.2.0](release_notes/0.2.0.md) 17 | - [0.1.1](release_notes/0.1.1.md) 18 | - [0.1.0](release_notes/0.1.0.md) 19 | 20 | @@@ 21 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.1.0.md: -------------------------------------------------------------------------------- 1 | # Release 0.1.0 2 | 3 | This is the first official release of TensorFlow for Scala. The library 4 | website will soon be updated with information about the functionality 5 | supported by this API. Most of the main TensorFlow Python API 6 | functionality is already supported. 7 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.1.1.md: -------------------------------------------------------------------------------- 1 | # Release 0.1.1 2 | 3 | This release fixes the following bugs: 4 | 5 | - Issue with the packaged pre-compiled TensorFlow binaries that 6 | affected Linux platforms. 7 | - Learn API bug where the shared name of input iterators was being 8 | set incorrectly. 9 | 10 | I also switched to using CircleCI for continuous integration, instead 11 | of TravisCI. 12 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.2.0.md: -------------------------------------------------------------------------------- 1 | # Release 0.2.0 2 | 3 | In this release we have: 4 | 5 | - Added support for incremental compilation. 6 | - Added support for [Horovod](https://github.com/uber/horovod). 7 | - Added support for timelines to allow for easy profiling of 8 | TensorFlow graphs. 9 | - Fixed a major memory leak (issue #87). 10 | - Updated the JNI bindings to be compatible with the TensorFlow 11 | 1.9.0 release. 12 | - Added support for obtaining the list of available devices from 13 | within Scala. 14 | - Fixed bugs for some control flow ops. 15 | - Added support for `tf.cases`. 16 | - Added support for the RMSProp optimizer, the lazy Adam optimizer, 17 | the [AMSGrad](https://openreview.net/pdf?id=ryQu7f-RZ) optimizer, 18 | the lazy AMSGrad optimizer, and the 19 | [YellowFin](https://arxiv.org/pdf/1706.03471.pdf) optimizer. 20 | - Added more learning rate decay schemes: 21 | - Cosine decay. 22 | - Cycle-linear 10x decay. 23 | - Square-root decay. 24 | - More warm-up decay schedules. 25 | - Added support for dataset interleave ops. 26 | - Fixed some bugs related to variable scopes and variable sharing. 27 | - Fixed some bugs related to functional ops. 28 | - Added support for some new image-related ops, under the namespace 29 | `tf.image`. 30 | - Improved consistency for the creation of initializer ops. 31 | - Added support for the `tf.initializer` op creation context. 32 | - Exposed part of the `TensorArray` API. 33 | - Exposed `tf.Op.Builder` in the public API. 34 | - Improvements to the learn API: 35 | - Refactored `mode` into an implicit argument. 36 | - Improved the evaluator hook. 37 | - Removed the layer creation context mechanism, to be refactored 38 | later. It was causing some issues due to bad design and unclear 39 | semantics. The plan is to implement this, in the near future, as 40 | wrapper creation context layers. 41 | - Improved the `Model` class. 42 | - Fixed a bug that was causing some issues related to inference 43 | hooks in the in-memory estimator. 44 | - Improved logging. 45 | - Added support for reading and writing numpy (i.e., `.npy`) files. 46 | - Added a logo. :) 47 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.2.1.md: -------------------------------------------------------------------------------- 1 | # Release 0.2.1 2 | 3 | In this release we have fixed an issue related to the packaging and 4 | distributing of the pre-compiled TensorFlow shared libraries. 5 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.2.2.md: -------------------------------------------------------------------------------- 1 | # Release 0.2.2 2 | 3 | In this release we have updated the precompiled TensorFlow binaries 4 | distributed with this library. 5 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.2.3.md: -------------------------------------------------------------------------------- 1 | # Release 0.2.3 2 | 3 | Added compatibility with TensorFlow 1.9-rc1. 4 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.2.4.md: -------------------------------------------------------------------------------- 1 | # Release 0.2.4 2 | 3 | Fixed an issue with the packaged pre-compiled TensorFlow binaries that 4 | affected Linux platforms. 5 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.3.0.md: -------------------------------------------------------------------------------- 1 | # Release 0.3.0 2 | 3 | With this release we have finally added support for static data type 4 | information for tensors (not for symbolic tensors yet though -- for now 5 | we effectively have support for a statically-typed version of `numpy` 6 | for Scala). This is an important milestone and contributes significantly 7 | to type safety, which can help catch errors at compile time, rather than 8 | runtime. For example: 9 | 10 | ```scala 11 | val t1 = Tensor(0.5, 1) // The inferred type is Tensor[FLOAT64]. 12 | val t2 = Tensor(1, 2) // The inferred type is Tensor[INT32]. 13 | val t3 = t1 + t2 // The inferred type is Tensor[FLOAT64]. 14 | val t4 = t3.isNaN // The inferred type is Tensor[BOOLEAN]. 15 | val t5 = t3.any() // Fails at compile-time because `any()` is only 16 | // supported for Tensor[BOOLEAN]. 17 | ``` 18 | 19 | Other new features include: 20 | 21 | - Improvements to the high-level learn API: 22 | - Layers can now provide and use their own parameter generator, and 23 | can also access the current training step 24 | (using `Layer.currentStep`). 25 | - Layers now support `.map(...)`. 26 | - Added support for batch normalization. 27 | - Added support for `tf.logSigmoid` and `tf.lrn`. 28 | - Added support for the following new metrics: 29 | - Grouped precision. 30 | - Precision-at-k. 31 | - `data` module: 32 | - Added support for loading the extreme classification repository 33 | datasets (i.e., `data.XCLoader`). 34 | - Added support for randomly splitting datasets. 35 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.4.0.md: -------------------------------------------------------------------------------- 1 | # Release 0.4.0 2 | 3 | This is a major release with a lot of new features related to static 4 | types for tensors and ops. The graph construction API is now 5 | statically-typed, thus enabling much better type safety than before. 6 | 7 | Tensors and outputs are now statically-typed and the types used are the 8 | Scala types that correspond to the tensors' TensorFlow data types. For 9 | example: 10 | 11 | ```scala 12 | val t1 = Tensor(0.5, 1) // The inferred type is Tensor[Double]. 13 | val t2 = Tensor(1, 2) // The inferred type is Tensor[Int]. 14 | val t3 = t1 + t2 // The inferred type is Tensor[Double]. 15 | val t4 = t3.isNaN // The inferred type is Tensor[Boolean]. 16 | val t5 = t3.any() // Fails at compile-time because `any()` is only 17 | // supported for Tensor[Boolean]. 18 | ``` 19 | 20 | A similar situation now applies to `Output`s. `Op`s are also typed and 21 | so is the auto-differentiation implementation. 22 | 23 | This resulted in major simplifications in the data pipeline and the high 24 | level learn API. Datasets and dataset iterators do not "carry" `T`, `V`, 25 | `D`, and `S` types with them now, but rather just the type of the 26 | elements they contain/produce. 27 | 28 | A new type trait called `TF` is also introduced that denotes supported 29 | Scala types in TensorFlow (e.g., `TF[Int]` and `TF[Float]`). Similarly, 30 | some more type traits are introduced to denote type constraints for 31 | various ops (e.g., `IsIntOrUInt[Int]`, `IsIntOrUInt[Long]`, 32 | `IsFloatOrDouble[Float]`, etc.). These type traits are powered by a 33 | general implementation of union types for Scala. 34 | 35 | Other new features include: 36 | 37 | - `data` module: 38 | - Added support for the `mapAndBatch` transformation. 39 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.4.1.md: -------------------------------------------------------------------------------- 1 | # Release 0.4.1 2 | 3 | Fixed the precompiled TensorFlow binaries, and also added the following 4 | new features: 5 | 6 | - `io` module: 7 | - Added support for a new `TFRecordWriter`. 8 | - `ops` module: 9 | - Added a new ops namespace, `sparse`, that includes all sparse ops. 10 | - Added support for `sparse.reorder` and `sparse.merge`. 11 | - Added support for parsing TF records. 12 | - `data` module: 13 | - Added support for `Dataset.shuffleAndRepeat`. 14 | - `optimizers` module: 15 | - Added support for the Adafactor optimizer. 16 | - Renamed `SqrtDecay` to `RSqrtDecay` which is more appropriate. 17 | - `math` module: 18 | - Added support for `batchGather`. 19 | - Added support for bitwise ops. 20 | - `rnn` module: 21 | - Simplified the attention mechanisms functionality so that it is 22 | now not required to tile memory tensors for beam search outside 23 | the beam search decoder. 24 | - Moved the `seq2seq` module to a separate repository (that of 25 | [Symphony Machine Translation](https://github.com/eaplatanios/symphony-mt)). 26 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.5.0.md: -------------------------------------------------------------------------------- 1 | # Release 0.5.0 2 | 3 | This release introduces support for TensorFlow 2.0. 4 | -------------------------------------------------------------------------------- /docs/src/main/paradox/release_notes/0.5.1.md: -------------------------------------------------------------------------------- 1 | # Release 0.5.1 2 | 3 | This release introduces support for TensorFlow 2.2 and 4 | Scala 2.13 and drops support for Scala 2.11. The 5 | distributed precompiled binaries for this version will only 6 | work with CUDA 10.1 on Linux. Finally, this release also 7 | brings improved support for implicit derivations in some 8 | cases where case classes over tensors are used. 9 | -------------------------------------------------------------------------------- /docs/src/main/scala/AddingOps.scala: -------------------------------------------------------------------------------- 1 | import org.platanios.tensorflow.api._ 2 | 3 | object AddingOps { 4 | // #add_op_example 5 | def add[T: TF : IsNotQuantized]( 6 | x: Output[T], 7 | y: Output[T], 8 | name: String = "Add" 9 | ): Output[T] = { 10 | Op.Builder[(Output[T], Output[T]), Output[T]]( 11 | opType = "Add", 12 | name = name, 13 | input = (x, y) 14 | ).setGradientFn(addGradient(_, _)(TF[T], IsNotQuantized[T])) 15 | .build().output 16 | } 17 | 18 | protected def addGradient[T: TF : IsNotQuantized]( 19 | op: Op[(Output[T], Output[T]), Output[T]], 20 | outputGradient: Output[T] 21 | ): (Output[T], Output[T]) = { 22 | val xShape = tf.shape(op.input._1) 23 | val yShape = tf.shape(op.input._2) 24 | val (rx, ry) = tf.broadcastGradientArguments(xShape, yShape) 25 | (tf.reshape(tf.sum(outputGradient, rx), xShape), 26 | tf.reshape(tf.sum(outputGradient, ry), yShape)) 27 | } 28 | // #add_op_example 29 | } 30 | -------------------------------------------------------------------------------- /docs/src/main/scala/Estimators.scala: -------------------------------------------------------------------------------- 1 | import org.platanios.tensorflow.api._ 2 | import org.platanios.tensorflow.api.ops.metrics.Metric 3 | import org.platanios.tensorflow.api.tf.learn._ 4 | 5 | // #inference_model 6 | trait InferenceModel[In, Out] extends Model { 7 | def buildInferOps(): Model.InferOps[In, Out] 8 | } 9 | // #inference_model 10 | 11 | // #trainable_models 12 | trait TrainableModel[In, TrainIn, Out, TrainOut, Loss, EvalIn] extends InferenceModel[In, Out] { 13 | def buildTrainOps(): Model.TrainOps[TrainIn, TrainOut, Loss] 14 | def buildEvalOps(metrics: Seq[Metric[EvalIn, Output[Float]]]): Model.EvalOps[TrainIn, Out] 15 | } 16 | 17 | trait SupervisedTrainableModel[In, TrainIn, Out, TrainOut, Loss] extends TrainableModel[In, (In, TrainIn), Out, TrainOut, Loss, (Out, (In, TrainIn))] { 18 | override def buildTrainOps(): Model.TrainOps[(In, TrainIn), TrainOut, Loss] 19 | override def buildEvalOps(metrics: Seq[Metric[(Out, (In, TrainIn)), Output[Float]]]): Model.EvalOps[(In, TrainIn), Out] 20 | } 21 | 22 | trait UnsupervisedTrainableModel[In, Out, Loss] extends TrainableModel[In, In, Out, Out, Loss, Out] { 23 | override def buildTrainOps(): Model.TrainOps[In, Out, Loss] 24 | override def buildEvalOps(metrics: Seq[Metric[Out, Output[Float]]]): Model.EvalOps[In, Out] 25 | } 26 | // #trainable_models 27 | -------------------------------------------------------------------------------- /docs/src/main/scala/installation.sh: -------------------------------------------------------------------------------- 1 | // #clone_repository 2 | git clone https://github.com/tensorflow/tensorflow.git 3 | cd 4 | git checkout 1aaa68d93c6b2f4151446eb211399b4330c96a09 5 | // #clone_repository 6 | 7 | // #compile_tf 8 | ./configure 9 | bazel build --config=opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 //tensorflow:libtensorflow.so 10 | // #compile_tf 11 | 12 | // #apt_get_install_protobuf 13 | apt-get install protobuf-compiler 14 | // #apt_get_install_protobuf 15 | 16 | // #brew_install_protobuf 17 | brew install protobuf 18 | // #brew_install_protobuf 19 | -------------------------------------------------------------------------------- /modules/api/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/Documentation.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api 17 | 18 | /** Groups together documentation from various sub-packages. 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | private[api] trait Documentation extends ops.Documentation 23 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/config/SummaryConfig.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.config 17 | 18 | /** Summary configuration used while training models. 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | sealed trait SummaryConfig 23 | 24 | /** Summary configuration for not saving any summaries. */ 25 | case object NoSummaries extends SummaryConfig 26 | 27 | /** Summary configuration for step-based summaries (i.e., summaries every `n` steps). 28 | * 29 | * @param steps Save summaries every this many steps. 30 | * 31 | */ 32 | case class StepBasedSummaries(steps: Int = 1000) extends SummaryConfig { 33 | require(steps >= 0, s"'steps' (set to $steps) needs to be a non-negative integer.") 34 | } 35 | 36 | /** Summary configuration for time-based summaries (i.e., summaries every `n` seconds). 37 | * 38 | * @param seconds Save summaries every this many seconds. 39 | * 40 | */ 41 | case class TimeBasedSummaries(seconds: Int = 600) extends SummaryConfig { 42 | require(seconds >= 0, s"'seconds' (set to $seconds) needs to be a non-negative integer.") 43 | } 44 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/config/TensorBoardConfig.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.config 17 | 18 | import java.nio.file.Path 19 | 20 | /** TensorBoard configuration, which can be used when training using estimators. 21 | * 22 | * @param logDir Directory containing the logs and summaries that the TensorBoard instance should use. 23 | * @param host Host to use for the TensorBoard service. 24 | * @param port Port to use for the TensorBoard service. 25 | * @param reloadInterval Interval at which the backend reloads more data in seconds. 26 | * 27 | * @author Emmanouil Antonios Platanios 28 | */ 29 | case class TensorBoardConfig( 30 | logDir: Path, 31 | host: String = "localhost", 32 | port: Int = 6006, 33 | reloadInterval: Int = 5 34 | ) { 35 | private[api] val processBuilder = new ProcessBuilder( 36 | "tensorboard", 37 | "--logdir", logDir.toAbsolutePath.toString, 38 | "--host", host, 39 | "--port", port.toString, 40 | "--reload_interval", reloadInterval.toString) 41 | } 42 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/core/Devices.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.core 17 | 18 | import org.platanios.tensorflow.api.core.client.SessionConfig 19 | import org.platanios.tensorflow.jni.{Session => NativeSession} 20 | import org.platanios.tensorflow.proto.DeviceAttributes 21 | 22 | /** Contains helper methods for dealing with devices. 23 | * 24 | * @author Emmanouil Antonios Platanios 25 | */ 26 | object Devices { 27 | /** Returns a sequence containing information for all the devices available to this local process. 28 | * 29 | * @param sessionConfig Optional session configuration to use. 30 | * @return Sequence with information for all the devices available to this local process. 31 | */ 32 | def local(sessionConfig: Option[SessionConfig] = None): Seq[DeviceAttributes] = { 33 | val devices = NativeSession.deviceList(sessionConfig.map(_.toConfigProto.toByteArray).orNull).toSeq 34 | devices.map(DeviceAttributes.parseFrom) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/core/Implicits.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.core 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | private[api] trait Implicits 22 | extends client.Implicits 23 | with types.Implicits { 24 | // TODO: [INDEXERS] Add begin mask support (not simple). 25 | 26 | implicit def intToIndex(index: Int): Index = Index(index = index) 27 | 28 | implicit def intToIndexerConstruction(n: Int): IndexerConstructionWithOneNumber = { 29 | IndexerConstructionWithOneNumber(n) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/core/Logging.scala: -------------------------------------------------------------------------------- 1 | package org.platanios.tensorflow.api.core 2 | 3 | import org.platanios.tensorflow.jni.TensorFlow 4 | 5 | object Logging { 6 | 7 | /** Represents the TensorFlow logging level. */ 8 | sealed trait Level { 9 | private[Logging] def value: Int 10 | } 11 | 12 | case object DEBUG extends Level { 13 | override private[Logging] def value = 0 14 | } 15 | 16 | case object INFO extends Level { 17 | override private[Logging] def value = 1 18 | } 19 | 20 | case object WARNING extends Level { 21 | override private[Logging] def value = 2 22 | } 23 | 24 | case object ERROR extends Level { 25 | override private[Logging] def value = 3 26 | } 27 | 28 | /** Sets the current TensorFlow logging [[Level]]. */ 29 | def setLoggingLevel(level: Level): Unit = { 30 | TensorFlow.setLogLevel(level.value.toString) 31 | } 32 | 33 | /** Returns the current TensorFlow logging [[Level]]. */ 34 | def currentLoggingLevel: Level = TensorFlow.getLogLevel match { 35 | case null | "0" => DEBUG 36 | case "1" => INFO 37 | case "2" => WARNING 38 | case "3" => ERROR 39 | case _ => throw new AssertionError("This should be unreachable.") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/core/client/Implicits.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.core.client 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | private[core] trait Implicits 22 | extends FeedMap.Implicits 23 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/core/distributed/Protocol.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.core.distributed 17 | 18 | /** Trait used to represented supported communication protocols for [[Server]]s. 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | sealed trait Protocol { 23 | val name: String 24 | } 25 | 26 | /** GRPC communication protocol. */ 27 | case object GRPC extends Protocol { 28 | override val name: String = "grpc" 29 | } 30 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/helpers/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.implicits 17 | 18 | import org.platanios.tensorflow.api.core.Shape 19 | import org.platanios.tensorflow.api.core.types.{DataType, Variant} 20 | import org.platanios.tensorflow.api.ops.Output 21 | import org.platanios.tensorflow.api.ops.data.Dataset 22 | 23 | package object helpers { 24 | type SparseDataType[T] = (DataType[Long], DataType[T], DataType[Long]) 25 | type IndexedSlicesDataType[T] = (DataType[Int], DataType[T], DataType[Int]) 26 | type SparseShape = (Shape, Shape, Shape) 27 | 28 | // TODO: [FUNCTIONS] !!! Find a better way to deal with this for use in the reduce function of the "GroupByWindowDataset". 29 | 30 | case class VariantDataset[T: OutputStructure] protected( 31 | handle: Output[Variant], 32 | private val _outputDataTypes: Any = null, 33 | private val _outputShapes: Any = null 34 | ) extends Dataset[T] { 35 | override val name: String = "VariantDataset" 36 | 37 | override def createHandle[D, S]()(implicit 38 | evOutputToDataType: OutputToDataType.Aux[T, D], 39 | evOutputToShape: OutputToShape.Aux[T, S] 40 | ): Output[Variant] = { 41 | handle 42 | } 43 | 44 | override def outputDataTypes[D](implicit ev: OutputToDataType.Aux[T, D]): D = { 45 | _outputDataTypes.asInstanceOf[D] 46 | } 47 | 48 | override def outputShapes[S](implicit ev: OutputToShape.Aux[T, S]): S = { 49 | _outputShapes.asInstanceOf[S] 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/ops/ControlFlowImplicits.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.implicits.ops 17 | 18 | import org.platanios.tensorflow.api.ops.UntypedOp 19 | 20 | trait ControlFlowImplicits { 21 | implicit class ControlFlowOps(val op: UntypedOp) { 22 | /** Returns `true` if the provided op is within a cond statement. */ 23 | def isInCond: Boolean = { 24 | op.controlFlowContext.flatMap(_.condContext).isDefined 25 | } 26 | 27 | /** Returns `true` if the provided op is within a while loop statement. */ 28 | def isInWhileLoop: Boolean = { 29 | op.controlFlowContext.flatMap(_.whileLoopContext()).isDefined 30 | } 31 | 32 | /** Returns `true` if the provided op is within an XLA control flow context. */ 33 | def isInXLAContext: Boolean = { 34 | val xlaCompile = { 35 | try { 36 | op.booleanAttribute("_XlaCompile") 37 | } catch { 38 | case _: IllegalArgumentException => false 39 | } 40 | } 41 | xlaCompile || op.controlFlowContext.flatMap(_.xlaContext).isDefined 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/implicits/ops/EmbeddingImplicits.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.implicits.ops 17 | 18 | import org.platanios.tensorflow.api.core.types.{IsNotQuantized, TF} 19 | import org.platanios.tensorflow.api.ops.Embedding.{OutputParameters, VariableParameters} 20 | import org.platanios.tensorflow.api.ops.{EmbeddingMap, EmbeddingParameters, Output} 21 | import org.platanios.tensorflow.api.ops.variables.Variable 22 | 23 | trait EmbeddingImplicits { 24 | implicit def singlePartitionEmbeddingMap[T: TF]( 25 | parameters: EmbeddingParameters[T] 26 | ): EmbeddingMap[T] = { 27 | EmbeddingMap(Seq(parameters)) 28 | } 29 | 30 | implicit def multiplePartitionsEmbeddingMap[T: TF]( 31 | parameters: Seq[EmbeddingParameters[T]] 32 | ): EmbeddingMap[T] = { 33 | EmbeddingMap(parameters) 34 | } 35 | 36 | implicit def outputToEmbeddingMap[T: TF : IsNotQuantized]( 37 | parameters: Output[T] 38 | ): EmbeddingMap[T] = { 39 | OutputParameters(parameters) 40 | } 41 | 42 | implicit def variableToEmbeddingMap[T: TF : IsNotQuantized]( 43 | parameters: Variable[T] 44 | ): EmbeddingMap[T] = { 45 | VariableParameters(parameters) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/io/CompressionType.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.io 17 | 18 | /** 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | sealed trait CompressionType { 23 | val name: String 24 | 25 | override def toString: String = name 26 | } 27 | 28 | case object NoCompression extends CompressionType { 29 | override val name: String = "" 30 | } 31 | 32 | case object ZLIBCompression extends CompressionType { 33 | override val name: String = "ZLIB" 34 | } 35 | 36 | case object GZIPCompression extends CompressionType { 37 | override val name: String = "GZIP" 38 | } 39 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/io/Loader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.io 17 | 18 | /** Simple trait for representing data loaders of arbitrary types. */ 19 | trait Loader[T] { 20 | /** Loads entries and returns an iterator over them. */ 21 | def load(): Iterator[T] 22 | } 23 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/io/TFRecordWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.io 17 | 18 | import org.platanios.tensorflow.api.utilities.{CRC32C, Coding} 19 | import org.platanios.tensorflow.proto.Example 20 | 21 | import java.io.BufferedOutputStream 22 | import java.nio.file.{Files, Path, StandardOpenOption} 23 | 24 | /** Helper used to write `Example` protocol buffers to TensorFlow record files. 25 | * 26 | * @param filePath TensorFlow record file path. 27 | * 28 | * @author Emmanouil Antonios Platanios 29 | */ 30 | case class TFRecordWriter(filePath: Path) { 31 | protected var fileStream: BufferedOutputStream = { 32 | new BufferedOutputStream(Files.newOutputStream( 33 | filePath, StandardOpenOption.CREATE_NEW, StandardOpenOption.APPEND)) 34 | } 35 | 36 | /** Appends `example` to the TensorFlow records file. */ 37 | def write(example: Example): Unit = { 38 | val recordBytes = example.toByteArray 39 | // Format of a single record: 40 | // uint64 length 41 | // uint32 masked crc of length 42 | // byte data[length] 43 | // uint32 masked crc of data 44 | val encLength = Coding.encodeFixedInt64(recordBytes.length) 45 | val encLengthMaskedCrc = Coding.encodeFixedInt32(CRC32C.mask(CRC32C.value(encLength))) 46 | val encDataMaskedCrc = Coding.encodeFixedInt32(CRC32C.mask(CRC32C.value(recordBytes))) 47 | fileStream.write(encLength ++ encLengthMaskedCrc ++ recordBytes ++ encDataMaskedCrc) 48 | } 49 | 50 | /** Pushes outstanding examples to disk. */ 51 | def flush(): Unit = { 52 | fileStream.flush() 53 | } 54 | 55 | /** Calls `flush()` and then closes the current TensorFlow records file. */ 56 | def close(): Unit = { 57 | fileStream.close() 58 | } 59 | } 60 | 61 | object TFRecordWriter { 62 | def apply(filePath: Path): TFRecordWriter = { 63 | new TFRecordWriter(filePath) 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/io/events/EventType.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.io.events 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | sealed trait EventType 22 | case object ScalarEventType extends EventType 23 | case object ImageEventType extends EventType 24 | case object AudioEventType extends EventType 25 | case object HistogramEventType extends EventType 26 | case object CompressedHistogramEventType extends EventType 27 | case object TensorEventType extends EventType 28 | case object GraphEventType extends EventType 29 | case object MetaGraphEventType extends EventType 30 | case object RunMetadataEventType extends EventType 31 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/io/events/SummaryFileWriterCache.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.io.events 17 | 18 | import org.platanios.tensorflow.api.ops.Op 19 | import org.platanios.tensorflow.api.core.Graph 20 | 21 | import java.nio.file.Path 22 | 23 | import scala.collection.mutable 24 | 25 | /** Cache for summary file writers, which caches one writer per directory. 26 | * 27 | * @author Emmanouil Antonios Platanios 28 | */ 29 | object SummaryFileWriterCache { 30 | private[this] val cache: mutable.Map[Path, SummaryFileWriter] = mutable.HashMap.empty[Path, SummaryFileWriter] 31 | 32 | /** Returns the summary file writer responsible for the specified directory. */ 33 | def get(directory: Path, graph: Graph = Op.currentGraph): SummaryFileWriter = cache synchronized { 34 | cache.getOrElseUpdate(directory, SummaryFileWriter(directory, graph)) 35 | } 36 | 37 | /** Clears the cached summary writers. Currently only used for testing. */ 38 | private[io] def clear(): Unit = cache synchronized { 39 | // Make sure all the writers are closed. 40 | // Otherwise, open file handles may hang around, blocking deletions on Windows. 41 | cache.values.foreach(_.close()) 42 | cache.clear() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/Mode.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn 17 | 18 | /** Represents the mode that a model is on, while being used by a learner (e.g., training mode, evaluation mode, or 19 | * prediction mode). 20 | * 21 | * @author Emmanouil Antonios Platanios 22 | */ 23 | sealed trait Mode { 24 | val isTraining: Boolean 25 | } 26 | 27 | case object TRAINING extends Mode { 28 | override val isTraining: Boolean = true 29 | } 30 | 31 | case object EVALUATION extends Mode { 32 | override val isTraining: Boolean = false 33 | } 34 | 35 | case object INFERENCE extends Mode { 36 | override val isTraining: Boolean = false 37 | } 38 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/ModelInstance.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn 17 | 18 | import org.platanios.tensorflow.api.core.types.{IsFloatOrDouble, TF} 19 | import org.platanios.tensorflow.api.ops.{Output, OutputLike, UntypedOp} 20 | import org.platanios.tensorflow.api.ops.data.DatasetIterator 21 | import org.platanios.tensorflow.api.ops.variables.Variable 22 | 23 | // TODO: [LEARN] What about "trainOutput"? 24 | 25 | /** Represents an instance of a constructed model. Such instances are constructed by estimators and passed on to 26 | * model-dependent hooks. 27 | * 28 | * @author Emmanouil Antonios Platanios 29 | */ 30 | case class ModelInstance[In, TrainIn, Out, TrainOut, Loss: TF : IsFloatOrDouble, EvalIn]( 31 | model: TrainableModel[In, TrainIn, Out, TrainOut, Loss, EvalIn], 32 | configuration: Configuration, 33 | trainInputIterator: Option[DatasetIterator[TrainIn]] = None, 34 | trainInput: Option[TrainIn] = None, 35 | output: Option[Out] = None, 36 | trainOutput: Option[TrainOut] = None, 37 | loss: Option[Output[Loss]] = None, 38 | gradientsAndVariables: Option[Seq[(OutputLike[Loss], Variable[Any])]] = None, 39 | trainOp: Option[UntypedOp] = None) 40 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/estimators/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object estimators { 22 | private[api] trait API { 23 | type Estimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.Estimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] 24 | type InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.InMemoryEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] 25 | type FileBasedEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] = estimators.FileBasedEstimator[In, TrainIn, Out, TrainOut, Loss, EvalIn] 26 | 27 | val Estimator : estimators.Estimator.type = estimators.Estimator 28 | val InMemoryEstimator : estimators.InMemoryEstimator.type = estimators.InMemoryEstimator 29 | val FileBasedEstimator: estimators.FileBasedEstimator.type = estimators.FileBasedEstimator 30 | } 31 | 32 | private[api] object API extends API 33 | } 34 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/hooks/ModelDependentHook.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn.hooks 17 | 18 | import org.platanios.tensorflow.api.learn.ModelInstance 19 | 20 | /** Represents hooks that may dependent on the constructed model. 21 | * 22 | * This class offers the `modelInstance` field that sub-classes can access and that contains information specific to 23 | * the created model. It is only updated when the model graph is constructed (i.e., it is not updated while recovering 24 | * failed sessions). 25 | * 26 | * For example, a hook that logs the loss function value depends on the created loss op, or an evaluation hook may 27 | * depends on multiple ops created as part of the model. 28 | * 29 | * @author Emmanouil Antonios Platanios 30 | */ 31 | trait ModelDependentHook[In, TrainIn, Out, TrainOut, Loss, EvalIn] extends Hook { 32 | protected var modelInstance: ModelInstance[In, TrainIn, Out, TrainOut, Loss, EvalIn] = _ 33 | 34 | /** This method will be called by estimators at graph construction time, before `begin()`. It will **not** be called 35 | * again if a session fails and is recovered. */ 36 | private[learn] final def setModelInstance( 37 | modelInstance: ModelInstance[In, TrainIn, Out, TrainOut, Loss, EvalIn] 38 | ): Unit = { 39 | this.modelInstance = modelInstance 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/Embedding.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn.layers 17 | 18 | import org.platanios.tensorflow.api._ 19 | import org.platanios.tensorflow.api.core.types.{IsNotQuantized, TF} 20 | import org.platanios.tensorflow.api.learn.{Mode, layers} 21 | import org.platanios.tensorflow.api.ops 22 | import org.platanios.tensorflow.api.ops.Embedding.OutputParameters 23 | import org.platanios.tensorflow.api.ops.{EmbeddingMap, Output} 24 | import org.platanios.tensorflow.api.tensors.Tensor 25 | 26 | object Embedding { 27 | private[layers] trait API { 28 | type Embedding[T] = layers.Embedding[T] 29 | 30 | val Embedding: layers.Embedding.type = layers.Embedding 31 | } 32 | 33 | object API extends API 34 | } 35 | 36 | case class Embedding[T: TF : IsNotQuantized]( 37 | override val name: String, 38 | vocabularySize: Int, 39 | embeddingSize: Int, 40 | partitionStrategy: ops.Embedding.PartitionStrategy = ops.Embedding.ModStrategy, 41 | transformFn: Output[T] => Output[T] = null, 42 | maxNorm: Tensor[T] = null 43 | ) extends Layer[Output[Int], Output[T]](name) { 44 | override val layerType: String = "Embedding" 45 | 46 | override def forwardWithoutContext( 47 | input: Output[Int] 48 | )(implicit mode: Mode): Output[T] = { 49 | val embeddingMap = getParameter[T]("EmbeddingMap", Shape(vocabularySize, embeddingSize)) 50 | ops.Embedding.embeddingLookup( 51 | EmbeddingMap(Seq(OutputParameters(embeddingMap))), input, partitionStrategy, transformFn, 52 | if (maxNorm == null) null else ops.basic.Basic.constant(maxNorm), 53 | name) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/core/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn.layers 17 | 18 | import org.platanios.tensorflow.api.core.types.{IsHalfOrFloatOrDouble, TF} 19 | import org.platanios.tensorflow.api.ops.Output 20 | 21 | /** 22 | * @author Emmanouil Antonios Platanios 23 | */ 24 | package object core { 25 | private[layers] trait API { 26 | def MLP[T: TF : IsHalfOrFloatOrDouble]( 27 | name: String, 28 | hiddenLayers: Seq[Int], 29 | outputSize: Int, 30 | activation: String => Layer[Output[T], Output[T]] = null, 31 | dropout: Float = 0.0f 32 | ): Layer[Output[T], Output[T]] = { 33 | if (hiddenLayers.isEmpty) { 34 | Linear(s"$name/Linear", outputSize) 35 | } else { 36 | val activationWithDefault = { 37 | if (activation == null) 38 | (name: String) => ReLU[T](name, 0.1f) 39 | else 40 | activation 41 | } 42 | val size = hiddenLayers.head 43 | var layer = Linear(s"$name/Layer0/Linear", size) >> activationWithDefault(s"$name/Layer0/Activation") 44 | hiddenLayers.zipWithIndex.tail.foreach(s => { 45 | layer = layer >> 46 | Linear(s"$name/Layer${s._2}/Linear", s._1) >> 47 | Dropout(s"$name/Layer${s._2}/Dropout", 1 - dropout) >> 48 | activationWithDefault(s"$name/Layer${s._2}/Activation") 49 | }) 50 | layer >> Linear("OutputLayer/Linear", outputSize) 51 | } 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/rnn/cell/RNNCell.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn.layers.rnn.cell 17 | 18 | import org.platanios.tensorflow.api.implicits.helpers.OutputToShape 19 | import org.platanios.tensorflow.api.learn.Mode 20 | import org.platanios.tensorflow.api.learn.layers.Layer 21 | import org.platanios.tensorflow.api.ops 22 | import org.platanios.tensorflow.api.ops.variables.VariableScope 23 | 24 | /** 25 | * @param name Name scope (also acting as variable scope) for this layer. 26 | * 27 | * @author Emmanouil Antonios Platanios 28 | */ 29 | abstract class RNNCell[Out, State, OutShape, StateShape]( 30 | override val name: String 31 | )(implicit 32 | val evOutputToShapeOut: OutputToShape.Aux[Out, OutShape], 33 | val evOutputToShapeState: OutputToShape.Aux[State, StateShape] 34 | ) extends Layer[Tuple[Out, State], Tuple[Out, State]](name) { 35 | def createCellWithoutContext( 36 | mode: Mode, 37 | inputShape: OutShape 38 | ): ops.rnn.cell.RNNCell[Out, State, OutShape, StateShape] 39 | 40 | final def createCell( 41 | mode: Mode, 42 | inputShape: OutShape 43 | ): ops.rnn.cell.RNNCell[Out, State, OutShape, StateShape] = { 44 | if (name != null) { 45 | VariableScope.scope(name, isPure = true) { 46 | createCellWithoutContext(mode, inputShape) 47 | } 48 | } else { 49 | createCellWithoutContext(mode, inputShape) 50 | } 51 | } 52 | 53 | override final def forwardWithoutContext( 54 | input: Tuple[Out, State] 55 | )(implicit mode: Mode): Tuple[Out, State] = { 56 | createCellWithoutContext(mode, evOutputToShapeOut.shape(input.output)).forward(input) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/learn/layers/rnn/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.learn.layers 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object rnn { 22 | private[layers] trait API 23 | extends rnn.cell.API { 24 | type RNN[Out, State, OutShape, StateShape] = rnn.RNN[Out, State, OutShape, StateShape] 25 | type BidirectionalRNN[Out, State, OutShape, StateShape] = rnn.BidirectionalRNN[Out, State, OutShape, StateShape] 26 | 27 | val RNN : rnn.RNN.type = rnn.RNN 28 | val BidirectionalRNN: rnn.BidirectionalRNN.type = rnn.BidirectionalRNN 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Documentation.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** Groups together documentation related to constructing symbolic ops. 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | private[api] trait Documentation 23 | extends basic.Basic.Documentation 24 | with Cast.Documentation 25 | with Checks.Documentation 26 | with Clip.Documentation 27 | with Embedding.Documentation 28 | with Image.Documentation 29 | with Logging.Documentation 30 | with math.Math.Documentation 31 | with NN.Documentation 32 | with Parsing.Documentation 33 | with Random.Documentation 34 | with Sets.Documentation 35 | with Sparse.Documentation 36 | with Statistics.Documentation 37 | with Summary.Documentation 38 | with Text.Documentation 39 | with control_flow.ControlFlow.Documentation 40 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Input.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | import org.platanios.tensorflow.api.core.Graph 19 | import org.platanios.tensorflow.api.core.types.DataType 20 | import org.platanios.tensorflow.api.utilities.using 21 | import org.platanios.tensorflow.jni.{Op => NativeOp} 22 | 23 | /** Wrapper around an op meant to represent one of its inputs. Actual op inputs have type [[Output]] since they 24 | * represent outputs of other ops. Currently, [[Input]] is only useful for representing consumers of an [[Op]]'s 25 | * outputs. 26 | * 27 | * @param op Op whose input this class represents. 28 | * @param index Input index. 29 | * 30 | * @author Emmanouil Antonios Platanios 31 | */ 32 | final case class Input[T] private[ops]( 33 | op: UntypedOp, 34 | index: Int 35 | ) { 36 | /** Name of this op input. This is simply set to `":"`. */ 37 | lazy val name: String = { 38 | s"${op.name}:$index" 39 | } 40 | 41 | /** Data type of this op input. */ 42 | lazy val dataType: DataType[T] = { 43 | using(graph.reference) { r => 44 | DataType.fromCValue( 45 | NativeOp.inputDataType(r.nativeHandle, op.nativeHandle, index) 46 | ).asInstanceOf[DataType[T]] 47 | } 48 | } 49 | 50 | /** Graph where the op belongs. */ 51 | def graph: Graph = { 52 | op.graph 53 | } 54 | 55 | override def toString: String = { 56 | s"Op.Input(name = $name, dataType = $dataType)" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/OpSpecification.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | final case class OpSpecification( 22 | name: String, 23 | opType: String, 24 | device: String) 25 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/control_flow/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object control_flow { 22 | private[ops] trait API extends ControlFlow 23 | } 24 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/data/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object data { 22 | private[ops] trait API extends Data 23 | with data.DatasetIterator.API { 24 | type Dataset[T] = data.Dataset[T] 25 | type DatasetIterator[T] = data.DatasetIterator[T] 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/LookupTableInitializer.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.lookup 17 | 18 | import org.platanios.tensorflow.api.core.types.{DataType, TF} 19 | import org.platanios.tensorflow.api.ops.UntypedOp 20 | 21 | /** Lookup table initializer. 22 | * 23 | * @param keysDataType Data type of the table keys. 24 | * @param valuesDataType Data type of the table values. 25 | * 26 | * @author Emmanouil Antonios Platanios 27 | */ 28 | abstract class LookupTableInitializer[K: TF, V: TF]( 29 | val keysDataType: DataType[K], 30 | val valuesDataType: DataType[V] 31 | ) { 32 | /** Creates and returns an op that initializes the provided table. 33 | * 34 | * @param table Table to initialize. 35 | * @return Created initialization op for `table`. 36 | */ 37 | def initialize( 38 | table: InitializableLookupTable[K, V], 39 | name: String = "Initialize" 40 | )(implicit evVTF: TF[V]): UntypedOp 41 | } 42 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/LookupTableTensorInitializer.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.lookup 17 | 18 | import org.platanios.tensorflow.api.core.Graph 19 | import org.platanios.tensorflow.api.core.types.{Resource, TF} 20 | import org.platanios.tensorflow.api.ops.{Op, Output, UntypedOp} 21 | 22 | /** Lookup table initializer that uses the provided tensors (containing keys and corresponding values) for initializing 23 | * a lookup table. 24 | * 25 | * @param keys Tensor containing the table keys. 26 | * @param values Tensor containing the table values. 27 | * 28 | * @author Emmanouil Antonios Platanios 29 | */ 30 | class LookupTableTensorInitializer[K: TF, V: TF] protected ( 31 | val keys: Output[K], 32 | val values: Output[V] 33 | ) extends LookupTableInitializer(keys.dataType, values.dataType) { 34 | /** Creates and returns an op that initializes the provided table. 35 | * 36 | * @param table Table to initialize. 37 | * @return Created initialization op for `table`. 38 | */ 39 | override def initialize( 40 | table: InitializableLookupTable[K, V], 41 | name: String = "Initialize" 42 | )(implicit evVTF: TF[V]): UntypedOp = { 43 | Op.nameScope(name) { 44 | val initializationOp = Op.Builder[(Output[Resource], Output[K], Output[V]), Unit]( 45 | opType = "InitializeTableV2", 46 | name = name, 47 | input = (table.handle, keys, values) 48 | ).build() 49 | Op.currentGraph.addToCollection(Graph.Keys.TABLE_INITIALIZERS)(initializationOp.asUntyped) 50 | initializationOp.asUntyped 51 | } 52 | } 53 | } 54 | 55 | object LookupTableTensorInitializer { 56 | def apply[K: TF, V: TF]( 57 | keys: Output[K], 58 | values: Output[V] 59 | ): LookupTableTensorInitializer[K, V] = { 60 | new LookupTableTensorInitializer(keys, values) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/lookup/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object lookup { 22 | private[ops] trait API 23 | extends Lookup { 24 | type LookupTable[K, V] = lookup.LookupTable[K, V] 25 | type HashTable[K, V] = lookup.HashTable[K, V] 26 | type IDLookupTableWithHashBuckets[K] = lookup.IDLookupTableWithHashBuckets[K] 27 | 28 | val HashTable : lookup.HashTable.type = lookup.HashTable 29 | val IDLookupTableWithHashBuckets: lookup.IDLookupTableWithHashBuckets.type = lookup.IDLookupTableWithHashBuckets 30 | 31 | type LookupTableInitializer[K, V] = lookup.LookupTableInitializer[K, V] 32 | type LookupTableTensorInitializer[K, V] = lookup.LookupTableTensorInitializer[K, V] 33 | type LookupTableTextFileInitializer[K, V] = lookup.LookupTableTextFileInitializer[K, V] 34 | 35 | val LookupTableTensorInitializer: lookup.LookupTableTensorInitializer.type = { 36 | lookup.LookupTableTensorInitializer 37 | } 38 | 39 | val LookupTableTextFileInitializer: lookup.LookupTableTextFileInitializer.type = { 40 | lookup.LookupTableTextFileInitializer 41 | } 42 | 43 | type TextFileFieldExtractor[K] = lookup.TextFileFieldExtractor[K] 44 | 45 | val TextFileLineNumber: lookup.TextFileLineNumber.type = lookup.TextFileLineNumber 46 | val TextFileWholeLine : lookup.TextFileWholeLine.type = lookup.TextFileWholeLine 47 | val TextFileColumn : lookup.TextFileColumn.type = lookup.TextFileColumn 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/cell/RNNCell.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.rnn.cell 17 | 18 | import org.platanios.tensorflow.api.implicits.helpers.{OutputToShape, Zero} 19 | import org.platanios.tensorflow.api.ops.Output 20 | 21 | /** Contains functions for constructing ops related to recurrent neural network (RNN) cells. 22 | * 23 | * @author Emmanouil Antonios Platanios 24 | */ 25 | abstract class RNNCell[Out, State, OutShape, StateShape](implicit 26 | val evOutputToShapeOut: OutputToShape.Aux[Out, OutShape], 27 | val evOutputToShapeState: OutputToShape.Aux[State, StateShape] 28 | ) { 29 | def outputShape: OutShape 30 | def stateShape: StateShape 31 | 32 | def zeroOutput( 33 | batchSize: Output[Int], 34 | name: String = "ZeroOutput" 35 | )(implicit evZero: Zero.Aux[Out, OutShape]): Out = { 36 | evZero.zero(batchSize, outputShape, name) 37 | } 38 | 39 | def zeroState( 40 | batchSize: Output[Int], 41 | name: String = "ZeroState" 42 | )(implicit evZero: Zero.Aux[State, StateShape]): State = { 43 | evZero.zero(batchSize, stateShape, name) 44 | } 45 | 46 | @throws[IllegalArgumentException] 47 | def forward(input: Tuple[Out, State]): Tuple[Out, State] 48 | 49 | def apply(input: Tuple[Out, State]): Tuple[Out, State] = { 50 | forward(input) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object rnn { 22 | private[ops] trait API 23 | extends attention.API 24 | with cell.API 25 | with RNN 26 | } 27 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.training 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object optimizers { 22 | private[training] trait API 23 | extends schedules.API { 24 | type Optimizer = optimizers.Optimizer 25 | type AdaDelta = optimizers.AdaDelta 26 | type Adafactor = optimizers.Adafactor 27 | type AdaGrad = optimizers.AdaGrad 28 | type RMSProp = optimizers.RMSProp 29 | type Adam = optimizers.Adam 30 | type AMSGrad = optimizers.AMSGrad 31 | type GradientDescent = optimizers.GradientDescent 32 | type LazyAdam = optimizers.LazyAdam 33 | type LazyAMSGrad = optimizers.LazyAMSGrad 34 | type YellowFin = optimizers.YellowFin 35 | type Yogi = optimizers.Yogi 36 | 37 | val AdaDelta : optimizers.AdaDelta.type = optimizers.AdaDelta 38 | val Adafactor : optimizers.Adafactor.type = optimizers.Adafactor 39 | val AdaGrad : optimizers.AdaGrad.type = optimizers.AdaGrad 40 | val RMSProp : optimizers.RMSProp.type = optimizers.RMSProp 41 | val Adam : optimizers.Adam.type = optimizers.Adam 42 | val AMSGrad : optimizers.AMSGrad.type = optimizers.AMSGrad 43 | val GradientDescent: optimizers.GradientDescent.type = optimizers.GradientDescent 44 | val LazyAdam : optimizers.LazyAdam.type = optimizers.LazyAdam 45 | val LazyAMSGrad : optimizers.LazyAMSGrad.type = optimizers.LazyAMSGrad 46 | val YellowFin : optimizers.YellowFin.type = optimizers.YellowFin 47 | val Yogi : optimizers.Yogi.type = optimizers.Yogi 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/schedules/FixedSchedule.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.training.optimizers.schedules 17 | 18 | import org.platanios.tensorflow.api.core.types.{TF, IsIntOrLong} 19 | import org.platanios.tensorflow.api.ops.Output 20 | import org.platanios.tensorflow.api.ops.variables.Variable 21 | 22 | /** Dummy scheduling method representing no schedule being used. Useful as a default value for `Schedule`-valued 23 | * function arguments. 24 | * 25 | * @author Emmanouil Antonios Platanios 26 | */ 27 | case class FixedSchedule[T]() extends Schedule[T] { 28 | /** Applies the scheduling method to `value`, the current iteration in the optimization loop is `step` and returns the 29 | * result. 30 | * 31 | * @param value Value to change based on this schedule. 32 | * @param step Option containing current iteration in the optimization loop, if one has been provided. 33 | * @return Potentially modified value. 34 | * @throws IllegalArgumentException If the scheduling method requires a value for `step` but the provided option is 35 | * empty. 36 | */ 37 | @throws[IllegalArgumentException] 38 | override def apply[I: TF : IsIntOrLong]( 39 | value: Output[T], 40 | step: Option[Variable[I]] 41 | ): Output[T] = { 42 | value 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/optimizers/schedules/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.training.optimizers 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object schedules { 22 | private[optimizers] trait API { 23 | type Schedule[T] = schedules.Schedule[T] 24 | type FixedSchedule[T] = schedules.FixedSchedule[T] 25 | type CosineDecay = schedules.CosineDecay 26 | type CycleLinear10xDecay = schedules.CycleLinear10xDecay 27 | type ExponentialDecay = schedules.ExponentialDecay 28 | type LuongExponentialDecay = schedules.LuongExponentialDecay 29 | type RSqrtDecay = schedules.RSqrtDecay 30 | type WarmUpExponentialSchedule = schedules.WarmUpExponentialSchedule 31 | type WarmUpLinearSchedule = schedules.WarmUpLinearSchedule 32 | 33 | val FixedSchedule : schedules.FixedSchedule.type = schedules.FixedSchedule 34 | val CosineDecay : schedules.CosineDecay.type = schedules.CosineDecay 35 | val CycleLinear10xDecay : schedules.CycleLinear10xDecay.type = schedules.CycleLinear10xDecay 36 | val ExponentialDecay : schedules.ExponentialDecay.type = schedules.ExponentialDecay 37 | val LuongExponentialDecay : schedules.LuongExponentialDecay.type = schedules.LuongExponentialDecay 38 | val RSqrtDecay : schedules.RSqrtDecay.type = schedules.RSqrtDecay 39 | val WarmUpExponentialSchedule: schedules.WarmUpExponentialSchedule.type = schedules.WarmUpExponentialSchedule 40 | val WarmUpLinearSchedule : schedules.WarmUpLinearSchedule.type = schedules.WarmUpLinearSchedule 41 | 42 | // TODO: Piecewise constant. 43 | // TODO: Polynomial. 44 | // TODO: Natural exp. 45 | // TODO: Inverse time. 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/training/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object training { 22 | private[ops] trait API extends optimizers.API { 23 | type ExponentialMovingAverage = training.ExponentialMovingAverage 24 | val ExponentialMovingAverage: training.ExponentialMovingAverage.type = training.ExponentialMovingAverage 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/Regularizer.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.variables 17 | 18 | import org.platanios.tensorflow.api.ops.Output 19 | 20 | /** A variable regularizer is simply a function that takes a tensor representing the variable value as input, and 21 | * returns another tensor representing the regularizer value as output. 22 | * 23 | * @author Emmanouil Antonios Platanios 24 | */ 25 | trait Regularizer { 26 | def apply[T](value: Output[T]): Output[T] 27 | } 28 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/Reuse.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.variables 17 | 18 | /** Enumeration of possible variable reuse options, used by variable scopes and variable stores. 19 | * 20 | * The supported options are: 21 | * - [[ReuseExistingOnly]]: Reuse existing variables only and throw an exception if no appropriate variable exists. 22 | * - [[CreateNewOnly]]: Create new variables only and throw an exception if a variable with the same name exists. 23 | * - [[ReuseOrCreateNew]]: Reuse existing variables or create new ones, if no variable with the provided name exists. 24 | * 25 | * @author Emmanouil Antonios Platanios 26 | */ 27 | sealed trait Reuse 28 | 29 | /** Trait marking the variable reuse modes that allow reusing existing variables. */ 30 | sealed trait ReuseAllowed extends Reuse 31 | 32 | /** Reuse existing variables only and throw an exception if no appropriate variable exists. */ 33 | case object ReuseExistingOnly extends ReuseAllowed 34 | 35 | /** Create new variables only and throw an exception if a variable with the same name exists. */ 36 | case object CreateNewOnly extends Reuse 37 | 38 | /** Reuse existing variables or create new ones, if no variable with the provided name exists. */ 39 | case object ReuseOrCreateNew extends ReuseAllowed 40 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/ops/variables/VariableScopeStore.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.variables 17 | 18 | import org.platanios.tensorflow.api.ops.Op 19 | 20 | /** A thread-local score for the current variable scope and scope counts. 21 | * 22 | * @author Emmanouil Antonios Platanios 23 | */ 24 | case class VariableScopeStore private[api]() { 25 | private[api] var scope: VariableScope = { 26 | VariableScope(CreateNewOnly) 27 | } 28 | 29 | /** Map with variable scope names as keys and the corresponding use counts as values. */ 30 | private[api] var variableScopeCounts: Map[String, Int] = { 31 | Map.empty[String, Int] 32 | } 33 | 34 | private[api] def enterVariableScope(scope: String): Unit = { 35 | variableScopeCounts += scope -> (variableScopeCounts.getOrElse(scope, 0) + 1) 36 | } 37 | 38 | private[api] def closeVariableSubScopes(scope: String): Unit = { 39 | variableScopeCounts.keySet.filter(_.startsWith(s"$scope/")).foreach(variableScopeCounts -= _) 40 | } 41 | 42 | /** Returns the use count of the provided scope in this variable store. 43 | * 44 | * @param scope Variable scope name. 45 | * @return Number of usages of the provided variable scope name, in this variable store. 46 | */ 47 | private[api] def variableScopeCount(scope: String): Int = { 48 | variableScopeCounts.getOrElse(scope, 0) 49 | } 50 | } 51 | 52 | object VariableScopeStore { 53 | def current: VariableScopeStore = { 54 | Op.currentGraph.variableScopeStore.value 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Cast.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.tensors.ops 17 | 18 | import org.platanios.tensorflow.api.core.types._ 19 | import org.platanios.tensorflow.api.tensors._ 20 | import org.platanios.tensorflow.jni.generated.tensors.{Math => NativeTensorOpsMath} 21 | 22 | /** Contains functions for executing cast-related ops. 23 | * 24 | * @author Emmanouil Antonios Platanios 25 | */ 26 | trait Cast { 27 | /** $OpDocCastCast 28 | * 29 | * @group CastOps 30 | * 31 | * @param input Tensor to cast. 32 | * @tparam R Target data type. 33 | * @return Result as a new tensor. 34 | */ 35 | private[tensors] def cast[T, R: TF, TL[TT] <: TensorLike[TT]]( 36 | input: TL[T], 37 | truncate: Boolean = false 38 | )(implicit ev: TensorOps.Aux[TL, T]): TL[R] = { 39 | val dataType = implicitly[TF[R]].dataType 40 | if (input.dataType == dataType) { 41 | input.asInstanceOf[TL[R]] 42 | } else { 43 | ev.applyUnary(input, t => { 44 | Tensor.fromNativeHandle[R](NativeTensorOpsMath.cast( 45 | executionContext.value.nativeHandle, t.nativeHandle, dataType.cValue, truncate)) 46 | }) 47 | } 48 | } 49 | 50 | // TODO: [OPS] saturateCast 51 | 52 | /** $OpDocCastBitcast 53 | * 54 | * @group CastOps 55 | * 56 | * @param input Input tensor. 57 | * @tparam R Target data type. 58 | * @return Result as a new tensor. 59 | */ 60 | private[tensors] def bitcast[T: IsNumeric, R: TF, TL[TT] <: TensorLike[TT]]( 61 | input: TL[T] 62 | )(implicit ev: TensorOps.Aux[TL, T]): TL[R] = { 63 | val dataType = implicitly[TF[R]].dataType 64 | ev.applyUnary(input, t => { 65 | Tensor.fromNativeHandle[R](NativeTensorOpsMath.bitcast( 66 | executionContext.value.nativeHandle, t.nativeHandle, dataType.cValue)) 67 | }) 68 | } 69 | } 70 | 71 | object Cast extends Cast 72 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Random.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.tensors.ops 17 | 18 | import org.platanios.tensorflow.api.core.types.TF 19 | import org.platanios.tensorflow.api.ops.Op 20 | import org.platanios.tensorflow.api.tensors._ 21 | import org.platanios.tensorflow.jni.generated.tensors.{Random => NativeTensorOpsRandom} 22 | 23 | /** Contains functions for executing ops related to random numbers and tensors. 24 | * 25 | * @author Emmanouil Antonios Platanios 26 | */ 27 | trait Random { 28 | /** $OpDocRandomRandomShuffle 29 | * 30 | * @group RandomOps 31 | * @param value Tensor to be shuffled. 32 | * @param seed Optional random seed, used to generate a random seed pair for the random number generator, when 33 | * combined with the graph-level seed. 34 | * @return Result as a new tensor. 35 | */ 36 | def randomShuffle[T: TF]( 37 | value: Tensor[T], 38 | seed: Option[Int] = None 39 | ): Tensor[T] = { 40 | val (graphSeed, opSeed) = Op.currentGraphRandomSeed(seed) 41 | Tensor.fromNativeHandle[T](NativeTensorOpsRandom.randomShuffle( 42 | executionContext.value.nativeHandle, value.nativeHandle, 43 | graphSeed.getOrElse(0).toLong, opSeed.getOrElse(0).toLong)) 44 | } 45 | } 46 | 47 | object Random extends Random 48 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.tensors 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object ops { 22 | private[api] trait API 23 | extends Basic 24 | with Cast 25 | with Math 26 | with NN 27 | with Random 28 | } 29 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api 17 | 18 | import scala.util.DynamicVariable 19 | 20 | /** 21 | * @author Emmanouil Antonios Platanios 22 | */ 23 | package object tensors { 24 | private[api] val executionContext: DynamicVariable[Context] = { 25 | new DynamicVariable[Context](Context(Some(core.defaultSessionConfig))) 26 | } 27 | 28 | private[api] trait API 29 | extends tensors.ops.API 30 | } 31 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/Collections.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.utilities 17 | 18 | /** Contains helper functions for manipulating collections. 19 | * 20 | * @author Emmanouil Antonios Platanios 21 | */ 22 | object Collections { 23 | /** Segments a sequence according to a provided sequence of segment lengths. 24 | * 25 | * For example. 26 | * {{{ 27 | * val xs = Seq(3, 5, 2, 77, 12, 45, 78, 21, 89, 1, 0, -1, 123) 28 | * val n = Seq(3, 1, 0, 2, 5, 2) 29 | * segment(xs, n) = Seq(Seq(3, 5, 2), Seq(77), Seq(), Seq(12, 45), Seq(78, 21, 89, 1, 0), Seq(-1, 123)) 30 | * }}} 31 | * 32 | * Note that the function returns when either one of `xs` or `n` is exhausted. This means that no exception is thrown 33 | * if the provided segment lengths do not match the original sequence length. 34 | * 35 | * @param xs Sequence to segment. 36 | * @param n Segment lengths. 37 | * @return Sequence containing the segments of `xs`. 38 | */ 39 | def segment[V](xs: Seq[V], n: Seq[Int]): Seq[Seq[V]] = { 40 | if (xs.isEmpty && n.isEmpty) { 41 | Nil 42 | } else { 43 | val (ys, zs) = xs.splitAt(n.head) 44 | ys +: segment(zs, n.tail) 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/Proto.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.utilities 17 | 18 | import org.platanios.tensorflow.api.io.FileIO 19 | 20 | import com.google.protobuf.GeneratedMessageV3 21 | 22 | import java.nio.file.{Files, Path} 23 | 24 | /** Contains helper functions for working with ProtoBuf. 25 | * 26 | * @author Emmanouil Antonios Platanios 27 | */ 28 | object Proto { 29 | /** Writes `message` to the specified file. 30 | * 31 | * @param directory Directory in which to write the file. 32 | * @param filename Name of the file. 33 | * @param message ProtoBuf message to write. 34 | * @param asText Boolean value indicating whether to serialize the ProtoBuf message in the human-friendly text 35 | * format, or in the more efficient binary format. 36 | * @return Path of the written file. 37 | */ 38 | def write(directory: Path, filename: String, message: GeneratedMessageV3, asText: Boolean = false): Path = { 39 | // GCS does not have the concept of a directory at the moment. 40 | if (!Files.exists(directory) && !directory.startsWith("gs:")) { 41 | Files.createDirectories(directory) 42 | } 43 | val filePath = directory.resolve(filename) 44 | if (asText) 45 | FileIO.writeStringToFileAtomic(filePath, message.toString) 46 | else 47 | message.writeTo(Files.newOutputStream(filePath)) 48 | filePath 49 | } 50 | 51 | /** Trait that all ProtoBuf-serializable objects should extend. */ 52 | trait Serializable { 53 | /** Converts this object to its corresponding ProtoBuf object. 54 | * 55 | * @return ProtoBuf object corresponding to this object. 56 | */ 57 | def toProto: GeneratedMessageV3 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /modules/api/src/main/scala/org/platanios/tensorflow/api/utilities/package.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | package object utilities { 22 | trait Closeable { 23 | protected val closeFn: () => Unit 24 | 25 | /** Releases the native resources associated with this object. */ 26 | def close(): Unit = closeFn() 27 | } 28 | 29 | def using[T <: Closeable, R](resource: T)(block: T => R): R = { 30 | try { 31 | block(resource) 32 | } finally { 33 | if (resource != null) 34 | resource.close() 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /modules/api/src/test/scala/org/platanios/tensorflow/api/core/SessionSpec.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.core 17 | 18 | import org.platanios.tensorflow.api._ 19 | 20 | import org.scalatest.flatspec.AnyFlatSpec 21 | import org.scalatest.matchers.should.Matchers 22 | 23 | /** 24 | * @author Emmanouil Antonios Platanios 25 | */ 26 | class SessionSpec extends AnyFlatSpec with Matchers { 27 | "Session run fetch by name" should "return the correct result" in { 28 | val graph = Graph() 29 | tf.createWith(graph = graph) { 30 | val a = tf.constant(Tensor(Tensor(2, 3)), name = "A") 31 | val x = tf.placeholder[Int](Shape(1, 2), name = "X") 32 | tf.subtract(tf.constant(1), tf.matmul(a = a, b = x, transposeB = true), name = "Y") 33 | } 34 | val session = Session(graph = graph) 35 | val feeds = Map(graph.getOutputByName("X:0").asInstanceOf[Output[Int]] -> Tensor(Tensor(5, 7))) 36 | val fetches = graph.getOutputByName("Y:0").asInstanceOf[Output[Int]] 37 | val output = session.run(feeds, fetches) 38 | val expectedResult = Tensor(Tensor(-30)) 39 | assert(output.scalar == expectedResult.scalar) 40 | graph.close() 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /modules/api/src/test/scala/org/platanios/tensorflow/api/ops/BasicSpec.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | import org.scalatest.flatspec.AnyFlatSpec 19 | import org.scalatest.matchers.should.Matchers 20 | 21 | /** 22 | * @author Emmanouil Antonios Platanios 23 | */ 24 | class BasicSpec extends AnyFlatSpec with Matchers { 25 | // "'ArrayOps.constant'" must "create a constant op when provided a Tensor of the same data type and shape" in { 26 | // // DataType.Int32 Tensor 27 | // val tensor1 = Tensor(Tensor(Tensor(2, 3), Tensor(0, 0), Tensor(5, 7)), 28 | // Tensor(Tensor(1, 23), Tensor(4, -5), Tensor(7, 9)), 29 | // Tensor(Tensor(56, 1), Tensor(-2, -4), Tensor(-7, -9))) 30 | // val constant1 = ArrayOps.constant(tensor1) 31 | // val constantValue1 = constant1.value() 32 | // assert(tensor1.get(1, 1, 1) === -5) 33 | // assert(constant1.shape === Shape(3, 3, 2)) 34 | // assert(constant1.dataType === DataType.Int32) 35 | // assert(constantValue1.shape === Shape(3, 3, 2)) 36 | // assert(constantValue1.dataType === DataType.Int32) 37 | // assert(constantValue1(1, 1, 1) === -5) 38 | // tensor1.close() 39 | // 40 | // // DataType.Float64 Tensor 41 | // val tensor2 = Tensor.create(array, dataType = DataType.Float64) 42 | // val constant2 = ArrayOps.constant(tensor2) 43 | // val constantValue2 = constant2.value() 44 | // assert(tensor2(1, 1, 1) === -5.0) 45 | // assert(constant2.shape === Shape(3, 3, 2)) 46 | // assert(constant2.dataType === DataType.Float64) 47 | // assert(constantValue2.shape === Shape(3, 3, 2)) 48 | // assert(constantValue2.dataType === DataType.Float64) 49 | // assert(constantValue2(1, 1, 1) === -5.0) 50 | // tensor2.close() 51 | // } 52 | } 53 | -------------------------------------------------------------------------------- /modules/api/src/test/scala/org/platanios/tensorflow/api/ops/NNSpec.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops 17 | 18 | import org.platanios.tensorflow.api._ 19 | 20 | import org.scalatest.matchers.should.Matchers 21 | import org.scalatestplus.junit.JUnitSuite 22 | import org.junit.Test 23 | 24 | /** 25 | * @author Emmanouil Antonios Platanios 26 | */ 27 | class NNSpec extends JUnitSuite with Matchers { 28 | @Test def testLogSoftmax(): Unit = { 29 | val tensor = Tensor(Tensor(Tensor(2, 3), Tensor(0, 0), Tensor(5, 7)), 30 | Tensor(Tensor(1, 23), Tensor(4, -5), Tensor(7, 9)), 31 | Tensor(Tensor(56, 1), Tensor(-2, -4), Tensor(-7, -9))) 32 | val constant = tf.constant(tensor).toFloat 33 | val logSoftmaxLastAxis = tf.logSoftmax(constant, axis = -1) 34 | val logSoftmaxPenultimateAxis = tf.logSoftmax(constant, axis = 1) 35 | val session = Session() 36 | assertApproximatelyEqual( 37 | session.run(fetches = logSoftmaxLastAxis).toArray, 38 | Array( 39 | -1.3132616f, -0.31326163f, -0.6931472f, -0.6931472f, -2.126928f, -0.12692805f, 40 | -22.0f, 0.0f, -1.23374e-4f, -9.000123f, -2.126928f, -0.12692805f, 0.0f, -55.0f, 41 | -0.12692805f, -2.126928f, -0.12692805f, -2.126928f, 42 | ), 43 | ) 44 | assertApproximatelyEqual( 45 | session.run(fetches = logSoftmaxPenultimateAxis).toArray, 46 | Array( 47 | -3.0549853f, -4.019045f, -5.054985f, -7.019045f, -0.054985214f, -0.019044992f, 48 | -6.0509458f, -8.344647e-7f, -3.0509458f, -28.0f, -0.05094571f, -14.000001f, 0.0f, 49 | -0.0067604627f, -58.0f, -5.0067606f, -63.0f, -10.006761f, 50 | ), 51 | ) 52 | } 53 | 54 | def assertApproximatelyEqual(x: Array[Float], y: Array[Float]): Unit = { 55 | x.zip(y).foreach { case (xElement, yElement) => 56 | assert(xElement === yElement +- 1e-6f) 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /modules/api/src/test/scala/org/platanios/tensorflow/api/ops/data/FilterDatasetSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.ops.data 17 | 18 | import org.platanios.tensorflow.api.core.client.Session 19 | import org.platanios.tensorflow.api.core.{Graph, Shape} 20 | import org.platanios.tensorflow.api.implicits.Implicits._ 21 | import org.platanios.tensorflow.api.ops.Op 22 | import org.platanios.tensorflow.api.ops.math.Math 23 | import org.platanios.tensorflow.api.tensors.Tensor 24 | import org.platanios.tensorflow.api.utilities.using 25 | 26 | import org.junit.Test 27 | import org.scalatestplus.junit.JUnitSuite 28 | 29 | /** 30 | * @author Emmanouil Antonios Platanios 31 | */ 32 | class FilterDatasetSuite extends JUnitSuite { 33 | @Test def testFilterRange(): Unit = using(Graph()) { graph => 34 | Op.createWith(graph) { 35 | val dataset = Data.datasetFromRange(0, 100).filter(x => { 36 | Math.notEqual(Math.mod(x, 3L), 2L) 37 | }) 38 | val iterator = dataset.createInitializableIterator() 39 | val initOp = iterator.initializer 40 | val nextOutput = iterator.next() 41 | assert(nextOutput.shape == Shape.scalar()) 42 | val session = Session() 43 | session.run(targets = initOp) 44 | assert(session.run(fetches = nextOutput) == (0L: Tensor[Long])) 45 | assert(session.run(fetches = nextOutput) == (1L: Tensor[Long])) 46 | assert(session.run(fetches = nextOutput) == (3L: Tensor[Long])) 47 | assert(session.run(fetches = nextOutput) == (4L: Tensor[Long])) 48 | assert(session.run(fetches = nextOutput) == (6L: Tensor[Long])) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /modules/api/src/test/scala/org/platanios/tensorflow/api/utilities/CRC32CSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.api.utilities 17 | 18 | import org.scalatestplus.junit.JUnitSuite 19 | import org.junit.Test 20 | 21 | /** 22 | * @author Emmanouil Antonios Platanios 23 | */ 24 | class CRC32CSuite extends JUnitSuite { 25 | @Test def testValue(): Unit = { 26 | // From rfc3720 section B.4. 27 | assert(CRC32C.value(Array.fill[Byte](32)(0x00.toByte)) === 0x8a9136aa) 28 | assert(CRC32C.value(Array.fill[Byte](32)(0xff.toByte)) === 0x62a8ab43) 29 | assert(CRC32C.value((0 until 32).map(_.toByte).toArray) === 0x46dd794e) 30 | assert(CRC32C.value((31 to 0 by -1).map(_.toByte).toArray) === 0x113fdb5c) 31 | 32 | val bytes = Array( 33 | 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 34 | 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 35 | 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, 36 | 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 37 | ).map(_.toByte) 38 | assert(CRC32C.value(bytes) === 0xd9963a56) 39 | 40 | assert(CRC32C.value("a") !== CRC32C.value("foo")) 41 | } 42 | 43 | @Test def testExtend(): Unit = { 44 | assert(CRC32C.value("hello world") === CRC32C.extend(CRC32C.value("hello "), "world")) 45 | } 46 | 47 | @Test def testMask(): Unit = { 48 | val crc = CRC32C.value("foo") 49 | assert(crc !== CRC32C.mask(crc)) 50 | assert(crc !== CRC32C.mask(CRC32C.mask(crc))) 51 | assert(crc === CRC32C.unmask(CRC32C.mask(crc))) 52 | assert(crc === CRC32C.unmask(CRC32C.unmask(CRC32C.mask(CRC32C.mask(crc))))) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /modules/data/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /modules/data/src/main/scala/org/platanios/tensorflow/data/utilities/CompressedFiles.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.data.utilities 17 | 18 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream 19 | import org.apache.commons.compress.utils.IOUtils 20 | 21 | import java.io.{File, FileOutputStream, InputStream} 22 | import java.nio.file.{Files, Path} 23 | import java.util.zip.GZIPInputStream 24 | 25 | /** 26 | * @author Emmanouil Antonios Platanios 27 | */ 28 | object CompressedFiles { 29 | def decompressTGZ(tgzFilePath: Path, destinationPath: Path, bufferSize: Int = 8192): Unit = { 30 | decompressTGZStream(Files.newInputStream(tgzFilePath), destinationPath, bufferSize) 31 | } 32 | 33 | def decompressTar(tarFilePath: Path, destinationPath: Path, bufferSize: Int = 8192): Unit = { 34 | decompressTarStream(Files.newInputStream(tarFilePath), destinationPath, bufferSize) 35 | } 36 | 37 | def decompressTGZStream(tgzStream: InputStream, destinationPath: Path, bufferSize: Int = 8192): Unit = { 38 | decompressTarStream(new GZIPInputStream(tgzStream), destinationPath, bufferSize) 39 | } 40 | 41 | def decompressTarStream(tarStream: InputStream, destinationPath: Path, bufferSize: Int = 8192): Unit = { 42 | val inputStream = new TarArchiveInputStream(tarStream) 43 | var entry = inputStream.getNextTarEntry 44 | while (entry != null) { 45 | if (!entry.isDirectory) { 46 | val currentFile = new File(destinationPath.toAbsolutePath.toString, entry.getName) 47 | val parentFile = currentFile.getParentFile 48 | if (!parentFile.exists) 49 | parentFile.mkdirs() 50 | IOUtils.copy(inputStream, new FileOutputStream(currentFile)) 51 | } 52 | entry = inputStream.getNextTarEntry 53 | } 54 | inputStream.close() 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /modules/data/src/test/scala/org/platanios/tensorflow/data/image/MNISTLoaderSpec.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.data.image 17 | 18 | import org.scalatest.flatspec.AnyFlatSpec 19 | 20 | /** 21 | * @author Emmanouil Antonios Platanios 22 | */ 23 | class MNISTLoaderSpec extends AnyFlatSpec { 24 | // val directory: Path = Paths.get("/Users/Anthony/Development/GitHub/tensorflow_scala/temp/data/mnist") 25 | // Files.createDirectories(directory) 26 | 27 | "The MNIST data set loader" must "work" in { 28 | // val dataSet = MNISTLoader.load(directory) 29 | // val label0 = dataSet.trainLabels.summarize(10) 30 | // print(dataSet) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /modules/examples/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "linear-regression" 2 | all_model_checkpoint_paths: "virgin-linear-regression" 3 | all_model_checkpoint_paths: "linear-regression" 4 | -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/linear-regression.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.data-00000-of-00001 -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/linear-regression.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.index -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/linear-regression.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/linear-regression.meta -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/virgin-linear-regression.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.data-00000-of-00001 -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/virgin-linear-regression.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.index -------------------------------------------------------------------------------- /modules/examples/src/main/resources/python2scala/virgin-linear-regression.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eaplatanios/tensorflow_scala/39053eefe0859bde4c48c8da20212d6919deb38e/modules/examples/src/main/resources/python2scala/virgin-linear-regression.meta -------------------------------------------------------------------------------- /modules/jni/src/main/native/checkpoint_reader.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_CheckpointReader__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_CheckpointReader__ 6 | #define _Included_org_platanios_tensorflow_jni_CheckpointReader__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | /* 11 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 12 | * Method: newCheckpointReader 13 | * Signature: (Ljava/lang/String;)J 14 | */ 15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_newCheckpointReader 16 | (JNIEnv *, jobject, jstring); 17 | 18 | /* 19 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 20 | * Method: debugString 21 | * Signature: (J)Ljava/lang/String; 22 | */ 23 | JNIEXPORT jstring JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_debugString 24 | (JNIEnv *, jobject, jlong); 25 | 26 | /* 27 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 28 | * Method: hasTensor 29 | * Signature: (JLjava/lang/String;)B 30 | */ 31 | JNIEXPORT jboolean JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_hasTensor 32 | (JNIEnv *, jobject, jlong, jstring); 33 | 34 | /* 35 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 36 | * Method: getTensor 37 | * Signature: (JLjava/lang/String;)J 38 | */ 39 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_getTensor 40 | (JNIEnv *, jobject, jlong, jstring); 41 | 42 | /* 43 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 44 | * Method: variableShapes 45 | * Signature: (J)Lorg/platanios/tensorflow/jni/VariableShapes; 46 | */ 47 | JNIEXPORT jobject JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_variableShapes 48 | (JNIEnv *, jobject, jlong); 49 | 50 | /* 51 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 52 | * Method: variableDataTypes 53 | * Signature: (J)Lorg/platanios/tensorflow/jni/VariableDataTypes; 54 | */ 55 | JNIEXPORT jobject JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_variableDataTypes 56 | (JNIEnv *, jobject, jlong); 57 | 58 | /* 59 | * Class: org_platanios_tensorflow_jni_CheckpointReader__ 60 | * Method: delete 61 | * Signature: (J)V 62 | */ 63 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_CheckpointReader_00024_delete 64 | (JNIEnv *, jobject, jlong); 65 | 66 | #ifdef __cplusplus 67 | } 68 | #endif 69 | #endif 70 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/function.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_Function__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_Function__ 6 | #define _Included_org_platanios_tensorflow_jni_Function__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | /* 12 | * Class: org_platanios_tensorflow_jni_Function__ 13 | * Method: graphToFunction 14 | * Signature: (JLjava/lang/String;Z[J[J[I[J[I[Ljava/lang/String;)J 15 | */ 16 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Function_00024_graphToFunction 17 | (JNIEnv *, jobject, jlong, jstring, jboolean, jlongArray, jlongArray, jintArray, jlongArray, jintArray, jobjectArray); 18 | 19 | /* 20 | * Class: org_platanios_tensorflow_jni_Function__ 21 | * Method: copyToGraph 22 | * Signature: (JJJ)V 23 | */ 24 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Function_00024_copyToGraph 25 | (JNIEnv *, jobject, jlong, jlong, jlong); 26 | 27 | /* 28 | * Class: org_platanios_tensorflow_jni_Function__ 29 | * Method: toFunctionDef 30 | * Signature: (J)[B 31 | */ 32 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Function_00024_toFunctionDef 33 | (JNIEnv *, jobject, jlong); 34 | 35 | /* 36 | * Class: org_platanios_tensorflow_jni_Function__ 37 | * Method: delete 38 | * Signature: (J)V 39 | */ 40 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Function_00024_delete 41 | (JNIEnv *, jobject, jlong); 42 | 43 | #ifdef __cplusplus 44 | } 45 | #endif 46 | #endif 47 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/generated/tensor_random_ops.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_generated_tensors_Random__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Random__ 6 | #define _Included_org_platanios_tensorflow_jni_generated_tensors_Random__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | /* 11 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__ 12 | * Method: randomShuffle 13 | * Signature: (JJJJ)J 14 | */ 15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomShuffle 16 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong); 17 | 18 | /* 19 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__ 20 | * Method: randomUniform 21 | * Signature: (JJIJJ)J 22 | */ 23 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomUniform 24 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong); 25 | 26 | /* 27 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__ 28 | * Method: randomUniformInt 29 | * Signature: (JJJJJJ)J 30 | */ 31 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomUniformInt 32 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jlong, jlong); 33 | 34 | /* 35 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__ 36 | * Method: randomStandardNormal 37 | * Signature: (JJIJJ)J 38 | */ 39 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_randomStandardNormal 40 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong); 41 | 42 | /* 43 | * Class: org_platanios_tensorflow_jni_generated_tensors_Random__ 44 | * Method: truncatedNormal 45 | * Signature: (JJIJJ)J 46 | */ 47 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Random_00024_truncatedNormal 48 | (JNIEnv *, jobject, jlong, jlong, jint, jlong, jlong); 49 | 50 | #ifdef __cplusplus 51 | } 52 | #endif 53 | #endif 54 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/generated/tensor_sparse_ops.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_generated_tensors_Sparse__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Sparse__ 6 | #define _Included_org_platanios_tensorflow_jni_generated_tensors_Sparse__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | /* 11 | * Class: org_platanios_tensorflow_jni_generated_tensors_Sparse__ 12 | * Method: sparseToDense 13 | * Signature: (JJJJJZ)J 14 | */ 15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Sparse_00024_sparseToDense 16 | (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jlong, jboolean); 17 | 18 | #ifdef __cplusplus 19 | } 20 | #endif 21 | #endif 22 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/graph.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_Graph__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_Graph__ 6 | #define _Included_org_platanios_tensorflow_jni_Graph__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | /* 12 | * Class: org_platanios_tensorflow_jni_Graph__ 13 | * Method: allocate 14 | * Signature: ()J 15 | */ 16 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_allocate 17 | (JNIEnv *, jobject); 18 | 19 | /* 20 | * Class: org_platanios_tensorflow_jni_Graph__ 21 | * Method: delete 22 | * Signature: (J)V 23 | */ 24 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_delete 25 | (JNIEnv *, jobject, jlong); 26 | 27 | /* 28 | * Class: org_platanios_tensorflow_jni_Graph__ 29 | * Method: findOp 30 | * Signature: (JLjava/lang/String;)J 31 | */ 32 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_findOp 33 | (JNIEnv *, jobject, jlong, jstring); 34 | 35 | /* 36 | * Class: org_platanios_tensorflow_jni_Graph__ 37 | * Method: ops 38 | * Signature: (J)[J 39 | */ 40 | JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_ops 41 | (JNIEnv *, jobject, jlong); 42 | 43 | /* 44 | * Class: org_platanios_tensorflow_jni_Graph__ 45 | * Method: addGradients 46 | * Signature: (J[Lorg/platanios/tensorflow/jni/Output;[Lorg/platanios/tensorflow/jni/Output;[Lorg/platanios/tensorflow/jni/Output;)[Lorg/platanios/tensorflow/jni/Output; 47 | */ 48 | JNIEXPORT jobjectArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_addGradients 49 | (JNIEnv *, jobject, jlong, jobjectArray, jobjectArray, jobjectArray); 50 | 51 | /* 52 | * Class: org_platanios_tensorflow_jni_Graph__ 53 | * Method: importGraphDef 54 | * Signature: (J[BLjava/lang/String;[Ljava/lang/String;[I[J[I[Ljava/lang/String;[J[J)V 55 | */ 56 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_importGraphDef 57 | (JNIEnv *, jobject, jlong, jbyteArray, jstring, jobjectArray, jintArray, jlongArray, jintArray, jobjectArray, jlongArray, jlongArray); 58 | 59 | /* 60 | * Class: org_platanios_tensorflow_jni_Graph__ 61 | * Method: toGraphDef 62 | * Signature: (J)[B 63 | */ 64 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Graph_00024_toGraphDef 65 | (JNIEnv *, jobject, jlong); 66 | 67 | #ifdef __cplusplus 68 | } 69 | #endif 70 | #endif 71 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/include/tensorflow/c/c_api_macros.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_C_C_API_MACROS_H_ 17 | #define TENSORFLOW_C_C_API_MACROS_H_ 18 | 19 | #ifdef SWIG 20 | #define TF_CAPI_EXPORT 21 | #else 22 | #if defined(_WIN32) 23 | #ifdef TF_COMPILE_LIBRARY 24 | #define TF_CAPI_EXPORT __declspec(dllexport) 25 | #else 26 | #define TF_CAPI_EXPORT __declspec(dllimport) 27 | #endif // TF_COMPILE_LIBRARY 28 | #else 29 | #define TF_CAPI_EXPORT __attribute__((visibility("default"))) 30 | #endif // _WIN32 31 | #endif // SWIG 32 | 33 | // TF_Bool is the C API typedef for unsigned char, while TF_BOOL is 34 | // the datatype for boolean tensors. 35 | #ifndef TF_Bool 36 | #define TF_Bool unsigned char 37 | #endif // TF_Bool 38 | 39 | // Macro used to calculate struct size for maintaining ABI stability across 40 | // different struct implementations. 41 | #ifndef TF_OFFSET_OF_END 42 | #define TF_OFFSET_OF_END(TYPE, MEMBER) \ 43 | (offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER)) 44 | #endif // TF_OFFSET_OF_END 45 | 46 | #endif // TENSORFLOW_C_C_API_MACROS_H_ 47 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/include/tensorflow/c/eager/dlpack.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_C_EAGER_DLPACK_H_ 17 | #define TENSORFLOW_C_EAGER_DLPACK_H_ 18 | 19 | #include "tensorflow/c/eager/c_api.h" 20 | 21 | namespace tensorflow { 22 | 23 | // PyCapsule name for DLPack Tensor 24 | const char* const kDlTensorCapsuleName = "dltensor"; 25 | 26 | // Converts eager tensor handle to DLPack (DLManagedTensor*), and return the 27 | // void* for further PyCapsule construction. 28 | TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, 29 | TF_Status* status); 30 | 31 | // Converts DLPack (DLManagedTensor*) to eager tensor handle. 32 | TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, 33 | TF_Status* status, 34 | TFE_Context* ctx); 35 | 36 | // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. 37 | TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); 38 | } // namespace tensorflow 39 | 40 | #endif // TENSORFLOW_C_EAGER_DLPACK_H_ 41 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/include/tensorflow/c/tf_attrtype.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | #ifndef TENSORFLOW_C_TF_ATTRTYPE_H_ 16 | #define TENSORFLOW_C_TF_ATTRTYPE_H_ 17 | 18 | #ifdef __cplusplus 19 | extern "C" { 20 | #endif 21 | 22 | // TF_AttrType describes the type of the value of an attribute on an operation. 23 | typedef enum TF_AttrType { 24 | TF_ATTR_STRING = 0, 25 | TF_ATTR_INT = 1, 26 | TF_ATTR_FLOAT = 2, 27 | TF_ATTR_BOOL = 3, 28 | TF_ATTR_TYPE = 4, 29 | TF_ATTR_SHAPE = 5, 30 | TF_ATTR_TENSOR = 6, 31 | TF_ATTR_PLACEHOLDER = 7, 32 | TF_ATTR_FUNC = 8, 33 | } TF_AttrType; 34 | 35 | #ifdef __cplusplus 36 | } /* end extern "C" */ 37 | #endif 38 | 39 | #endif // TENSORFLOW_C_TF_ATTRTYPE_H_ 40 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/include/tensorflow/c/tf_file_statistics.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_C_TF_FILE_STATISTICS_H_ 17 | #define TENSORFLOW_C_TF_FILE_STATISTICS_H_ 18 | 19 | #include 20 | 21 | typedef struct TF_FileStatistics { 22 | // The length of the file in bytes. 23 | int64_t length; 24 | // The last modified time in nanoseconds. 25 | int64_t mtime_nsec; 26 | // Whether the name refers to a directory. 27 | bool is_directory; 28 | } TF_FileStatistics; 29 | 30 | // TODO(mihaimaruseac): `tensorflow::FileStatistics` from 31 | // `core/platform/file_statistics.h` is a duplicate of this so maybe try to 32 | // remove duplication later? 33 | 34 | #endif // TENSORFLOW_C_TF_FILE_STATISTICS_H_ 35 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/include/tensorflow/c/tf_tstring.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | #ifndef TENSORFLOW_C_TF_TSTRING_H_ 16 | #define TENSORFLOW_C_TF_TSTRING_H_ 17 | 18 | #include "tensorflow/core/platform/ctstring.h" 19 | 20 | #endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_ 21 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/ops/beam_search_ops.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 17 | #define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 18 | 19 | #include "tensorflow/core/framework/tensor_types.h" 20 | #include "tensorflow/core/platform/types.h" 21 | 22 | namespace tensorflow { 23 | class OpKernelContext; 24 | 25 | namespace functor { 26 | 27 | template 28 | struct GatherTree { 29 | void operator()(OpKernelContext* ctx, const Device& d, 30 | typename TTypes::ConstTensor step_ids, 31 | typename TTypes::ConstTensor parent_ids, 32 | TTypes::ConstVec max_sequence_lengths, 33 | const T end_token, typename TTypes::Tensor beams); 34 | }; 35 | 36 | } // namespace functor 37 | } // namespace tensorflow 38 | 39 | #endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_ 40 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/ops/jvm_callback_op.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | #ifndef TENSORFLOW_JVM_CALLBACK_OP_H_ 17 | #define TENSORFLOW_JVM_CALLBACK_OP_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "tensorflow/c/c_api.h" 25 | #include "tensorflow/c/kernels.h" 26 | 27 | // A call to the registered JVM function. 28 | struct JVMCall { 29 | JNIEnv* env; 30 | jclass registry; 31 | jmethodID call_method_id; 32 | 33 | TF_OpKernelContext* ctx; 34 | 35 | // True if and only if this op has been placed on a GPU. 36 | bool gpu; 37 | 38 | // Passed to the JVM to call the function registered with this ID. 39 | int id; 40 | 41 | // Inputs and outputs of this function invocation. 42 | std::vector inputs; 43 | std::vector outputs; 44 | }; 45 | 46 | #endif // TENSORFLOW_JVM_CALLBACK_OP_H_ 47 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/server.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_Server__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_Server__ 6 | #define _Included_org_platanios_tensorflow_jni_Server__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | /* 11 | * Class: org_platanios_tensorflow_jni_Server__ 12 | * Method: newServer 13 | * Signature: ([B)J 14 | */ 15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Server_00024_newServer 16 | (JNIEnv *, jobject, jbyteArray); 17 | 18 | /* 19 | * Class: org_platanios_tensorflow_jni_Server__ 20 | * Method: target 21 | * Signature: (J)Ljava/lang/String; 22 | */ 23 | JNIEXPORT jstring JNICALL Java_org_platanios_tensorflow_jni_Server_00024_target 24 | (JNIEnv *, jobject, jlong); 25 | 26 | /* 27 | * Class: org_platanios_tensorflow_jni_Server__ 28 | * Method: startServer 29 | * Signature: (J)V 30 | */ 31 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_startServer 32 | (JNIEnv *, jobject, jlong); 33 | 34 | /* 35 | * Class: org_platanios_tensorflow_jni_Server__ 36 | * Method: stopServer 37 | * Signature: (J)V 38 | */ 39 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_stopServer 40 | (JNIEnv *, jobject, jlong); 41 | 42 | /* 43 | * Class: org_platanios_tensorflow_jni_Server__ 44 | * Method: joinServer 45 | * Signature: (J)V 46 | */ 47 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_joinServer 48 | (JNIEnv *, jobject, jlong); 49 | 50 | /* 51 | * Class: org_platanios_tensorflow_jni_Server__ 52 | * Method: deleteServer 53 | * Signature: (J)V 54 | */ 55 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Server_00024_deleteServer 56 | (JNIEnv *, jobject, jlong); 57 | 58 | #ifdef __cplusplus 59 | } 60 | #endif 61 | #endif 62 | -------------------------------------------------------------------------------- /modules/jni/src/main/native/session.h: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | #include 3 | /* Header for class org_platanios_tensorflow_jni_Session__ */ 4 | 5 | #ifndef _Included_org_platanios_tensorflow_jni_Session__ 6 | #define _Included_org_platanios_tensorflow_jni_Session__ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | /* 11 | * Class: org_platanios_tensorflow_jni_Session__ 12 | * Method: allocate 13 | * Signature: (JLjava/lang/String;[B)J 14 | */ 15 | JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_Session_00024_allocate 16 | (JNIEnv *, jobject, jlong, jstring, jbyteArray); 17 | 18 | /* 19 | * Class: org_platanios_tensorflow_jni_Session__ 20 | * Method: delete 21 | * Signature: (J)V 22 | */ 23 | JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Session_00024_delete 24 | (JNIEnv *, jobject, jlong); 25 | 26 | /* 27 | * Class: org_platanios_tensorflow_jni_Session__ 28 | * Method: run 29 | * Signature: (J[B[J[J[I[J[I[JZ[J)[B 30 | */ 31 | JNIEXPORT jbyteArray JNICALL Java_org_platanios_tensorflow_jni_Session_00024_run 32 | (JNIEnv *, jobject, jlong, jbyteArray, jlongArray, jlongArray, jintArray, jlongArray, jintArray, jlongArray, jboolean, jlongArray); 33 | 34 | ///* 35 | // * Class: org_platanios_tensorflow_jni_Session__ 36 | // * Method: extend 37 | // * Signature: (J)V 38 | // */ 39 | //JNIEXPORT void JNICALL Java_org_platanios_tensorflow_jni_Session_00024_extend 40 | // (JNIEnv *, jobject, jlong); 41 | 42 | /* 43 | * Class: org_platanios_tensorflow_jni_Session__ 44 | * Method: deviceList 45 | * Signature: ([B)[[B 46 | */ 47 | JNIEXPORT jobjectArray JNICALL Java_org_platanios_tensorflow_jni_Session_00024_deviceList 48 | (JNIEnv *, jobject, jbyteArray); 49 | 50 | #ifdef __cplusplus 51 | } 52 | #endif 53 | #endif 54 | -------------------------------------------------------------------------------- /modules/jni/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/CheckpointReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | object CheckpointReader { 22 | TensorFlow.load() 23 | 24 | @native def newCheckpointReader(filePattern: String): Long 25 | @native def debugString(handle: Long): String 26 | @native def hasTensor(handle: Long, name: String): Boolean 27 | @native def getTensor(handle: Long, name: String): Long 28 | @native def variableShapes(handle: Long): VariableShapes 29 | @native def variableDataTypes(handle: Long): VariableDataTypes 30 | @native def delete(handle: Long): Unit 31 | } 32 | 33 | case class VariableShapes(variables: Array[String], shapes: Array[Array[Long]]) 34 | case class VariableDataTypes(variables: Array[String], dataTypes: Array[Int]) 35 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/Function.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | object Function { 22 | TensorFlow.load() 23 | 24 | @native def graphToFunction( 25 | fnBodyGraphHandle: Long, 26 | fnName: String, 27 | appendHashToFnName: Boolean, 28 | opHandles: Array[Long], 29 | inputOpHandles: Array[Long], 30 | inputOpIndices: Array[Int], 31 | outputOpHandles: Array[Long], 32 | outputOpIndices: Array[Int], 33 | outputNames: Array[String] 34 | ): Long 35 | 36 | @native def copyToGraph(graphHandle: Long, functionHandle: Long, gradientHandle: Long): Unit 37 | @native def toFunctionDef(handle: Long): Array[Byte] 38 | @native def delete(handle: Long): Unit 39 | } 40 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/Graph.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | object Graph { 22 | TensorFlow.load() 23 | 24 | @native def allocate(): Long 25 | @native def delete(handle: Long): Unit 26 | @native def findOp(handle: Long, name: String): Long 27 | @native def ops(handle: Long): Array[Long] 28 | 29 | @native def addGradients( 30 | handle: Long, 31 | y: Array[Output], 32 | x: Array[Output], 33 | dx: Array[Output] 34 | ): Array[Output] 35 | 36 | @native def importGraphDef( 37 | handle: Long, 38 | graphDef: Array[Byte], 39 | prefix: String, 40 | inputsMapSourceOpNames: Array[String], 41 | inputsMapSourceOutputIndices: Array[Int], 42 | inputsMapDestinationOpHandles: Array[Long], 43 | inputsMapDestinationOutputIndices: Array[Int], 44 | controlDependenciesMapSourceOpNames: Array[String], 45 | controlDependenciesMapDestinationOpHandles: Array[Long], 46 | controlDependenciesOpHandles: Array[Long]): Unit 47 | 48 | @native def toGraphDef(handle: Long): Array[Byte] 49 | } 50 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/ScalaCallbacksRegistry.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | import scala.collection.mutable 19 | 20 | /** Keeps a map from unique tokens (i.e., integer IDs) to Scala functions (i.e., callbacks), each of which takes an 21 | * array with handles of tensors as input and returns an array with handles of tensors as output. 22 | * 23 | * @author Emmanouil Antonios Platanios 24 | */ 25 | object ScalaCallbacksRegistry { 26 | private[this] var uniqueId = 0 27 | private[this] val callbacks = mutable.Map.empty[Int, Array[Long] => Array[Long]] 28 | 29 | /** Number of callbacks currently registered. */ 30 | def size: Int = callbacks.size 31 | 32 | /** Registers the provided callback function and returns a unique token to use when creating ops invoking it. */ 33 | def register(function: Array[Long] => Array[Long]): Int = this synchronized { 34 | val token = uniqueId 35 | callbacks.update(uniqueId, function) 36 | uniqueId += 1 37 | token 38 | } 39 | 40 | /** De-registers (i.e., removes from this registry) the function that corresponds to the provided token. */ 41 | def deregister(token: Int): Unit = this synchronized { 42 | callbacks.remove(token) 43 | } 44 | 45 | /** Invokes the callback identified by `token` using the provides input arguments. */ 46 | def call(token: Int, inputs: Array[Long]): Array[Long] = { 47 | val callback = this synchronized callbacks(token) 48 | callback(inputs) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/Server.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | /** 19 | * @author Emmanouil Antonios Platanios 20 | */ 21 | object Server { 22 | TensorFlow.load() 23 | 24 | @native def newServer(serverDef: Array[Byte]): Long 25 | @native def target(serverHandle: Long): String 26 | @native def startServer(serverHandle: Long): Unit 27 | @native def stopServer(serverHandle: Long): Unit 28 | @native def joinServer(serverHandle: Long): Unit 29 | @native def deleteServer(serverHandle: Long): Unit 30 | } 31 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/Tensor.scala: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | package org.platanios.tensorflow.jni 17 | 18 | import java.nio.ByteBuffer 19 | 20 | /** 21 | * @author Emmanouil Antonios Platanios 22 | */ 23 | object Tensor { 24 | TensorFlow.load() 25 | 26 | //region TensorFlow C Tensors 27 | 28 | @native def allocate(dataType: Int, shape: Array[Long], numBytes: Long): Long 29 | @native def fromBuffer(dataType: Int, shape: Array[Long], numBytes: Long, buffer: ByteBuffer): Long 30 | @native def dataType(handle: Long): Int 31 | @native def shape(handle: Long): Array[Long] 32 | @native def buffer(handle: Long): ByteBuffer 33 | @native def delete(handle: Long): Unit 34 | 35 | //endregion TensorFlow C Tensors 36 | 37 | //region TensorFlow Eager Tensors 38 | 39 | @native def eagerAllocateContext(configProto: Array[Byte]): Long 40 | @native def eagerDeleteContext(handle: Long): Unit 41 | @native def eagerAllocate(tensorHandle: Long): Long 42 | @native def eagerDataType(handle: Long): Int 43 | @native def eagerShape(handle: Long): Array[Long] 44 | @native def eagerDevice(handle: Long): String 45 | @native def eagerDelete(handle: Long): Unit 46 | @native def eagerResolve(handle: Long): Long 47 | @native def eagerCopyToDevice(handle: Long, contextHandle: Long, device: String): Long 48 | 49 | @native def eagerNewOp(contextHandle: Long, opOrFunctionName: String): Long 50 | @native def eagerDeleteOp(opHandle: Long): Unit 51 | @native def eagerSetOpDevice(opHandle: Long, device: String): Unit 52 | 53 | //endregion TensorFlow Eager Tensors 54 | 55 | //region String Helpers 56 | 57 | @native def setStringBytes(stringBytes: Array[Byte], buffer: ByteBuffer): Int 58 | @native def getStringBytes(buffer: ByteBuffer): Array[Byte] 59 | @native def tfStringSize(): Int 60 | 61 | //endregion String Helpers 62 | } 63 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Random.scala: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | 3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 6 | * use this file except in compliance with the License. You may obtain a copy of 7 | * the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14 | * License for the specific language governing permissions and limitations under 15 | * the License. 16 | */ 17 | 18 | package org.platanios.tensorflow.jni.generated.tensors 19 | 20 | import org.platanios.tensorflow.jni.TensorFlow 21 | 22 | object Random { 23 | TensorFlow.load() 24 | 25 | @native def randomShuffle(contextHandle: Long, value: Long, seed: Long, seed2: Long): Long 26 | @native def randomUniform(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long 27 | @native def randomUniformInt(contextHandle: Long, shape: Long, minval: Long, maxval: Long, seed: Long, seed2: Long): Long 28 | @native def randomStandardNormal(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long 29 | @native def truncatedNormal(contextHandle: Long, shape: Long, dtype: Int, seed: Long, seed2: Long): Long 30 | } 31 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Sparse.scala: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | 3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 6 | * use this file except in compliance with the License. You may obtain a copy of 7 | * the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14 | * License for the specific language governing permissions and limitations under 15 | * the License. 16 | */ 17 | 18 | package org.platanios.tensorflow.jni.generated.tensors 19 | 20 | import org.platanios.tensorflow.jni.TensorFlow 21 | 22 | object Sparse { 23 | TensorFlow.load() 24 | 25 | @native def sparseToDense(contextHandle: Long, sparse_indices: Long, output_shape: Long, sparse_values: Long, default_value: Long, validate_indices: Boolean): Long 26 | } 27 | -------------------------------------------------------------------------------- /modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Text.scala: -------------------------------------------------------------------------------- 1 | /* DO NOT EDIT THIS FILE - it is machine generated */ 2 | 3 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 6 | * use this file except in compliance with the License. You may obtain a copy of 7 | * the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14 | * License for the specific language governing permissions and limitations under 15 | * the License. 16 | */ 17 | 18 | package org.platanios.tensorflow.jni.generated.tensors 19 | 20 | import org.platanios.tensorflow.jni.TensorFlow 21 | 22 | object Text { 23 | TensorFlow.load() 24 | 25 | @native def stringJoin(contextHandle: Long, inputs: Array[Long], separator: Array[Byte]): Long 26 | @native def stringSplit(contextHandle: Long, input: Long, delimiter: Long, skip_empty: Boolean): Array[Long] 27 | @native def encodeBase64(contextHandle: Long, input: Long, pad: Boolean): Long 28 | @native def decodeBase64(contextHandle: Long, input: Long): Long 29 | @native def stringToHashBucket(contextHandle: Long, string_tensor: Long, num_buckets: Long): Long 30 | @native def stringToHashBucketFast(contextHandle: Long, input: Long, num_buckets: Long): Long 31 | @native def stringToHashBucketStrong(contextHandle: Long, input: Long, num_buckets: Long, key: Array[Long]): Long 32 | } 33 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/allocation_description.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "AllocationDescriptionProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/allocation_description_go_proto"; 10 | 11 | message AllocationDescription { 12 | // Total number of bytes requested 13 | int64 requested_bytes = 1; 14 | 15 | // Total number of bytes allocated if known 16 | int64 allocated_bytes = 2; 17 | 18 | // Name of the allocator used 19 | string allocator_name = 3; 20 | 21 | // Identifier of the allocated buffer if known 22 | int64 allocation_id = 4; 23 | 24 | // Set if this tensor only has one remaining reference 25 | bool has_single_reference = 5; 26 | 27 | // Address of the allocation. 28 | uint64 ptr = 6; 29 | } 30 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/autotuning.proto: -------------------------------------------------------------------------------- 1 | // This file defines protos that store the results of autotuning various 2 | // operations. 3 | // 4 | // They are in proto format because we want to log them structured. They offer 5 | // tremendous statistical, testing, and debugging value. 6 | syntax = "proto3"; 7 | 8 | package org.platanios.tensorflow.proto; 9 | 10 | import "any.proto"; 11 | import "duration.proto"; 12 | 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 14 | 15 | message CudnnVersion { 16 | int32 major = 1; 17 | int32 minor = 2; 18 | int32 patch = 3; 19 | } 20 | 21 | message ComputeCapability { 22 | int32 major = 1; 23 | int32 minor = 2; 24 | } 25 | 26 | message AutotuneResult { 27 | enum FailureKind { 28 | UNKNOWN = 0; 29 | REDZONE_MODIFIED = 1; 30 | WRONG_RESULT = 2; 31 | } 32 | 33 | message FailureResult { 34 | FailureKind kind = 1; 35 | string msg = 2; 36 | 37 | // For failure_kind == WRONG_RESULT, this field indicates the reference 38 | // configuration that we compared against. 39 | // 40 | // Note that the reference algorithm isn't always correct. However, 41 | // empirically it's more correct, as it's "algo 0", less fancy than the 42 | // compared one. 43 | oneof key { 44 | ConvKey reference_conv = 11; 45 | GemmKey reference_gemm = 12; 46 | } 47 | 48 | int64 buffer_address = 13; 49 | } 50 | 51 | message ConvKey { 52 | int64 algorithm = 1; 53 | bool tensor_ops_enabled = 2; 54 | } 55 | 56 | message GemmKey { 57 | int64 algorithm = 1; 58 | } 59 | 60 | int64 scratch_bytes = 8; 61 | tensorflow.proto.Duration run_time = 9; 62 | 63 | FailureResult failure = 7; 64 | 65 | oneof key { 66 | ConvKey conv = 5; 67 | GemmKey gemm = 6; 68 | } 69 | 70 | // Next ID: 14 71 | } 72 | 73 | message AutotuningLog { 74 | google.Any instr = 1; 75 | 76 | // Records all auto-tuning results per algorithm. 77 | repeated AutotuneResult results = 2; 78 | 79 | CudnnVersion cudnn_version = 3; 80 | ComputeCapability compute_capability = 4; 81 | 82 | // stream_executor::DeviceDescription::pci_bus_id. 83 | string device_pci_bus_id = 5; 84 | 85 | string blas_version = 6; 86 | 87 | // Next ID: 7 88 | } 89 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/bfc_memory_map.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 6 | 7 | // Some of the data from AllocatorStats 8 | message MemAllocatorStats { 9 | int64 num_allocs = 1; 10 | int64 bytes_in_use = 2; 11 | int64 peak_bytes_in_use = 3; 12 | int64 largest_alloc_size = 4; 13 | float fragmentation_metric = 5; 14 | } 15 | 16 | message MemChunk { 17 | uint64 address = 1; 18 | int64 size = 2; 19 | int64 requested_size = 3; 20 | int32 bin = 4; 21 | string op_name = 5; 22 | uint64 freed_at_count = 6; 23 | uint64 action_count = 7; 24 | bool in_use = 8; 25 | uint64 step_id = 9; 26 | } 27 | 28 | message BinSummary { 29 | int32 bin = 1; 30 | int64 total_bytes_in_use = 2; 31 | int64 total_bytes_in_bin = 3; 32 | int64 total_chunks_in_use = 4; 33 | int64 total_chunks_in_bin = 5; 34 | } 35 | 36 | message SnapShot { 37 | uint64 action_count = 1; 38 | int64 size = 2; 39 | } 40 | 41 | message MemoryDump { 42 | string allocator_name = 1; 43 | repeated BinSummary bin_summary = 2; 44 | repeated MemChunk chunk = 3; 45 | repeated SnapShot snap_shot = 4; 46 | MemAllocatorStats stats = 5; 47 | } 48 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/checkpoint_state.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | option cc_enable_arenas = true; 5 | 6 | option java_package = "org.platanios.tensorflow.proto"; 7 | option java_outer_classname = "CheckpointStateProto"; 8 | 9 | // Protocol buffer representing the checkpoint state. 10 | message CheckpointState { 11 | // Path to the most-recent model checkpoint. 12 | string model_checkpoint_path = 1; 13 | 14 | // Paths to all not-yet-deleted model checkpoints, sorted from oldest to newest. 15 | // Note that the value of model_checkpoint_path should be the last item in this list. 16 | repeated string all_model_checkpoint_paths = 2; 17 | } 18 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/critical_section.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "CriticalSectionProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 10 | 11 | // Protocol buffer representing a CriticalSection. 12 | message CriticalSectionDef { 13 | // Name of the critical section handle. 14 | string critical_section_name = 1; 15 | } 16 | 17 | // Protocol buffer representing a CriticalSection execution. 18 | message CriticalSectionExecutionDef { 19 | // Name of the critical section handle. 20 | string execute_in_critical_section_name = 1; 21 | // Whether this operation requires exclusive access to its resources, 22 | // (i.e., no other CriticalSections may request the same resources). 23 | bool exclusive_resource_access = 2; 24 | } 25 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/device_attributes.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "DeviceAttributesProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/device_attributes_go_proto"; 10 | 11 | message InterconnectLink { 12 | int32 device_id = 1; 13 | string type = 2; 14 | int32 strength = 3; 15 | } 16 | 17 | message LocalLinks { 18 | repeated InterconnectLink link = 1; 19 | } 20 | 21 | message DeviceLocality { 22 | // Optional bus locality of device. Default value of 0 means 23 | // no specific locality. Specific localities are indexed from 1. 24 | int32 bus_id = 1; 25 | 26 | // Optional NUMA locality of device. 27 | int32 numa_node = 2; 28 | 29 | // Optional local interconnect links to other devices. 30 | LocalLinks links = 3; 31 | } 32 | 33 | message DeviceAttributes { 34 | // Fully specified name of the device within a cluster. 35 | string name = 1; 36 | 37 | // String representation of device_type. 38 | string device_type = 2; 39 | 40 | // Memory capacity of device in bytes. 41 | int64 memory_limit = 4; 42 | 43 | // Platform-specific data about device that may be useful 44 | // for supporting efficient data transfers. 45 | DeviceLocality locality = 5; 46 | 47 | // A device is assigned a global unique number each time it is 48 | // initialized. "incarnation" should never be 0. 49 | fixed64 incarnation = 6; 50 | 51 | // String representation of the physical device that this device maps to. 52 | string physical_device_desc = 7; 53 | } 54 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/device_properties.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package org.platanios.tensorflow.proto; 19 | 20 | option cc_enable_arenas = true; 21 | option java_outer_classname = "DevicePropertiesProtos"; 22 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 23 | 24 | message DeviceProperties { 25 | // Device type (CPU, GPU, ...) 26 | string type = 1; 27 | // Vendor (Intel, nvidia, ...) 28 | string vendor = 2; 29 | // Model (Haswell, K40, ...) 30 | string model = 3; 31 | // Core Frequency in Mhz 32 | int64 frequency = 4; 33 | // Number of cores 34 | int64 num_cores = 5; 35 | // Version of the tools and libraries used with this device (e.g. gcc 4.9, 36 | // cudnn 5.1) 37 | map environment = 6; 38 | // Number of registers per core. 39 | int64 num_registers = 7; 40 | // L1 cache size in bytes 41 | int64 l1_cache_size = 8; 42 | // L2 cache size in bytes 43 | int64 l2_cache_size = 9; 44 | // L3 cache size in bytes 45 | int64 l3_cache_size = 10; 46 | // Shared memory size per multiprocessor in bytes. This field is 47 | // applicable to GPUs only. 48 | int64 shared_memory_size_per_multiprocessor = 11; 49 | // Memory size in bytes 50 | int64 memory_size = 12; 51 | // Memory bandwidth in KB/s 52 | int64 bandwidth = 13; 53 | } 54 | 55 | message NamedDevice { 56 | string name = 1; 57 | DeviceProperties properties = 2; 58 | } 59 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "function.proto"; 6 | import "node_def.proto"; 7 | import "versions.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "GraphProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.platanios.tensorflow.proto"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_go_proto"; 14 | 15 | // Represents the graph of operations 16 | message GraphDef { 17 | repeated NodeDef node = 1; 18 | 19 | // Compatibility versions of the graph. See core/public/version.h for version 20 | // history. The GraphDef version is distinct from the TensorFlow version, and 21 | // each release of TensorFlow will support a range of GraphDef versions. 22 | VersionDef versions = 4; 23 | 24 | // Deprecated single version field; use versions above instead. Since all 25 | // GraphDef changes before "versions" was introduced were forward 26 | // compatible, this field is entirely ignored. 27 | int32 version = 3 [deprecated = true]; 28 | 29 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 30 | // 31 | // "library" provides user-defined functions. 32 | // 33 | // Naming: 34 | // * library.function.name are in a flat namespace. 35 | // NOTE: We may need to change it to be hierarchical to support 36 | // different orgs. E.g., 37 | // { "/google/nn", { ... }}, 38 | // { "/google/vision", { ... }} 39 | // { "/org_foo/module_bar", { ... }} 40 | // map named_lib; 41 | // * If node[i].op is the name of one function in "library", 42 | // node[i] is deemed as a function call. Otherwise, node[i].op 43 | // must be a primitive operation supported by the runtime. 44 | // 45 | // 46 | // Function call semantics: 47 | // 48 | // * The callee may start execution as soon as some of its inputs 49 | // are ready. The caller may want to use Tuple() mechanism to 50 | // ensure all inputs are ready in the same time. 51 | // 52 | // * The consumer of return values may start executing as soon as 53 | // the return values the consumer depends on are ready. The 54 | // consumer may want to use Tuple() mechanism to ensure the 55 | // consumer does not start until all return values of the callee 56 | // function are ready. 57 | FunctionDefLibrary library = 2; 58 | } 59 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/graph_debug_info.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "GraphDebugInfoProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 10 | 11 | message GraphDebugInfo { 12 | // This represents a file/line location in the source code. 13 | message FileLineCol { 14 | // File name index, which can be used to retrieve the file name string from 15 | // `files`. The value should be between 0 and (len(files)-1) 16 | int32 file_index = 1; 17 | 18 | // Line number in the file. 19 | int32 line = 2; 20 | 21 | // Col number in the file line. 22 | int32 col = 3; 23 | 24 | // Name of function contains the file line. 25 | string func = 4; 26 | 27 | // Source code contained in this file line. 28 | string code = 5; 29 | } 30 | 31 | // This represents a stack trace which is a ordered list of `FileLineCol`. 32 | message StackTrace { 33 | // Each line in the stack trace. 34 | repeated FileLineCol file_line_cols = 1; 35 | } 36 | 37 | // This stores all the source code file names and can be indexed by the 38 | // `file_index`. 39 | repeated string files = 1; 40 | 41 | // This maps a node name to a stack trace in the source code. 42 | // The map key is a mangling of the containing function and op name with 43 | // syntax: 44 | // op.name '@' func_name 45 | // For ops in the top-level graph, the func_name is the empty string. 46 | // Note that op names are restricted to a small number of characters which 47 | // exclude '@', making it impossible to collide keys of this form. Function 48 | // names accept a much wider set of characters. 49 | // It would be preferable to avoid mangling and use a tuple key of (op.name, 50 | // func_name), but this is not supported with protocol buffers. 51 | map traces = 2; 52 | } 53 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/graph_transfer_info.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "types.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "GraphTransferInfoProto"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/graph_transfer_info_go_proto"; 12 | 13 | message GraphTransferNodeInput { 14 | int32 node_id = 1; 15 | int32 output_port = 2; 16 | } 17 | message GraphTransferNodeInfo { 18 | string name = 1; 19 | int32 node_id = 2; 20 | string type_name = 3; 21 | int32 soc_op_id = 4; 22 | int32 padding_id = 5; 23 | int32 input_count = 6; 24 | int32 output_count = 7; 25 | } 26 | message GraphTransferConstNodeInfo { 27 | string name = 1; 28 | int32 node_id = 2; 29 | repeated int64 shape = 3; 30 | bytes data = 4; 31 | DataType dtype = 5; 32 | } 33 | message GraphTransferNodeInputInfo { 34 | int32 node_id = 1; 35 | repeated GraphTransferNodeInput node_input = 2; 36 | } 37 | message GraphTransferNodeOutputInfo { 38 | int32 node_id = 1; 39 | repeated int32 max_byte_size = 2; 40 | } 41 | message GraphTransferGraphInputNodeInfo { 42 | string name = 1; 43 | repeated int64 shape = 2; 44 | DataType dtype = 3; 45 | } 46 | 47 | message GraphTransferGraphOutputNodeInfo { 48 | string name = 1; 49 | repeated int64 shape = 2; 50 | DataType dtype = 3; 51 | } 52 | 53 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 54 | // not valid across executions, but can be serialized back and forth from within 55 | // a single run. 56 | message GraphTransferInfo { 57 | enum Destination { 58 | NOP = 0; 59 | HEXAGON = 1; 60 | } 61 | 62 | repeated GraphTransferNodeInfo node_info = 1; 63 | repeated GraphTransferConstNodeInfo const_node_info = 2; 64 | repeated GraphTransferNodeInputInfo node_input_info = 3; 65 | repeated GraphTransferNodeOutputInfo node_output_info = 4; 66 | // Input Node parameters of transferred graph 67 | repeated GraphTransferGraphInputNodeInfo graph_input_node_info = 5; 68 | repeated GraphTransferGraphOutputNodeInfo graph_output_node_info = 6; 69 | // Destination of graph transfer 70 | Destination destination = 7; 71 | } 72 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/kernel_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "attr_value.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "KernelDefProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/kernel_def_go_proto"; 12 | 13 | message KernelDef { 14 | // Must match the name of an Op. 15 | string op = 1; 16 | 17 | // Type of device this kernel runs on. 18 | string device_type = 2; 19 | 20 | message AttrConstraint { 21 | // Name of an attr from the Op. 22 | string name = 1; 23 | 24 | // A list of values that this kernel supports for this attr. 25 | // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. 26 | AttrValue allowed_values = 2; 27 | } 28 | repeated AttrConstraint constraint = 3; 29 | 30 | // Names of the Op's input_/output_args that reside in host memory 31 | // instead of device memory. 32 | repeated string host_memory_arg = 4; 33 | 34 | // This allows experimental kernels to be registered for an op that 35 | // won't be used unless the user specifies a "_kernel" attr with 36 | // value matching this. 37 | string label = 5; 38 | 39 | // Prioritization of kernel amongst different devices. By default we assume 40 | // priority is 0. The higher the priority the better. By default (i.e. if 41 | // this is not set), we prefer GPU kernels over CPU. 42 | int32 priority = 6; 43 | } 44 | 45 | // A collection of KernelDefs 46 | message KernelList { 47 | repeated KernelDef kernel = 1; 48 | } 49 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/named_tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "tensor.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "NamedTensorProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 12 | 13 | // A pair of tensor name and tensor values. 14 | message NamedTensorProto { 15 | // Name of the tensor. 16 | string name = 1; 17 | 18 | // The client can populate a TensorProto using a tensorflow::Tensor`, or 19 | // directly using the protobuf field accessors. 20 | // 21 | // The client specifies whether the returned tensor values should be 22 | // filled tensor fields (float_val, int_val, etc.) or encoded in a 23 | // compact form in tensor.tensor_content. 24 | TensorProto tensor = 2; 25 | } 26 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/queue_runner.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "error_codes.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "QueueRunnerProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 12 | 13 | // Protocol buffer representing a QueueRunner. 14 | message QueueRunnerDef { 15 | // Queue name. 16 | string queue_name = 1; 17 | 18 | // A list of enqueue operations. 19 | repeated string enqueue_op_name = 2; 20 | 21 | // The operation to run to close the queue. 22 | string close_op_name = 3; 23 | 24 | // The operation to run to cancel the queue. 25 | string cancel_op_name = 4; 26 | 27 | // A list of exception types considered to signal a safely closed queue 28 | // if raised during enqueue operations. 29 | repeated error.Code queue_closed_exception_types = 5; 30 | } 31 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/reader_base.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "ReaderBaseProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/reader_base_go_proto"; 10 | 11 | // For serializing and restoring the state of ReaderBase, see 12 | // reader_base.h for details. 13 | message ReaderBaseState { 14 | int64 work_started = 1; 15 | int64 work_finished = 2; 16 | int64 num_records_produced = 3; 17 | bytes current_work = 4; 18 | } 19 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/remote_tensor_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "tensor_shape.proto"; 6 | import "types.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "RemoteTensorHandleProtos"; 10 | option java_multiple_files = true; 11 | option java_package = "org.platanios.tensorflow.proto"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 13 | 14 | message ResourceDtypeAndShape { 15 | DataType dtype = 1; 16 | TensorShapeProto shape = 2; 17 | } 18 | 19 | message RemoteTensorHandle { 20 | // The ID of the operation that produced this tensor. 21 | int64 op_id = 1; 22 | // The index into the outputs of the operation that produced this tensor. 23 | int32 output_num = 2; 24 | // Device of the operation that produced this tensor. Cannot be empty. 25 | // For multi-device functions, it's the default device passed to placer. 26 | string device = 3; 27 | // Device where the tensor is located. Can be empty if the operation producing 28 | // this tensor is a multi-device function. 29 | string op_device = 4; 30 | // Tensor type. 31 | DataType dtype = 5; 32 | // Optional data types and shapes of a remote resource variable. 33 | repeated ResourceDtypeAndShape resource_dtypes_and_shapes = 6; 34 | } 35 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/replay_log.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "master.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 9 | 10 | // Records the creation of a new replay session. We record the device listing 11 | // here to capture the state of the cluster. 12 | message NewReplaySession { 13 | ListDevicesResponse devices = 1; 14 | string session_handle = 2; 15 | } 16 | 17 | message ReplayOp { 18 | double start_time_us = 31; 19 | double end_time_us = 32; 20 | 21 | oneof op { 22 | CreateSessionRequest create_session = 1; 23 | ExtendSessionRequest extend_session = 2; 24 | PartialRunSetupRequest partial_run_setup = 3; 25 | RunStepRequest run_step = 4; 26 | CloseSessionRequest close_session = 5; 27 | ListDevicesRequest list_devices = 6; 28 | ResetRequest reset_request = 7; 29 | MakeCallableRequest make_callable = 8; 30 | RunCallableRequest run_callable = 9; 31 | ReleaseCallableRequest release_callable = 10; 32 | NewReplaySession new_replay_session = 11; 33 | } 34 | 35 | oneof response { 36 | CreateSessionResponse create_session_response = 21; 37 | ExtendSessionResponse extend_session_response = 22; 38 | PartialRunSetupResponse partial_run_setup_response = 23; 39 | RunStepResponse run_step_response = 24; 40 | CloseSessionResponse close_session_response = 25; 41 | ListDevicesResponse list_devices_response = 26; 42 | ResetResponse reset_request_response = 27; 43 | MakeCallableResponse make_callable_response = 28; 44 | RunCallableResponse run_callable_response = 29; 45 | ReleaseCallableResponse release_callable_response = 30; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "tensor_shape.proto"; 6 | import "types.proto"; 7 | 8 | option cc_enable_arenas = true; 9 | option java_outer_classname = "ResourceHandle"; 10 | option java_multiple_files = true; 11 | option java_package = "org.platanios.tensorflow.proto"; 12 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto"; 13 | 14 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 15 | // not valid across executions, but can be serialized back and forth from within 16 | // a single run. 17 | message ResourceHandleProto { 18 | // Unique name for the device containing the resource. 19 | string device = 1; 20 | 21 | // Container in which this resource is placed. 22 | string container = 2; 23 | 24 | // Unique name of this resource. 25 | string name = 3; 26 | 27 | // Hash code for the type of the resource. Is only valid in the same device 28 | // and in the same execution. 29 | uint64 hash_code = 4; 30 | 31 | // For debug-only, the name of the type pointed to by this handle, if 32 | // available. 33 | string maybe_type_name = 5; 34 | 35 | // Protocol buffer representing a pair of (data type, tensor shape). 36 | message DtypeAndShape { 37 | DataType dtype = 1; 38 | TensorShapeProto shape = 2; 39 | } 40 | 41 | // Data types and shapes for the underlying resource. 42 | repeated DtypeAndShape dtypes_and_shapes = 6; 43 | 44 | // A set of devices containing the resource. If empty, the resource only 45 | // exists on `device`. 46 | repeated string allowed_devices = 7; 47 | } 48 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/saved_model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "meta_graph.proto"; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "SavedModelProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 12 | 13 | // SavedModel is the high level serialization format for TensorFlow Models. 14 | // See [todo: doc links, similar to session_bundle] for more information. 15 | message SavedModel { 16 | // The schema version of the SavedModel instance. Used for versioning when 17 | // making future changes to the specification/implementation. Initial value 18 | // at release will be 1. 19 | int64 saved_model_schema_version = 1; 20 | 21 | // One or more MetaGraphs. 22 | repeated MetaGraphDef meta_graphs = 2; 23 | } 24 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/saver.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "SaverProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 10 | 11 | // Protocol buffer representing the configuration of a Saver. 12 | message SaverDef { 13 | // The name of the tensor in which to specify the filename when saving or 14 | // restoring a model checkpoint. 15 | string filename_tensor_name = 1; 16 | 17 | // The operation to run when saving a model checkpoint. 18 | string save_tensor_name = 2; 19 | 20 | // The operation to run when restoring a model checkpoint. 21 | string restore_op_name = 3; 22 | 23 | // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. 24 | int32 max_to_keep = 4; 25 | 26 | // Shard the save files, one per device that has Variable nodes. 27 | bool sharded = 5; 28 | 29 | // How often to keep an additional checkpoint. If not specified, only the last 30 | // "max_to_keep" checkpoints are kept; if specified, in addition to keeping 31 | // the last "max_to_keep" checkpoints, an additional checkpoint will be kept 32 | // for every n hours of training. 33 | float keep_checkpoint_every_n_hours = 6; 34 | 35 | // A version number that identifies a different on-disk checkpoint format. 36 | // Usually, each subclass of BaseSaverBuilder works with a particular 37 | // version/format. However, it is possible that the same builder may be 38 | // upgraded to support a newer checkpoint format in the future. 39 | enum CheckpointFormatVersion { 40 | // Internal legacy format. 41 | LEGACY = 0; 42 | // Deprecated format: tf.Saver() which works with tensorflow::table::Table. 43 | V1 = 1; 44 | // Current format: more efficient. 45 | V2 = 2; 46 | } 47 | CheckpointFormatVersion version = 7; 48 | } 49 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/source_context.proto: -------------------------------------------------------------------------------- 1 | // Protocol Buffers - Google's data interchange format 2 | // Copyright 2008 Google Inc. All rights reserved. 3 | // https://developers.google.com/protocol-buffers/ 4 | // 5 | // Redistribution and use in source and binary forms, with or without 6 | // modification, are permitted provided that the following conditions are 7 | // met: 8 | // 9 | // * Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // * Redistributions in binary form must reproduce the above 12 | // copyright notice, this list of conditions and the following disclaimer 13 | // in the documentation and/or other materials provided with the 14 | // distribution. 15 | // * Neither the name of Google Inc. nor the names of its 16 | // contributors may be used to endorse or promote products derived from 17 | // this software without specific prior written permission. 18 | // 19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | syntax = "proto3"; 32 | 33 | package org.platanios.tensorflow.proto; 34 | 35 | option csharp_namespace = "Google.Protobuf.WellKnownTypes"; 36 | option java_package = "org.platanios.tensorflow.proto"; 37 | option java_outer_classname = "SourceContextProto"; 38 | option java_multiple_files = true; 39 | option objc_class_prefix = "GPB"; 40 | option go_package = "google.golang.org/genproto/protobuf/source_context;source_context"; 41 | 42 | // `SourceContext` represents information about the source of a 43 | // protobuf element, like the file in which it is defined. 44 | message SourceContext { 45 | // The path-qualified name of the .proto file that contained the associated 46 | // protobuf element. For example: `"google/protobuf/source_context.proto"`. 47 | string file_name = 1; 48 | } 49 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/tensor_bundle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "tensor_shape.proto"; 6 | import "tensor_slice.proto"; 7 | import "types.proto"; 8 | import "versions.proto"; 9 | 10 | option cc_enable_arenas = true; 11 | option java_outer_classname = "TensorBundleProtos"; 12 | option java_multiple_files = true; 13 | option java_package = "org.platanios.tensorflow.proto"; 14 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 15 | 16 | // Protos used in the tensor bundle module (tf/core/util/tensor_bundle/). 17 | 18 | // Special header that is associated with a bundle. 19 | // 20 | // TODO(zongheng,zhifengc): maybe in the future, we can add information about 21 | // which binary produced this checkpoint, timestamp, etc. Sometime, these can be 22 | // valuable debugging information. And if needed, these can be used as defensive 23 | // information ensuring reader (binary version) of the checkpoint and the writer 24 | // (binary version) must match within certain range, etc. 25 | message BundleHeaderProto { 26 | // Number of data files in the bundle. 27 | int32 num_shards = 1; 28 | 29 | // An enum indicating the endianness of the platform that produced this 30 | // bundle. A bundle can only be read by a platform with matching endianness. 31 | // Defaults to LITTLE, as most modern platforms are little-endian. 32 | // 33 | // Affects the binary tensor data bytes only, not the metadata in protobufs. 34 | enum Endianness { 35 | LITTLE = 0; 36 | BIG = 1; 37 | } 38 | Endianness endianness = 2; 39 | 40 | // Versioning of the tensor bundle format. 41 | VersionDef version = 3; 42 | } 43 | 44 | // Describes the metadata related to a checkpointed tensor. 45 | message BundleEntryProto { 46 | // The tensor dtype and shape. 47 | DataType dtype = 1; 48 | TensorShapeProto shape = 2; 49 | // The binary content of the tensor lies in: 50 | // File "shard_id": bytes [offset, offset + size). 51 | int32 shard_id = 3; 52 | int64 offset = 4; 53 | int64 size = 5; 54 | 55 | // The CRC32C checksum of the tensor bytes. 56 | fixed32 crc32c = 6; 57 | 58 | // Iff present, this entry represents a partitioned tensor. The previous 59 | // fields are interpreted as follows: 60 | // 61 | // "dtype", "shape": describe the full tensor. 62 | // "shard_id", "offset", "size", "crc32c": all IGNORED. 63 | // These information for each slice can be looked up in their own 64 | // BundleEntryProto, keyed by each "slice_name". 65 | repeated TensorSliceProto slices = 7; 66 | } 67 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/tensor_description.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | import "allocation_description.proto"; 6 | import "tensor_shape.proto"; 7 | import "types.proto"; 8 | 9 | option cc_enable_arenas = true; 10 | option java_outer_classname = "TensorDescriptionProtos"; 11 | option java_multiple_files = true; 12 | option java_package = "org.platanios.tensorflow.proto"; 13 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_description_go_proto"; 14 | 15 | message TensorDescription { 16 | // Data type of tensor elements 17 | DataType dtype = 1; 18 | 19 | // Shape of the tensor. 20 | TensorShapeProto shape = 2; 21 | 22 | // Information about the size and allocator used for the data 23 | AllocationDescription allocation_description = 4; 24 | } 25 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | 5 | package org.platanios.tensorflow.proto; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "TensorShapeProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"; 12 | 13 | // Dimensions of a tensor. 14 | message TensorShapeProto { 15 | // One dimension of the tensor. 16 | message Dim { 17 | // Size of the tensor in that dimension. 18 | // This value must be >= -1, but values of -1 are reserved for "unknown" 19 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 20 | // that work with TensorShapeProto may fail at runtime when deserializing 21 | // a TensorShapeProto containing a dim value of -1. 22 | int64 size = 1; 23 | 24 | // Optional name of the tensor dimension. 25 | string name = 2; 26 | }; 27 | 28 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 29 | // for a 30 x 40 2D tensor. If an entry has size -1, this 30 | // corresponds to a dimension of unknown size. The names are 31 | // optional. 32 | // 33 | // The order of entries in "dim" matters: It indicates the layout of the 34 | // values in the tensor in-memory representation. 35 | // 36 | // The first entry in "dim" is the outermost dimension used to layout the 37 | // values, the last entry is the innermost dimension. This matches the 38 | // in-memory layout of RowMajor Eigen tensors. 39 | // 40 | // If "dim.size()" > 0, "unknown_rank" must be false. 41 | repeated Dim dim = 2; 42 | 43 | // If true, the number of dimensions in the shape is unknown. 44 | // 45 | // If true, "dim.size()" must be 0. 46 | bool unknown_rank = 3; 47 | }; 48 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/tensor_slice.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing slices of a tensor 2 | 3 | syntax = "proto3"; 4 | 5 | package org.platanios.tensorflow.proto; 6 | 7 | option cc_enable_arenas = true; 8 | option java_outer_classname = "TensorSliceProtos"; 9 | option java_multiple_files = true; 10 | option java_package = "org.platanios.tensorflow.proto"; 11 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_slice_go_proto"; 12 | 13 | // Can only be interpreted if you know the corresponding TensorShape. 14 | message TensorSliceProto { 15 | // Extent of the slice in one dimension. 16 | message Extent { 17 | // Either both or no attributes must be set. When no attribute is set 18 | // means: All data in that dimension. 19 | 20 | // Start index of the slice, starting at 0. 21 | int64 start = 1; 22 | 23 | // Length of the slice: if the length is missing or -1 we will 24 | // interpret this as "everything in this dimension". We use 25 | // "oneof" to preserve information about whether the length is 26 | // present without changing the serialization format from the 27 | // prior proto2 version of this proto. 28 | oneof has_length { 29 | int64 length = 2; 30 | } 31 | } 32 | 33 | // Extent of the slice in all tensor dimensions. 34 | // 35 | // Must have one entry for each of the dimension of the tensor that this 36 | // slice belongs to. The order of sizes is the same as the order of 37 | // dimensions in the TensorShape. 38 | repeated Extent extent = 1; 39 | } 40 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/tensorflow_server.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package org.platanios.tensorflow.proto; 19 | 20 | import "cluster.proto"; 21 | import "config.proto"; 22 | import "device_filters.proto"; 23 | 24 | option cc_enable_arenas = true; 25 | option java_outer_classname = "ServerProtos"; 26 | option java_multiple_files = true; 27 | option java_package = "org.platanios.tensorflow.proto"; 28 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 29 | 30 | // Defines the configuration of a single TensorFlow server. 31 | message ServerDef { 32 | // The cluster of which this server is a member. 33 | ClusterDef cluster = 1; 34 | 35 | // The name of the job of which this server is a member. 36 | // 37 | // NOTE(mrry): The `cluster` field must contain a `JobDef` with a `name` field 38 | // that matches this name. 39 | string job_name = 2; 40 | 41 | // The task index of this server in its job. 42 | // 43 | // NOTE: The `cluster` field must contain a `JobDef` with a matching `name` 44 | // and a mapping in its `tasks` field for this index. 45 | int32 task_index = 3; 46 | 47 | // The default configuration for sessions that run on this server. 48 | ConfigProto default_session_config = 4; 49 | 50 | // The protocol to be used by this server. 51 | // 52 | // Acceptable values include: "grpc", "grpc+verbs". 53 | string protocol = 5; 54 | 55 | // The server port. If not set, then we identify the port from the job_name. 56 | int32 port = 6; 57 | 58 | // Device filters for remote tasks in the cluster. 59 | // NOTE: This is an experimental feature and only effective in TensorFlow 2.x. 60 | ClusterDeviceFilters cluster_device_filters = 7; 61 | } 62 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/trackable_object_graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 7 | 8 | // A TensorBundle addition which saves extra information about the objects which 9 | // own variables, allowing for more robust checkpoint loading into modified 10 | // programs. 11 | 12 | message TrackableObjectGraph { 13 | message TrackableObject { 14 | message ObjectReference { 15 | // An index into `TrackableObjectGraph.nodes`, indicating the object 16 | // being referenced. 17 | int32 node_id = 1; 18 | // A user-provided name for the edge. 19 | string local_name = 2; 20 | } 21 | 22 | message SerializedTensor { 23 | // A name for the Tensor. Simple variables have only one 24 | // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may 25 | // be restored on object creation as an optimization. 26 | string name = 1; 27 | // The full name of the variable/tensor, if applicable. Used to allow 28 | // name-based loading of checkpoints which were saved using an 29 | // object-based API. Should match the checkpoint key which would have been 30 | // assigned by tf.train.Saver. 31 | string full_name = 2; 32 | // The generated name of the Tensor in the checkpoint. 33 | string checkpoint_key = 3; 34 | // Whether checkpoints should be considered as matching even without this 35 | // value restored. Used for non-critical values which don't affect the 36 | // TensorFlow graph, such as layer configurations. 37 | bool optional_restore = 4; 38 | } 39 | 40 | message SlotVariableReference { 41 | // An index into `TrackableObjectGraph.nodes`, indicating the 42 | // variable object this slot was created for. 43 | int32 original_variable_node_id = 1; 44 | // The name of the slot (e.g. "m"/"v"). 45 | string slot_name = 2; 46 | // An index into `TrackableObjectGraph.nodes`, indicating the 47 | // `Object` with the value of the slot variable. 48 | int32 slot_variable_node_id = 3; 49 | } 50 | 51 | // Objects which this object depends on. 52 | repeated ObjectReference children = 1; 53 | // Serialized data specific to this object. 54 | repeated SerializedTensor attributes = 2; 55 | // Slot variables owned by this object. 56 | repeated SlotVariableReference slot_variables = 3; 57 | } 58 | 59 | repeated TrackableObject nodes = 1; 60 | } 61 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/transport_options.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 6 | 7 | // Extra data needed on a non-RDMA RecvBufResponse. 8 | message RecvBufRespExtra { 9 | repeated bytes tensor_content = 1; 10 | } 11 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/verifier_config.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "VerifierConfigProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; 10 | 11 | // The config for graph verifiers. 12 | message VerifierConfig { 13 | enum Toggle { 14 | DEFAULT = 0; 15 | ON = 1; 16 | OFF = 2; 17 | } 18 | 19 | // Deadline for completion of all verification i.e. all the Toggle ON 20 | // verifiers must complete execution within this time. 21 | int64 verification_timeout_in_ms = 1; 22 | 23 | // Perform structural validation on a tensorflow graph. Default is OFF. 24 | Toggle structure_verifier = 2; 25 | 26 | // Next tag: 3 27 | } 28 | -------------------------------------------------------------------------------- /modules/proto/src/main/proto/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.platanios.tensorflow.proto; 4 | 5 | option cc_enable_arenas = true; 6 | option java_outer_classname = "VersionsProtos"; 7 | option java_multiple_files = true; 8 | option java_package = "org.platanios.tensorflow.proto"; 9 | option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/versions_go_proto"; 10 | 11 | // Version information for a piece of serialized data 12 | // 13 | // There are different types of versions for each type of data 14 | // (GraphDef, etc.), but they all have the same common shape 15 | // described here. 16 | // 17 | // Each consumer has "consumer" and "min_producer" versions (specified 18 | // elsewhere). A consumer is allowed to consume this data if 19 | // 20 | // producer >= min_producer 21 | // consumer >= min_consumer 22 | // consumer not in bad_consumers 23 | // 24 | message VersionDef { 25 | // The version of the code that produced this data. 26 | int32 producer = 1; 27 | 28 | // Any consumer below this version is not allowed to consume this data. 29 | int32 min_consumer = 2; 30 | 31 | // Specific consumer versions which are disallowed (e.g. due to bugs). 32 | repeated int32 bad_consumers = 3; 33 | } 34 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.4.7 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | /* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | * use this file except in compliance with the License. You may obtain a copy of 5 | * the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | * License for the specific language governing permissions and limitations under 13 | * the License. 14 | */ 15 | 16 | logLevel := Level.Warn 17 | 18 | libraryDependencies ++= Seq( 19 | "ch.qos.logback" % "logback-classic" % "1.2.3", 20 | "org.ow2.asm" % "asm" % "6.2.1", 21 | // The following is needed to automatically generate the eager ops. 22 | "org.tensorflow" % "proto" % "1.15.0") 23 | 24 | addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.2.5") 25 | addSbtPlugin("com.github.sbt" % "sbt-protobuf" % "0.7.0") 26 | 27 | // Plugins used for the documentation website. 28 | addSbtPlugin("com.lightbend.paradox" % "sbt-paradox" % "0.8.0") 29 | addSbtPlugin("io.github.jonas" % "sbt-paradox-material-theme" % "0.6.0") 30 | addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.3.2") 31 | addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.3") 32 | addSbtPlugin("com.thoughtworks.sbt-api-mappings" % "sbt-api-mappings" % "latest.release") 33 | 34 | // Packaging and publishing related plugins. 35 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.13") 36 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.0") 37 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.2") 38 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | version in ThisBuild := "0.6.6-SNAPSHOT" 2 | --------------------------------------------------------------------------------