├── .git-blame-ignore-revs ├── .github ├── dependabot.yml └── workflows │ ├── check-lint.yml │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .scalafmt.conf ├── LICENSE.md ├── README.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── scripts ├── check-lint.sh ├── clangfmt ├── git-release └── scalafmt └── stensorflow └── src ├── main ├── resources │ └── scala-native │ │ └── tensorflow.c └── scala │ └── org │ └── ekrich │ └── tensorflow │ └── unsafe │ └── tensorflow.scala └── test └── scala └── org └── ekrich └── tensorflow └── unsafe └── TensorflowTest.scala /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Scala Steward: Reformat with scalafmt 3.7.5 2 | 41a085cce3219dfe1a485440af06e5bc153292ec 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Set update schedule for GitHub Actions 2 | 3 | version: 2 4 | updates: 5 | 6 | - package-ecosystem: "github-actions" 7 | directory: "/" 8 | schedule: 9 | # Check for updates to GitHub Actions every week 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /.github/workflows/check-lint.yml: -------------------------------------------------------------------------------- 1 | name: Check Lint 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | check-lint: 9 | runs-on: ubuntu-22.04 10 | steps: 11 | - name: Install clang-format 12 | run: | 13 | sudo apt update 14 | sudo apt install clang-format-15 15 | - uses: actions/checkout@v4 16 | - run: ./scripts/check-lint.sh 17 | env: 18 | CLANG_FORMAT_PATH: "/usr/bin/clang-format-15" 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ main] 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | os: [ubuntu-22.04, macos-13] 14 | java: [ '17' ] 15 | runs-on: ${{ matrix.os }} 16 | steps: 17 | - name: Checkout current branch (full) 18 | uses: actions/checkout@v4 19 | - name: Setup Java 20 | uses: actions/setup-java@v4 21 | with: 22 | distribution: 'adopt' 23 | java-version: ${{ matrix.java }} 24 | cache: 'sbt' 25 | - uses: sbt/setup-sbt@v1 26 | - name: Setup (Linux) 27 | if: startsWith(matrix.os, 'ubuntu') 28 | run: | 29 | curl -fsSL https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz \ 30 | -o ~/libtensorflow.tar.gz 31 | mkdir -p ~/tensorflow && tar -xzf ~/libtensorflow.tar.gz -C ~/tensorflow 32 | - name: Setup (macOS) 33 | if: startsWith(matrix.os, 'macos') 34 | run: | 35 | brew update 36 | brew install libtensorflow 37 | - name: Run tests (Linux) Java ${{ matrix.java }} 38 | if: startsWith(matrix.os, 'ubuntu') 39 | run: | 40 | export LIBRARY_PATH=$LIBRARY_PATH:~/tensorflow/lib 41 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/tensorflow/lib 42 | export C_INCLUDE_PATH=$C_INCLUDE_PATH:~/tensorflow/include 43 | sbt +test 44 | - name: Run tests (macOS) Java ${{ matrix.java }} 45 | if: startsWith(matrix.os, 'macos') 46 | run: sbt +test 47 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | branches: [ main ] 5 | tags: ["*"] 6 | jobs: 7 | publish: 8 | runs-on: ubuntu-22.04 9 | steps: 10 | - uses: actions/checkout@v4 11 | with: 12 | fetch-depth: 0 13 | - uses: actions/setup-java@v4 14 | with: 15 | distribution: 'adopt' 16 | java-version: '17' 17 | cache: 'sbt' 18 | - run: sbt ci-release 19 | env: 20 | PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} 21 | PGP_SECRET: ${{ secrets.PGP_SECRET }} 22 | SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} 23 | SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # sbt 2 | target/ 3 | 4 | # scala-cli 5 | /.scala-build/ 6 | 7 | # metals 8 | /.bloop/ 9 | /.metals/ 10 | /project/**/metals.sbt 11 | 12 | # Build Server Protocol, used by sbt 13 | /.bsp/ 14 | 15 | # vscode 16 | /.vscode/ 17 | 18 | # scripts generated 19 | /scripts/.coursier 20 | /scripts/.scalafmt* 21 | 22 | # clangd 23 | **/compile_flags.txt 24 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | # Test upgrades: $ scripts/scalafmt --test 2> diff.txt 2 | version = 3.9.7 3 | runner.dialect = scala3 4 | preset = default 5 | 6 | # Match Scala Native 7 | docstrings.style = AsteriskSpace 8 | assumeStandardLibraryStripMargin = true 9 | project.git = true 10 | 11 | # Added for CI error via --test option 12 | runner.fatalWarnings = true 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | _Copyright (c) 2017-2019 Eric K Richardson_ 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | ### Origin of this work 16 | 17 | This project brings the TensorFlow to Scala Native. 18 | The code is written in Scala to interface with the C library interface 19 | to TensorFlow. As such this is a derivative work so the applicable licenses 20 | follow to attribute to the original work as best as can be determined. 21 | 22 | ### License notice for TensorFlow 23 | 24 | TensorFlow is developed by Google and licensed via the Apache License, 25 | Version 2.0. 26 | 27 | 28 | ### License notice for Documentation 29 | 30 | The documentation comes from source files and the TensorFlow 31 | website and should be attributed to the original sources. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stensorflow - Scala Native TensorFlow 2 | ![CI](https://github.com/ekrich/stensorflow/workflows/CI/badge.svg) 3 | 4 | This library implements the C TensorFlow API adapted for the Scala Native platform. 5 | 6 | Scala Native is a unique platform that 7 | marries the high level language of Scala but compiles to native code with a 8 | lightweight managed runtime which includes a state of the art garbage collector. 9 | The combination allows for great programming and the ability to use high 10 | performance C language libraries like TensorFlow. 11 | 12 | Scala Native uses the Scala compiler to produce 13 | [NIR](https://scala-native.readthedocs.io/en/latest/contrib/nir.html) 14 | (Native Intermediate Representation) that is optimized and then 15 | converted to [LLVM IR](http://llvm.org/). Finally LLVM code is optimized 16 | and compiled by [Clang](http://clang.llvm.org/) to produce a native executable. 17 | 18 | ## Getting started 19 | [![Maven Central](https://img.shields.io/maven-central/v/org.ekrich/stensorflow_native0.5_3.svg)](https://maven-badges.herokuapp.com/maven-central/org.ekrich/stensorflow_native0.5_3) 20 | 21 | You can use the Gitter8 template [stensorflow.g8](https://github.com/ekrich/stensorflow.g8#stensorflowg8) to get started. The linked directions will send you back here to install the Tensorflow library below but it should be easier overall. 22 | 23 | If you are already familiar with Scala Native you can jump right in by adding the following dependency in your `sbt` build file. Refer to the `TensorflowTest.scala` source in this repository or in the template referred to above for an example. 24 | 25 | ```scala 26 | libraryDependencies += "org.ekrich" %%% "stensorflow" % "x.y.z" 27 | ``` 28 | 29 | To use in `sbt`, replace `x.y.z` with the version from Maven Central badge above. 30 | All available versions can be seen at the [Maven Repository](https://mvnrepository.com/artifact/org.ekrich/stensorflow). 31 | 32 | Otherwise follow the [Getting Started](https://scala-native.readthedocs.io/en/latest/user/setup.html) 33 | instructions for Scala Native if you are not already setup. 34 | 35 | ## Scala Build Versions 36 | 37 | | Scala Version | Native (0.5.x) | 38 | | ---------------------- | :-------------------: | 39 | | 3.3.x (LTS) | ✅ | 40 | 41 | 42 | * Use version `0.5.0` for Scala Native `0.5.x`. 43 | * Use version `0.3.0` for Scala Native `0.4.9+`. 44 | 45 | Note: Refer to release notes for older versions of Scala and Scala Native 46 | 47 | ## Tensorflow library 48 | 49 | The TensorFlow C library is required and the current version is `2.16.1` 50 | with `2.5.0` pre-built for Linux. 51 | 52 | * Linux/Ubuntu can TensorFlow following the following directions: 53 | 54 | https://www.tensorflow.org/install/lang_c 55 | 56 | The essential directions are repeated here replacing `` with the above: 57 | 58 | ``` 59 | $ FILENAME=libtensorflow-cpu-linux-x86_64-.tar.gz 60 | $ wget -q --no-check-certificate https://storage.googleapis.com/tensorflow/libtensorflow/${FILENAME} 61 | $ sudo tar -C /usr/local -xzf ${FILENAME} 62 | $ sudo ldconfig /usr/local/lib 63 | ``` 64 | 65 | * macOS can install TensorFlow using [Homebrew](https://formulae.brew.sh/formula/libtensorflow) 66 | which will install into the `/usr/local/Cellar/libtensorflow/` directory. 67 | 68 | Note: macOS 12 or greater is recommended to install TensorFlow via 69 | Homebrew and is used in CI. Tensorflow `2.16.1` is built for macos `13.1` so you 70 | will get a linking warning if using an older OS version. 71 | 72 | ``` 73 | $ brew install libtensorflow 74 | ``` 75 | 76 | * Other OSes need to have `libtensorflow` available on the system. 77 | 78 | ## Usage and Help 79 | [![scaladoc](https://www.javadoc.io/badge/org.ekrich/stensorflow_native0.5_3.svg?label=scaladoc)](https://www.javadoc.io/doc/org.ekrich/stensorflow_native0.5_3) 80 | [![Discord](https://img.shields.io/discord/633356833498595365.svg?label=&logo=discord&logoColor=ffffff&color=404244&labelColor=6A7EC2)](https://discord.gg/XSj6hQs) 81 | 82 | Reference the link above for Scaladoc. The documentation is a little sparse but hopefully will improve with time. 83 | 84 | After `sbt` is installed and any other Scala Native prerequisites are met you can use the following Gitter G8 template instructions to get a fully functional Scala Native application with an example in the body of the main program. 85 | 86 | ``` 87 | $ sbt new ekrich/stensorflow.g8 88 | $ cd 89 | $ sbt run 90 | ``` 91 | 92 | In addition, look at the [stensorflow unit tests](https://github.com/ekrich/stensorflow/blob/main/stensorflow/src/test/scala/org/ekrich/tensorflow/unsafe/TensorflowTest.scala) for other examples of usage. 93 | 94 | ## TensorFlow References and External Documentation 95 | 96 | [TensorFlow Website](https://www.tensorflow.org/)
97 | [TensorFlow for JVM using JNI](https://platanios.org/tensorflow_scala/) 98 | 99 | ## Tensorflow for Scala Native contributors 100 | 101 | The Language Server `clangd` can be used to help development using VSCode or other editors. For VSCode see the `clangd` plugin on the [Visual Studio Marketplace](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) for more info. 102 | 103 | Add a `compile_flags.txt` with the following content to the `stensorflow/src/main/resources/scala-native` directory. 104 | 105 | ``` 106 | # Tensorflow Setup 107 | # Standard path on macOS arm 108 | -I 109 | /opt/homebrew/include 110 | ``` 111 | 112 | Change the path to match your include path. There is a small amount of official documentation that contains info about [compile_flags.txt](https://clang.llvm.org/docs/JSONCompilationDatabase.html). Otherwise some other info can be found online. 113 | 114 | ## Versions 115 | 116 | Release [0.5.0](https://github.com/ekrich/stensorflow/releases/tag/v0.5.0) - (2024-04-11)
117 | Release [0.4.0](https://github.com/ekrich/stensorflow/releases/tag/v0.4.0) - (2024-03-01)
118 | Release [0.3.0](https://github.com/ekrich/stensorflow/releases/tag/v0.3.0) - (2022-11-29)
119 | Release [0.2.0](https://github.com/ekrich/stensorflow/releases/tag/v0.2.0) - (2021-12-13)
120 | Release [0.1.0](https://github.com/ekrich/stensorflow/releases/tag/v0.1.0) - (2021-07-02)
121 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // stensorflow build 2 | val scala3 = "3.3.6" 3 | 4 | val versionsNative = Seq(scala3) 5 | 6 | ThisBuild / scalaVersion := scala3 7 | ThisBuild / crossScalaVersions := versionsNative 8 | ThisBuild / versionScheme := Some("early-semver") 9 | 10 | inThisBuild( 11 | List( 12 | description := "TensorFlow Interface for Scala Native", 13 | organization := "org.ekrich", 14 | homepage := Some(url("https://github.com/ekrich/stensorflow")), 15 | licenses := List( 16 | "Apache-2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0") 17 | ), 18 | developers := List( 19 | Developer( 20 | id = "ekrich", 21 | name = "Eric K Richardson", 22 | email = "ekrichardson@gmail.com", 23 | url = url("http://github.ekrich.org/") 24 | ) 25 | ) 26 | ) 27 | ) 28 | 29 | lazy val commonSettings = Seq( 30 | testOptions += Tests.Argument(TestFrameworks.JUnit, "-a", "-s", "-v"), 31 | logLevel := Level.Info // Info, Debug 32 | ) 33 | 34 | lazy val root = project 35 | .in(file(".")) 36 | .settings( 37 | name := "stensorflow-root", 38 | crossScalaVersions := Nil, 39 | publish / skip := true, 40 | doc / aggregate := false, 41 | doc := (stensorflow / Compile / doc).value, 42 | packageDoc / aggregate := false, 43 | packageDoc := (stensorflow / Compile / packageDoc).value 44 | ) 45 | .aggregate(stensorflow) 46 | 47 | lazy val stensorflow = project 48 | .in(file("stensorflow")) 49 | .settings( 50 | crossScalaVersions := versionsNative, 51 | commonSettings 52 | ) 53 | .enablePlugins(ScalaNativePlugin, ScalaNativeJUnitPlugin) 54 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.11.1 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers ++= Resolver.sonatypeOssRepos("snapshots") 2 | 3 | // Current releases 4 | addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.5.8") 5 | addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.11.1") 6 | addSbtPlugin("com.github.sbt" % "sbt-dynver" % "5.1.0") 7 | -------------------------------------------------------------------------------- /scripts/check-lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | echo 'Running clangfmt...' 6 | scripts/clangfmt --test 7 | echo 'Running scalafmt...' 8 | scripts/scalafmt --test 9 | -------------------------------------------------------------------------------- /scripts/clangfmt: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Format C/C++ code using clang-format. 4 | # 5 | # To ensure reproducible formatting this script checks that clang-format 6 | # matches the lowest version number of LLVM supported by Scala Native. 7 | # 8 | # Usage: $0 [--test] 9 | # 10 | # Set CLANG_FORMAT_PATH to configure path to clang-format. 11 | 12 | set -euo pipefail 13 | 14 | # The minimum version of clang-format with the new options 15 | CLANG_FORMAT_VERSION=15 16 | 17 | die() { 18 | while [ "$#" -gt 0 ]; do 19 | echo >&2 "$1"; shift 20 | done 21 | exit 1 22 | } 23 | 24 | # avoid unbound var 25 | version= 26 | 27 | check_clang_format_version() { 28 | cmd="$1" 29 | # version can be in 3rd or 4th position after the word "version" 30 | version=$("$cmd" --version \ 31 | | grep -E -i -o " version [0-9]+.[0-9]+" \ 32 | | grep -E -i -o "[0-9]+.[0-9]+") 33 | 34 | major=${version%%.*} 35 | [ $major -ge $CLANG_FORMAT_VERSION ] 36 | } 37 | 38 | clang_format= 39 | 40 | if [ -n "${CLANG_FORMAT_PATH:-}" ]; then 41 | if [ ! -e "$(type -P "${CLANG_FORMAT_PATH}")" ]; then 42 | echo "CLANG_FORMAT_PATH='$CLANG_FORMAT_PATH' does not exist or is not executable" >&2 43 | else 44 | if check_clang_format_version "$CLANG_FORMAT_PATH"; then 45 | clang_format="$CLANG_FORMAT_PATH" 46 | else 47 | echo "CLANG_FORMAT_PATH='$CLANG_FORMAT_PATH'" >&2 48 | fi 49 | fi 50 | else 51 | if [ ! -e "$(type -P clang-format)" ]; then 52 | echo "clang-format is not installed or not in the PATH." >&2 53 | else 54 | check_clang_format_version "clang-format" && \ 55 | clang_format=clang-format 56 | fi 57 | fi 58 | 59 | if [ -z "$clang_format" ]; then 60 | die "clang-format version '$CLANG_FORMAT_VERSION' expected, but version '$version' found." \ 61 | "Install LLVM version '$CLANG_FORMAT_VERSION' and rerun." \ 62 | "Hint: export CLANG_FORMAT_PATH='/path/to/clang-format'" 63 | fi 64 | 65 | test_mode= 66 | 67 | while [ "$#" -gt 0 ]; do 68 | arg="$1" 69 | case "$arg" in 70 | --test) test_mode=true; shift ;; 71 | --*) die "Unknown argument: $arg" "Usage: $0 [--test]" ;; 72 | *) break ;; 73 | esac 74 | done 75 | 76 | # Use this block for version 10 and above 77 | if [ "$test_mode" = true ]; then 78 | opts="--dry-run" 79 | err="--Werror" 80 | else 81 | opts="-i" 82 | err= 83 | fi 84 | 85 | if [ "$#" -gt 0 ]; then 86 | "$clang_format" --style=file "$opts" "$@" 87 | else 88 | find . -name "*.[ch]" -or -name "*.cpp" -or -name "*.hpp" | \ 89 | xargs "$clang_format" --style=file "$opts" $err || \ 90 | die "C/C++ code formatting changes detected" \ 91 | "Run '$0' to reformat." 92 | fi 93 | -------------------------------------------------------------------------------- /scripts/git-release: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Release script from Olafur 3 | set -eux 4 | version=$1 5 | # f option (force) allows a republish with the same versions 6 | git tag -af "v$version" -m "v$version" && git push -f origin v$version 7 | -------------------------------------------------------------------------------- /scripts/scalafmt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # uncomment to debug 4 | # set -x 5 | 6 | HERE="`dirname $0`" 7 | VERSION=$(sed -nre "s#\s*version[^0-9]+([0-9.]+)#\1#p" $HERE/../.scalafmt.conf) 8 | COURSIER="$HERE/.coursier" 9 | SCALAFMT="$HERE/.scalafmt-$VERSION" 10 | 11 | if [ ! -f $COURSIER ]; then 12 | curl -L -o $COURSIER https://git.io/coursier-cli 13 | chmod +x $COURSIER 14 | fi 15 | 16 | if [ ! -f $SCALAFMT ]; then 17 | $COURSIER bootstrap org.scalameta:scalafmt-cli_2.13:$VERSION -r sonatype:snapshots --main org.scalafmt.cli.Cli -o $SCALAFMT 18 | chmod +x $SCALAFMT 19 | fi 20 | 21 | $SCALAFMT "$@" 22 | -------------------------------------------------------------------------------- /stensorflow/src/main/resources/scala-native/tensorflow.c: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | * Converter functions for TF_Input and TF_Output pass by reference 4 | */ 5 | 6 | #include 7 | 8 | TF_Output *scalanative_TF_OperationInput(TF_Input *oper_in, 9 | TF_Output *oper_out) { 10 | TF_Output out = TF_OperationInput(*oper_in); 11 | oper_out->index = out.index; 12 | oper_out->oper = out.oper; 13 | return oper_out; 14 | } 15 | 16 | int scalanative_TF_OperationOutputNumConsumers(TF_Output *oper_out) { 17 | return TF_OperationOutputNumConsumers(*oper_out); 18 | } 19 | 20 | int scalanative_TF_OperationOutputConsumers(TF_Output *oper_out, 21 | TF_Input *consumers, 22 | int max_consumers) { 23 | return TF_OperationOutputConsumers(*oper_out, consumers, max_consumers); 24 | } 25 | 26 | void scalanative_TF_ImportGraphDefOptionsAddInputMapping( 27 | TF_ImportGraphDefOptions *opts, char *src_name, int src_index, 28 | TF_Output *dst) { 29 | return TF_ImportGraphDefOptionsAddInputMapping(opts, src_name, src_index, 30 | *dst); 31 | } 32 | 33 | unsigned char scalanative_TF_TryEvaluateConstant(TF_Graph *graph, 34 | TF_Output *output, 35 | TF_Tensor **result, 36 | TF_Status *status) { 37 | return TF_TryEvaluateConstant(graph, *output, result, status); 38 | } 39 | 40 | void scalanative_TF_GraphSetTensorShape(TF_Graph *graph, TF_Output *output, 41 | int64_t *dims, int num_dims, 42 | TF_Status *status) { 43 | return TF_GraphSetTensorShape(graph, *output, dims, num_dims, status); 44 | } 45 | 46 | int scalanative_TF_GraphGetTensorNumDims(TF_Graph *graph, TF_Output *output, 47 | TF_Status *status) { 48 | return TF_GraphGetTensorNumDims(graph, *output, status); 49 | } 50 | 51 | void scalanative_TF_GraphGetTensorShape(TF_Graph *graph, TF_Output *output, 52 | int64_t *dims, int num_dims, 53 | TF_Status *status) { 54 | return TF_GraphGetTensorShape(graph, *output, dims, num_dims, status); 55 | } 56 | 57 | void scalanative_TF_AddInput(TF_OperationDescription *desc, TF_Output *input) { 58 | return TF_AddInput(desc, *input); 59 | } 60 | 61 | TF_DataType scalanative_TF_OperationOutputType(TF_Output *oper_out) { 62 | return TF_OperationOutputType(*oper_out); 63 | } -------------------------------------------------------------------------------- /stensorflow/src/main/scala/org/ekrich/tensorflow/unsafe/tensorflow.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * - Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | * - Copyright 2017-2022 Eric K Richardson 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); you may not 6 | * use this file except in compliance with the License. You may obtain a copy 7 | * of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13 | * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14 | * License for the specific language governing permissions and limitations 15 | * under the License. 16 | */ 17 | package org.ekrich.tensorflow.unsafe 18 | 19 | import scalanative.unsafe._ 20 | 21 | /** Enums used in the API 22 | */ 23 | object tensorflowEnums { 24 | 25 | /** TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. 26 | * The enum values here are identical to corresponding values in types.proto. 27 | */ 28 | type TF_DataType = CInt 29 | final val TF_FLOAT: TF_DataType = 1 30 | final val TF_DOUBLE: TF_DataType = 2 31 | // Int32 tensors are always in 'host' memory. 32 | final val TF_INT32: TF_DataType = 3 33 | final val TF_UINT8: TF_DataType = 4 34 | final val TF_INT16: TF_DataType = 5 35 | final val TF_INT8: TF_DataType = 6 36 | final val TF_STRING: TF_DataType = 7 37 | final val TF_COMPLEX64: TF_DataType = 8 // Single-precision complex 38 | // Old identifier kept for API backwards compatibility 39 | final val TF_COMPLEX: TF_DataType = 8 40 | final val TF_INT64: TF_DataType = 9 41 | final val TF_BOOL: TF_DataType = 10 42 | final val TF_QINT8: TF_DataType = 11 // Quantized int8 43 | final val TF_QUINT8: TF_DataType = 12 // Quantized uint8 44 | final val TF_QINT32: TF_DataType = 13 // Quantized int32 45 | // Float32 truncated to 16 bits. Only for cast ops. 46 | final val TF_BFLOAT16: TF_DataType = 14 47 | final val TF_QINT16: TF_DataType = 15 // Quantized int16 48 | final val TF_QUINT16: TF_DataType = 16 // Quantized uint16 49 | final val TF_UINT16: TF_DataType = 17 50 | final val TF_COMPLEX128: TF_DataType = 18 // Double-precision complex 51 | final val TF_HALF: TF_DataType = 19 52 | final val TF_RESOURCE: TF_DataType = 20 53 | final val TF_VARIANT: TF_DataType = 21 54 | final val TF_UINT32: TF_DataType = 22 55 | final val TF_UINT64: TF_DataType = 23 56 | 57 | /** TF_Code holds an error code. The enum values here are identical to 58 | * corresponding values in error_codes.proto. 59 | */ 60 | type TF_Code = CInt 61 | final val TF_OK: TF_Code = 0 62 | final val TF_CANCELLED: TF_Code = 1 63 | final val TF_UNKNOWN: TF_Code = 2 64 | final val TF_INVALID_ARGUMENT: TF_Code = 3 65 | final val TF_DEADLINE_EXCEEDED: TF_Code = 4 66 | final val TF_NOT_FOUND: TF_Code = 5 67 | final val TF_ALREADY_EXISTS: TF_Code = 6 68 | final val TF_PERMISSION_DENIED: TF_Code = 7 69 | final val TF_UNAUTHENTICATED: TF_Code = 16 70 | final val TF_RESOURCE_EXHAUSTED: TF_Code = 8 71 | final val TF_FAILED_PRECONDITION: TF_Code = 9 72 | final val TF_ABORTED: TF_Code = 10 73 | final val TF_OUT_OF_RANGE: TF_Code = 11 74 | final val TF_UNIMPLEMENTED: TF_Code = 12 75 | final val TF_INTERNAL: TF_Code = 13 76 | final val TF_UNAVAILABLE: TF_Code = 14 77 | final val TF_DATA_LOSS: TF_Code = 15 78 | 79 | /** TF_AttrType describes the type of the value of an attribute on an 80 | * operation. 81 | */ 82 | type TF_AttrType = CInt 83 | final val TF_ATTR_STRING: TF_AttrType = 0 84 | final val TF_ATTR_INT: TF_AttrType = 1 85 | final val TF_ATTR_FLOAT: TF_AttrType = 2 86 | final val TF_ATTR_BOOL: TF_AttrType = 3 87 | final val TF_ATTR_TYPE: TF_AttrType = 4 88 | final val TF_ATTR_SHAPE: TF_AttrType = 5 89 | final val TF_ATTR_TENSOR: TF_AttrType = 6 90 | final val TF_ATTR_PLACEHOLDER: TF_AttrType = 7 91 | final val TF_ATTR_FUNC: TF_AttrType = 8 92 | } 93 | 94 | import tensorflowEnums._ 95 | 96 | /** C API for TensorFlow. 97 | * 98 | * The API leans towards simplicity and uniformity instead of convenience since 99 | * most usage will be by language specific wrappers. 100 | * 101 | * Conventions: 102 | * - We use the prefix TF_ for everything in the API. 103 | * - Objects are always passed around as pointers to opaque structs and these 104 | * structs are allocated/deallocated via the API. 105 | * - TF_Status holds error information. It is an object type and therefore is 106 | * passed around as a pointer to an opaque struct as mentioned above. 107 | * - Every call that has a TF_Status* argument clears it on success and fills 108 | * it with error info on failure. 109 | * - unsigned char is used for booleans (instead of the 'bool' type). In C++ 110 | * bool is a keyword while in C99 bool is a macro defined in stdbool.h. It 111 | * is possible for the two to be inconsistent. For example, neither the C99 112 | * nor the C++11 standard force a byte size on the bool type, so the macro 113 | * defined in stdbool.h could be inconsistent with the bool keyword in C++. 114 | * Thus, the use of stdbool.h is avoided and unsigned char is used instead. 115 | * - size_t is used to represent byte sizes of objects that are materialized 116 | * in the address space of the calling process. 117 | * - int is used as an index into arrays. 118 | * - Deletion functions are safe to call on nullptr. 119 | * 120 | * Questions left to address: 121 | * - Might at some point need a way for callers to provide their own Env. 122 | * - Maybe add TF_TensorShape that encapsulates dimension info. 123 | * 124 | * Design decisions made: 125 | * - Backing store for tensor memory has an associated deallocation function. 126 | * This deallocation function will point to client code for tensors 127 | * populated by the client. So the client can do things like shadowing a 128 | * numpy array. 129 | * - We do not provide TF_OK since it is not strictly necessary and we are 130 | * not optimizing for convenience. 131 | * - We make assumption that one session has one graph. This should be fine 132 | * since we have the ability to run sub-graphs. 133 | * - We could allow NULL for some arguments (e.g., NULL options arg). However 134 | * since convenience is not a primary goal, we don't do this. 135 | * - Devices are not in this API. Instead, they are created/used internally 136 | * and the API just provides high level controls over the number of devices 137 | * of each type. 138 | */ 139 | @link("tensorflow") 140 | @extern 141 | object tensorflow { 142 | 143 | type int64_t = CLongLong 144 | type uint64_t = CUnsignedLongLong 145 | 146 | /** TF_Status holds error information. It either has an OK code, or else an 147 | * error code with an associated error message. 148 | */ 149 | type TF_Status = CStruct0 150 | 151 | /** Represents a computation graph. Graphs may be shared between sessions. 152 | * Graphs are thread-safe when used as directed below. 153 | */ 154 | type TF_Graph = CStruct0 155 | 156 | /** Operation that has been added to the graph. Valid until the graph is 157 | * deleted -- in particular adding a new operation to the graph does not 158 | * invalidate old TF_Operation* pointers. 159 | */ 160 | type TF_Operation = CStruct0 161 | 162 | /** Operation being built. The underlying graph must outlive this. 163 | */ 164 | type TF_OperationDescription = CStruct0 165 | 166 | /** TF_Tensor holds a multi-dimensional array of elements of a single data 167 | * type. For all types other than TF_STRING, the data buffer stores elements 168 | * in row major order. E.g. if data is treated as a vector of TF_DataType: 169 | * 170 | * - element 0: index (0, ..., 0) 171 | * - element 1: index (0, ..., 1) 172 | * - ... 173 | * 174 | * The format for TF_STRING tensors is: 175 | * - start_offset: array[uint64] 176 | * - data: byte[...] 177 | * 178 | * The string length (as a varint), followed by the contents of the string is 179 | * encoded at data[start_offset[i]]]. TF_StringEncode and TF_StringDecode 180 | * facilitate this encoding. 181 | */ 182 | type TF_Tensor = CStruct0 183 | 184 | /** TF_SessionOptions holds options that can be passed during session 185 | * creation. 186 | */ 187 | type TF_SessionOptions = CStruct0 188 | 189 | /** TF_Buffer holds a pointer to a block of data and its associated length. 190 | * Typically, the data consists of a serialized protocol buffer, but other 191 | * data may also be held in a buffer. 192 | * 193 | * By default, TF_Buffer itself does not do any memory management of the 194 | * pointed-to block. If need be, users of this struct should specify how to 195 | * deallocate the block by setting the `data_deallocator` function pointer. 196 | */ 197 | type TF_Buffer = 198 | CStruct3[Ptr[Byte], CSize, CFuncPtr2[Ptr[Byte], CSize, Unit]] 199 | 200 | /** Represents a specific input of an operation. 201 | */ 202 | type TF_Input = CStruct2[Ptr[TF_Operation], CInt] 203 | 204 | /** Represents a specific output of an operation. 205 | */ 206 | type TF_Output = CStruct2[Ptr[TF_Operation], CInt] 207 | 208 | /** TF_Function is a grouping of operations with defined inputs and outputs. 209 | * Once created and added to graphs, functions can be invoked by creating an 210 | * operation whose operation type matches the function name. 211 | */ 212 | type TF_Function = CStruct0 213 | 214 | /** Function definition options. 215 | */ 216 | type TF_FunctionOptions = CStruct0 217 | 218 | /** TF_AttrMetadata describes the value of an attribute on an operation. 219 | */ 220 | type TF_AttrMetadata = CStruct4[ 221 | /** A boolean: 1 if the attribute value is a list, 0 otherwise. */ 222 | CUnsignedChar, 223 | /** Length of the list if is_list is true. Undefined otherwise. */ 224 | int64_t, 225 | /** Type of elements of the list if is_list != 0. 226 | * 227 | * Type of the single value stored in the attribute if is_list == 0. 228 | */ 229 | TF_AttrType, 230 | /** Total size the attribute value. The units of total_size depend on 231 | * is_list and type. 232 | * - (1) If type == TF_ATTR_STRING and is_list == 0 then total_size is 233 | * the byte size of the string valued attribute. 234 | * - (2) If type == TF_ATTR_STRING and is_list == 1 then total_size is 235 | * the cumulative byte size of all the strings in the list. 236 | * - (3) If type == TF_ATTR_SHAPE and is_list == 0 then total_size is the 237 | * number of dimensions of the shape valued attribute, or -1 if its 238 | * rank is unknown. 239 | * - (4) If type == TF_ATTR_SHAPE and is_list == 1 then total_size is the 240 | * cumulative number of dimensions of all shapes in the list. 241 | * - (5) Otherwise, total_size is undefined. 242 | */ 243 | int64_t 244 | ] 245 | 246 | type TF_WhileParams = CStruct8[ 247 | /** The number of inputs to the while loop, i.e. the number of loop 248 | * variables. This is the size of cond_inputs, body_inputs, and 249 | * body_outputs. 250 | */ 251 | CInt, // ninputs 252 | /** The while condition graph. The inputs are the current values of the loop 253 | * variables. The output should be a scalar boolean. 254 | */ 255 | Ptr[TF_Graph], // cond_graph 256 | Ptr[TF_Output], // cond_inputs 257 | Ptr[TF_Output], // cond_output // TF_output 258 | /** The loop body graph. The inputs are the current values of the loop 259 | * variables. The outputs are the updated values of the loop variables. 260 | */ 261 | Ptr[TF_Graph], // body_graph 262 | Ptr[TF_Output], // body_inputs 263 | Ptr[TF_Output], // body_outputs 264 | /** The loop body graph. The inputs are the current values of the loop 265 | * variables. The outputs are the updated values of the loop variables. 266 | */ 267 | CString // name 268 | ] 269 | 270 | /** TF_Version returns a string describing version information of the 271 | * TensorFlow library. TensorFlow using semantic versioning. 272 | */ 273 | def TF_Version(): CString = extern 274 | 275 | /** TF_DataTypeSize returns the sizeof() for the underlying type corresponding 276 | * to the given TF_DataType enum value. Returns 0 for variable length types 277 | * (eg. TF_STRING) or on failure. 278 | */ 279 | def TF_DataTypeSize(value: TF_DataType): CSize = extern 280 | 281 | /** Return a new status object. 282 | */ 283 | def TF_NewStatus(): Ptr[TF_Status] = extern 284 | 285 | /** Delete a previously created status object. 286 | */ 287 | def TF_DeleteStatus(status: Ptr[TF_Status]): Unit = extern 288 | 289 | /** Record in *s. Any previous information is lost. A common use 290 | * is to clear a status: TF_SetStatus(s, TF_OK, ""); 291 | */ 292 | def TF_SetStatus(s: Ptr[TF_Status], code: TF_Code, msg: CString): Unit = 293 | extern 294 | 295 | /** Return the code record in *s. 296 | */ 297 | def TF_GetCode(s: Ptr[TF_Status]): TF_Code = extern 298 | 299 | /** Return a pointer to the (null-terminated) error message in *s. The return 300 | * value points to memory that is only usable until the next mutation to *s. 301 | * Always returns an empty string if TF_GetCode(s) is TF_OK. 302 | */ 303 | def TF_Message(s: Ptr[TF_Status]): CString = extern 304 | 305 | /** Makes a copy of the input and sets an appropriate deallocator. Useful for 306 | * passing in read-only, input protobufs. 307 | */ 308 | def TF_NewBufferFromString( 309 | proto: Ptr[Byte], 310 | proto_len: CSize 311 | ): Ptr[TF_Buffer] = extern 312 | 313 | /** Useful for passing *out* a protobuf. 314 | */ 315 | def TF_NewBuffer(): Ptr[TF_Buffer] = extern 316 | def TF_DeleteBuffer(buffer: Ptr[TF_Buffer]): Unit = extern 317 | def TF_GetBuffer(buffer: Ptr[TF_Buffer]): TF_Buffer = extern 318 | 319 | /** Return a new tensor that holds the bytes data[0,len-1]. 320 | * 321 | * The data will be deallocated by a subsequent call to TF_DeleteTensor via: 322 | * (*deallocator)(data, len, deallocator_arg) Clients must provide a custom 323 | * deallocator function so they can pass in memory managed by something like 324 | * numpy. 325 | * 326 | * May return NULL (and invoke the deallocator) if the provided data buffer 327 | * (data, len) is inconsistent with a tensor of the given TF_DataType and the 328 | * shape specified by (dims, num_dims). 329 | */ 330 | def TF_NewTensor( 331 | value: TF_DataType, 332 | dims: Ptr[int64_t], 333 | num_dims: CInt, 334 | data: Ptr[Byte], 335 | len: CSize, 336 | deallocator: CFuncPtr3[Ptr[Byte], CSize, Ptr[Byte], Unit], 337 | deallocator_arg: Ptr[Byte] 338 | ): Ptr[TF_Tensor] = extern 339 | 340 | /** Allocate and return a new Tensor. 341 | * 342 | * This function is an alternative to TF_NewTensor and should be used when 343 | * memory is allocated to pass the Tensor to the C API. The allocated memory 344 | * satisfies TensorFlow's memory alignment preferences and should be 345 | * preferred over calling malloc and free. 346 | * 347 | * The caller must set the Tensor values by writing them to the pointer 348 | * returned by TF_TensorData with length TF_TensorByteSize. 349 | */ 350 | def TF_AllocateTensor( 351 | value: TF_DataType, 352 | dims: Ptr[int64_t], 353 | num_dims: CInt, 354 | len: CSize 355 | ): Ptr[TF_Tensor] = extern 356 | 357 | /** Deletes `tensor` and returns a new TF_Tensor with the same content if 358 | * possible. Returns nullptr and leaves `tensor` untouched if not. 359 | */ 360 | def TF_TensorMaybeMove(tensor: Ptr[TF_Tensor]): Ptr[TF_Tensor] = extern 361 | 362 | /** Destroy a tensor. 363 | */ 364 | def TF_DeleteTensor(tensor: Ptr[TF_Tensor]): Unit = extern 365 | 366 | /** Return the type of a tensor element. 367 | */ 368 | def TF_TensorType(tensor: Ptr[TF_Tensor]): TF_DataType = extern 369 | 370 | /** Return the number of dimensions that the tensor has. 371 | */ 372 | def TF_NumDims(tensor: Ptr[TF_Tensor]): CInt = extern 373 | 374 | /** Return the length of the tensor in the "dim_index" dimension. REQUIRES: 0 375 | * <= dim_index < TF_NumDims(tensor) 376 | */ 377 | def TF_Dim(tensor: Ptr[TF_Tensor], dim_index: CInt): int64_t = extern 378 | 379 | /** Return the size of the underlying data in bytes. 380 | */ 381 | def TF_TensorByteSize(tensor: Ptr[TF_Tensor]): CSize = extern 382 | 383 | /** Return a pointer to the underlying data buffer. 384 | */ 385 | def TF_TensorData(tensor: Ptr[TF_Tensor]): Ptr[Byte] = extern 386 | 387 | /** Encode the string `src` (`src_len` bytes long) into `dst` in the format 388 | * required by TF_STRING tensors. Does not write to memory more than 389 | * `dst_len` bytes beyond `*dst`. `dst_len` should be at least 390 | * TF_StringEncodedSize(src_len). 391 | * 392 | * On success returns the size in bytes of the encoded string. Returns an 393 | * error into `status` otherwise. 394 | */ 395 | def TF_StringEncode( 396 | src: CString, 397 | src_len: CSize, 398 | dst: CString, 399 | dst_len: CSize, 400 | status: Ptr[TF_Status] 401 | ): CSize = extern 402 | 403 | /** Decode a string encoded using TF_StringEncode. 404 | * 405 | * On success, sets `*dst` to the start of the decoded string and `*dst_len` 406 | * to its length. Returns the number of bytes starting at `src` consumed 407 | * while decoding. `*dst` points to memory within the encoded buffer. On 408 | * failure, `*dst` and `*dst_len` are undefined and an error is set in 409 | * `status`. 410 | * 411 | * Does not read memory more than `src_len` bytes beyond `src`. 412 | */ 413 | def TF_StringDecode( 414 | src: CString, 415 | src_len: CSize, 416 | dst: Ptr[CString], 417 | dst_len: Ptr[CSize], 418 | status: Ptr[TF_Status] 419 | ): CSize = extern 420 | 421 | /** Return the size in bytes required to encode a string `len` bytes long into 422 | * a TF_STRING tensor. 423 | */ 424 | def TF_StringEncodedSize(len: CSize): CSize = extern 425 | 426 | /** Return a new options object. 427 | */ 428 | def TF_NewSessionOptions(): Ptr[TF_SessionOptions] = extern 429 | 430 | /** Set the target in TF_SessionOptions.options. target can be empty, a single 431 | * entry, or a comma separated list of entries. Each entry is in one of the 432 | * following formats: 433 | * - "local" 434 | * - ip:port 435 | * - host:port 436 | */ 437 | def TF_SetTarget(options: Ptr[TF_SessionOptions], target: CString): Unit = 438 | extern 439 | 440 | /** Set the config in TF_SessionOptions.options. config should be a serialized 441 | * tensorflow.ConfigProto proto. If config was not parsed successfully as a 442 | * ConfigProto, record the error information in *status. 443 | */ 444 | def TF_SetConfig( 445 | options: Ptr[TF_SessionOptions], 446 | proto: Ptr[Byte], 447 | proto_len: CSize, 448 | status: Ptr[TF_Status] 449 | ): Unit = extern 450 | 451 | /** Destroy an options object. 452 | */ 453 | def TF_DeleteSessionOptions(sessionOptions: Ptr[TF_SessionOptions]): Unit = 454 | extern 455 | 456 | /** Return a new graph object. 457 | */ 458 | def TF_NewGraph(): Ptr[TF_Graph] = extern 459 | 460 | /** Destroy an options object. Graph will be deleted once no more TFSession's 461 | * are referencing it. 462 | */ 463 | def TF_DeleteGraph(graph: Ptr[TF_Graph]): Unit = extern 464 | 465 | /** Sets the shape of the Tensor referenced by `output` in `graph` to the 466 | * shape described by `dims` and `num_dims`. 467 | * 468 | * If the number of dimensions is unknown, `num_dims` must be set to -1 and 469 | * `dims` can be null. If a dimension is unknown, the corresponding entry in 470 | * the `dims` array must be -1. 471 | * 472 | * This does not overwrite the existing shape associated with `output`, but 473 | * merges the input shape with the existing shape. For example, setting a 474 | * shape of [-1, 2] with an existing shape [2, -1] would set a final shape of 475 | * [2, 2] based on shape merging semantics. 476 | * 477 | * Returns an error into `status` if: 478 | * - `output` is not in `graph`. 479 | * - An invalid shape is being set (e.g., the shape being set is 480 | * incompatible with the existing shape). 481 | */ 482 | @name("scalanative_TF_GraphSetTensorShape") 483 | def TF_GraphSetTensorShape( 484 | graph: Ptr[TF_Graph], 485 | output: Ptr[TF_Output], // TF_output 486 | dims: Ptr[int64_t], 487 | num_dims: CInt, 488 | status: Ptr[TF_Status] 489 | ): Unit = extern 490 | 491 | /** Returns the number of dimensions of the Tensor referenced by `output` in 492 | * `graph`. 493 | * 494 | * If the number of dimensions in the shape is unknown, returns -1. 495 | * 496 | * Returns an error into `status` if: 497 | * - `output` is not in `graph`. 498 | */ 499 | @name("scalanative_TF_GraphGetTensorNumDims") 500 | def TF_GraphGetTensorNumDims( 501 | graph: Ptr[TF_Graph], 502 | output: Ptr[TF_Output], // TF_output 503 | status: Ptr[TF_Status] 504 | ): CInt = extern 505 | 506 | /** Returns the shape of the Tensor referenced by `output` in `graph` into 507 | * `dims`. `dims` must be an array large enough to hold `num_dims` entries 508 | * (e.g., the return value of TF_GraphGetTensorNumDims). 509 | * 510 | * If the number of dimensions in the shape is unknown or the shape is a 511 | * scalar, `dims` will remain untouched. Otherwise, each element of `dims` 512 | * will be set corresponding to the size of the dimension. An unknown 513 | * dimension is represented by `-1`. 514 | * 515 | * Returns an error into `status` if: 516 | * - `output` is not in `graph`. 517 | * - `num_dims` does not match the actual number of dimensions. 518 | */ 519 | @name("scalanative_TF_GraphGetTensorShape") 520 | def TF_GraphGetTensorShape( 521 | graph: Ptr[TF_Graph], 522 | output: Ptr[TF_Output], // TF_output 523 | dims: Ptr[int64_t], 524 | num_dims: CInt, 525 | status: Ptr[TF_Status] 526 | ): Unit = extern 527 | 528 | /** Operation will only be added to *graph when TF_FinishOperation() is called 529 | * (assuming TF_FinishOperation() does not return an error). *graph must not 530 | * be deleted until after TF_FinishOperation() is called. 531 | */ 532 | def TF_NewOperation( 533 | graph: Ptr[TF_Graph], 534 | op_type: CString, 535 | oper_name: CString 536 | ): Ptr[TF_OperationDescription] = extern 537 | 538 | /** Specify the device for `desc`. Defaults to empty, meaning unconstrained. 539 | */ 540 | def TF_SetDevice(desc: Ptr[TF_OperationDescription], device: CString): Unit = 541 | extern 542 | 543 | /** The calls to TF_AddInput and TF_AddInputList must match (in number, order, 544 | * and type) the op declaration. For example, the "Concat" op has 545 | * registration: 546 | * {{{ 547 | * REGISTER_OP("Concat") 548 | * .Input("concat_dim: int32") 549 | * .Input("values: N * T") 550 | * .Output("output: T") 551 | * .Attr("N: int >= 2") 552 | * .Attr("T: type"); 553 | * }}} 554 | * that defines two inputs, "concat_dim" and "values" (in that order). You 555 | * must use TF_AddInput() for the first input (since it takes a single 556 | * tensor), and TF_AddInputList() for the second input (since it takes a 557 | * list, even if you were to pass a list with a single tensor), as in: 558 | * {{{ 559 | * TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); 560 | * TF_Output concat_dim_input = {...}; 561 | * TF_AddInput(desc, concat_dim_input); 562 | * TF_Output values_inputs[5] = {{...}, ..., {...}}; 563 | * TF_AddInputList(desc,values_inputs, 5); 564 | * }}} 565 | * For inputs that take a single tensor. 566 | */ 567 | @name("scalanative_TF_AddInput") 568 | def TF_AddInput( 569 | desc: Ptr[TF_OperationDescription], 570 | input: Ptr[TF_Output] 571 | ): Unit = 572 | extern // TF_output 573 | 574 | /** For inputs that take a list of tensors. inputs must point to 575 | * TF_Output[num_inputs]. 576 | */ 577 | def TF_AddInputList( 578 | desc: Ptr[TF_OperationDescription], 579 | inputs: Ptr[TF_Output], 580 | num_inputs: CInt 581 | ): Unit = extern 582 | 583 | /** Call once per control input to `desc`. 584 | */ 585 | def TF_AddControlInput( 586 | desc: Ptr[TF_OperationDescription], 587 | input: Ptr[TF_Operation] 588 | ): Unit = extern 589 | 590 | /** Request that `desc` be co-located on the device where `op` is placed. 591 | * 592 | * Use of this is discouraged since the implementation of device placement is 593 | * subject to change. Primarily intended for internal libraries 594 | */ 595 | def TF_ColocateWith( 596 | desc: Ptr[TF_OperationDescription], 597 | op: Ptr[TF_Operation] 598 | ): Unit = extern 599 | 600 | /** Call some TF_SetAttr*() function for every attr that is not inferred from 601 | * an input and doesn't have a default value you wish to keep. 602 | * 603 | * `value` must point to a string of length `length` bytes. 604 | */ 605 | def TF_SetAttrString( 606 | desc: Ptr[TF_OperationDescription], 607 | attr_name: CString, 608 | value: Ptr[Byte], 609 | length: CSize 610 | ): Unit = extern 611 | 612 | /** `values` and `lengths` each must have lengths `num_values`. `values[i]` 613 | * must point to a string of length `lengths[i]` bytes. 614 | */ 615 | def TF_SetAttrStringList( 616 | desc: Ptr[TF_OperationDescription], 617 | attr_name: CString, 618 | values: Ptr[Ptr[Byte]], 619 | lengths: Ptr[CSize], 620 | num_values: CInt 621 | ): Unit = extern 622 | 623 | /** */ 624 | def TF_SetAttrInt( 625 | desc: Ptr[TF_OperationDescription], 626 | attr_name: CString, 627 | value: int64_t 628 | ): Unit = extern 629 | 630 | /** */ 631 | def TF_SetAttrIntList( 632 | desc: Ptr[TF_OperationDescription], 633 | attr_name: CString, 634 | values: Ptr[int64_t], 635 | num_values: CInt 636 | ): Unit = extern 637 | 638 | /** */ 639 | def TF_SetAttrFloat( 640 | desc: Ptr[TF_OperationDescription], 641 | attr_name: CString, 642 | value: CFloat 643 | ): Unit = extern 644 | 645 | /** */ 646 | def TF_SetAttrFloatList( 647 | desc: Ptr[TF_OperationDescription], 648 | attr_name: CString, 649 | values: Ptr[CFloat], 650 | num_values: CInt 651 | ): Unit = extern 652 | 653 | /** */ 654 | def TF_SetAttrBool( 655 | desc: Ptr[TF_OperationDescription], 656 | attr_name: CString, 657 | value: CUnsignedChar 658 | ): Unit = extern 659 | 660 | /** */ 661 | def TF_SetAttrBoolList( 662 | desc: Ptr[TF_OperationDescription], 663 | attr_name: CString, 664 | values: Ptr[CUnsignedChar], 665 | num_values: CInt 666 | ): Unit = extern 667 | 668 | /** */ 669 | def TF_SetAttrType( 670 | desc: Ptr[TF_OperationDescription], 671 | attr_name: CString, 672 | value: TF_DataType 673 | ): Unit = extern 674 | 675 | /** */ 676 | def TF_SetAttrTypeList( 677 | desc: Ptr[TF_OperationDescription], 678 | attr_name: CString, 679 | values: Ptr[TF_DataType], 680 | num_values: CInt 681 | ): Unit = extern 682 | 683 | /** Set a 'func' attribute to the specified name. `value` must point to a 684 | * string of length `length` bytes. 685 | */ 686 | def TF_SetAttrFuncName( 687 | desc: Ptr[TF_OperationDescription], 688 | attr_name: CString, 689 | value: CString, 690 | length: CSize 691 | ): Unit = extern 692 | 693 | /** Set `num_dims` to -1 to represent "unknown rank". Otherwise, `dims` points 694 | * to an array of length `num_dims`. `dims[i]` must be >= -1, with -1 meaning 695 | * "unknown dimension". 696 | */ 697 | def TF_SetAttrShape( 698 | desc: Ptr[TF_OperationDescription], 699 | attr_name: CString, 700 | dims: Ptr[int64_t], 701 | num_dims: CInt 702 | ): Unit = extern 703 | 704 | /** `dims` and `num_dims` must point to arrays of length `num_shapes`. Set 705 | * `num_dims[i]` to -1 to represent "unknown rank". Otherwise, `dims[i]` 706 | * points to an array of length `num_dims[i]`. `dims[i][j]` must be >= -1, 707 | * with -1 meaning "unknown dimension". 708 | */ 709 | def TF_SetAttrShapeList( 710 | desc: Ptr[TF_OperationDescription], 711 | attr_name: CString, 712 | dims: Ptr[Ptr[int64_t]], 713 | num_dims: Ptr[CInt], 714 | num_shapes: CInt 715 | ): Unit = extern 716 | 717 | /** `proto` must point to an array of `proto_len` bytes representing a 718 | * binary-serialized TensorShapeProto. 719 | */ 720 | def TF_SetAttrTensorShapeProto( 721 | desc: Ptr[TF_OperationDescription], 722 | attr_name: CString, 723 | proto: Ptr[Byte], 724 | proto_len: CSize, 725 | status: Ptr[TF_Status] 726 | ): Unit = extern 727 | 728 | /** `protos` and `proto_lens` must point to arrays of length `num_shapes`. 729 | * `protos[i]` must point to an array of `proto_lens[i]` bytes representing a 730 | * binary-serialized TensorShapeProto. 731 | */ 732 | def TF_SetAttrTensorShapeProtoList( 733 | desc: Ptr[TF_OperationDescription], 734 | attr_name: CString, 735 | protos: Ptr[Ptr[Byte]], 736 | proto_lens: Ptr[CSize], 737 | num_shapes: CInt, 738 | status: Ptr[TF_Status] 739 | ): Unit = extern 740 | 741 | /** */ 742 | def TF_SetAttrTensor( 743 | desc: Ptr[TF_OperationDescription], 744 | attr_name: CString, 745 | value: Ptr[TF_Tensor], 746 | status: Ptr[TF_Status] 747 | ): Unit = extern 748 | 749 | /** */ 750 | def TF_SetAttrTensorList( 751 | desc: Ptr[TF_OperationDescription], 752 | attr_name: CString, 753 | values: Ptr[Ptr[TF_Tensor]], 754 | num_values: CInt, 755 | status: Ptr[TF_Status] 756 | ): Unit = extern 757 | 758 | /** `proto` should point to a sequence of bytes of length `proto_len` 759 | * representing a binary serialization of an AttrValue protocol buffer. 760 | */ 761 | def TF_SetAttrValueProto( 762 | desc: Ptr[TF_OperationDescription], 763 | attr_name: CString, 764 | proto: Ptr[Byte], 765 | proto_len: CSize, 766 | status: Ptr[TF_Status] 767 | ): Unit = extern 768 | 769 | /** If this function succeeds: 770 | * - *status is set to an OK value, 771 | * - a TF_Operation is added to the graph, 772 | * - a non-null value pointing to the added operation is returned -- this 773 | * value is valid until the underlying graph is deleted. Otherwise: 774 | * - *status is set to a non-OK value, 775 | * - the graph is not modified, 776 | * - a null value is returned. In either case, it deletes `desc`. 777 | */ 778 | def TF_FinishOperation( 779 | desc: Ptr[TF_OperationDescription], 780 | status: Ptr[TF_Status] 781 | ): Ptr[TF_Operation] = extern 782 | 783 | /** TF_Operation functions. Operations are immutable once created, so these 784 | * are all query functions. 785 | */ 786 | def TF_OperationName(oper: Ptr[TF_Operation]): CString = extern 787 | 788 | /** */ 789 | def TF_OperationOpType(oper: Ptr[TF_Operation]): CString = extern 790 | 791 | /** */ 792 | def TF_OperationDevice(oper: Ptr[TF_Operation]): CString = extern 793 | 794 | /** */ 795 | def TF_OperationNumOutputs(oper: Ptr[TF_Operation]): CInt = extern 796 | 797 | /** */ 798 | @name("scalanative_TF_OperationOutputType") 799 | def TF_OperationOutputType(oper_out: Ptr[TF_Output]): TF_DataType = 800 | extern // TF_output 801 | 802 | /** */ 803 | def TF_OperationOutputListLength( 804 | oper: Ptr[TF_Operation], 805 | arg_name: CString, 806 | status: Ptr[TF_Status] 807 | ): CInt = extern 808 | 809 | /** */ 810 | def TF_OperationNumInputs(oper: Ptr[TF_Operation]): CInt = extern 811 | 812 | /** */ 813 | def TF_OperationInputType(oper_in: Ptr[TF_Input]): TF_DataType = 814 | extern // TF_Input 815 | 816 | /** */ 817 | def TF_OperationInputListLength( 818 | oper: Ptr[TF_Operation], 819 | arg_name: CString, 820 | status: Ptr[TF_Status] 821 | ): CInt = extern 822 | 823 | /** In this code: 824 | * {{{ 825 | * TF_Output producer = TF_OperationInput(consumer); 826 | * }}} 827 | * There is an edge from producer.oper's output (given by producer.index) to 828 | * consumer.oper's input (given by consumer.index). 829 | * 830 | * Note: for Scala Native we need to pass an additonal Ptr[TF_Output] to 831 | * capture the original rvalue (stack, pass by value). 832 | */ 833 | @name("scalanative_TF_OperationInput") 834 | def TF_OperationInput( 835 | oper_in: Ptr[TF_Input], 836 | oper_out: Ptr[TF_Output] 837 | ): Ptr[TF_Output] = 838 | extern // TF_Input TF_Output 839 | 840 | /** Get the number of current consumers of a specific output of an operation. 841 | * Note that this number can change when new operations are added to the 842 | * graph. 843 | */ 844 | @name("scalanative_TF_OperationOutputNumConsumers") 845 | def TF_OperationOutputNumConsumers(oper_out: Ptr[TF_Output]): CInt = 846 | extern // TF_output 847 | 848 | /** Get list of all current consumers of a specific output of an operation. 849 | * `consumers` must point to an array of length at least `max_consumers` 850 | * (ideally set to TF_OperationOutputNumConsumers(oper_out)). Beware that a 851 | * concurrent modification of the graph can increase the number of consumers 852 | * of an operation. Returns the number of output consumers (should match 853 | * TF_OperationOutputNumConsumers(oper_out)). 854 | */ 855 | @name("scalanative_TF_OperationOutputConsumers") 856 | def TF_OperationOutputConsumers( 857 | oper_out: Ptr[TF_Output], // TF_output 858 | consumers: Ptr[TF_Input], 859 | max_consumers: CInt 860 | ): CInt = extern 861 | 862 | /** Get the number of control inputs to an operation. 863 | */ 864 | def TF_OperationNumControlInputs(oper: Ptr[TF_Operation]): CInt = extern 865 | 866 | /** Get list of all control inputs to an operation. `control_inputs` must 867 | * point to an array of length `max_control_inputs` (ideally set to 868 | * TF_OperationNumControlInputs(oper)). Returns the number of control inputs 869 | * (should match TF_OperationNumControlInputs(oper)). 870 | */ 871 | def TF_OperationGetControlInputs( 872 | oper: Ptr[TF_Operation], 873 | control_inputs: Ptr[Ptr[TF_Operation]], 874 | max_control_inputs: CInt 875 | ): CInt = extern 876 | 877 | /** Get the number of operations that have `*oper` as a control input. Note 878 | * that this number can change when new operations are added to the graph. 879 | */ 880 | def TF_OperationNumControlOutputs(oper: Ptr[TF_Operation]): CInt = extern 881 | 882 | /** Get the list of operations that have `*oper` as a control input. 883 | * `control_outputs` must point to an array of length at least 884 | * `max_control_outputs` (ideally set to 885 | * TF_OperationNumControlOutputs(oper)). Beware that a concurrent 886 | * modification of the graph can increase the number of control outputs. 887 | * Returns the number of control outputs (should match 888 | * TF_OperationNumControlOutputs(oper)). 889 | */ 890 | def TF_OperationGetControlOutputs( 891 | oper: Ptr[TF_Operation], 892 | control_outputs: Ptr[Ptr[TF_Operation]], 893 | max_control_outputs: CInt 894 | ): CInt = extern 895 | 896 | /** Returns metadata about the value of the attribute `attr_name` of `oper`. 897 | */ 898 | def TF_OperationGetAttrMetadata( 899 | oper: Ptr[TF_Operation], 900 | attr_name: CString, 901 | status: Ptr[TF_Status] 902 | ): TF_AttrMetadata = 903 | extern 904 | 905 | /** Fills in `value` with the value of the attribute `attr_name`. `value` must 906 | * point to an array of length at least `max_length` (ideally set to 907 | * TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, 908 | * attr_name)). 909 | */ 910 | def TF_OperationGetAttrString( 911 | oper: Ptr[TF_Operation], 912 | attr_name: CString, 913 | value: Ptr[Byte], 914 | max_length: CSize, 915 | status: Ptr[TF_Status] 916 | ): Unit = extern 917 | 918 | /** Get the list of strings in the value of the attribute `attr_name`. Fills 919 | * in `values` and `lengths`, each of which must point to an array of length 920 | * at least `max_values`. 921 | * 922 | * The elements of values will point to addresses in `storage` which must be 923 | * at least `storage_size` bytes in length. Ideally, max_values would be set 924 | * to TF_AttrMetadata.list_size and `storage` would be at least 925 | * TF_AttrMetadata.total_size, obtained from 926 | * TF_OperationGetAttrMetadata(oper, attr_name). 927 | * 928 | * Fails if storage_size is too small to hold the requested number of 929 | * strings. 930 | */ 931 | def TF_OperationGetAttrStringList( 932 | oper: Ptr[TF_Operation], 933 | attr_name: CString, 934 | values: Ptr[Ptr[Byte]], 935 | lengths: Ptr[CSize], 936 | max_values: CInt, 937 | storage: Ptr[Byte], 938 | storage_size: CSize, 939 | status: Ptr[TF_Status] 940 | ): Unit = extern 941 | 942 | /** */ 943 | def TF_OperationGetAttrInt( 944 | oper: Ptr[TF_Operation], 945 | attr_name: CString, 946 | value: Ptr[int64_t], 947 | status: Ptr[TF_Status] 948 | ): Unit = extern 949 | 950 | /** Fills in `values` with the value of the attribute `attr_name` of `oper`. 951 | * `values` must point to an array of length at least `max_values` (ideally 952 | * set TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, 953 | * attr_name)). 954 | */ 955 | def TF_OperationGetAttrIntList( 956 | oper: Ptr[TF_Operation], 957 | attr_name: CString, 958 | values: Ptr[int64_t], 959 | max_values: CInt, 960 | status: Ptr[TF_Status] 961 | ): Unit = extern 962 | 963 | /** */ 964 | def TF_OperationGetAttrFloat( 965 | oper: Ptr[TF_Operation], 966 | attr_name: CString, 967 | value: Ptr[CFloat], 968 | status: Ptr[TF_Status] 969 | ): Unit = extern 970 | 971 | /** Fills in `values` with the value of the attribute `attr_name` of `oper`. 972 | * `values` must point to an array of length at least `max_values` (ideally 973 | * set to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, 974 | * attr_name)). 975 | */ 976 | def TF_OperationGetAttrFloatList( 977 | oper: Ptr[TF_Operation], 978 | attr_name: CString, 979 | values: Ptr[CFloat], 980 | max_values: CInt, 981 | status: Ptr[TF_Status] 982 | ): Unit = extern 983 | 984 | /** */ 985 | def TF_OperationGetAttrBool( 986 | oper: Ptr[TF_Operation], 987 | attr_name: CString, 988 | value: Ptr[CUnsignedChar], 989 | status: Ptr[TF_Status] 990 | ): Unit = extern 991 | 992 | /** Fills in `values` with the value of the attribute `attr_name` of `oper`. 993 | * `values` must point to an array of length at least `max_values` (ideally 994 | * set to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, 995 | * attr_name)). 996 | */ 997 | def TF_OperationGetAttrBoolList( 998 | oper: Ptr[TF_Operation], 999 | attr_name: CString, 1000 | values: Ptr[CUnsignedChar], 1001 | max_values: CInt, 1002 | status: Ptr[TF_Status] 1003 | ): Unit = extern 1004 | 1005 | /** */ 1006 | def TF_OperationGetAttrType( 1007 | oper: Ptr[TF_Operation], 1008 | attr_name: CString, 1009 | value: Ptr[TF_DataType], 1010 | status: Ptr[TF_Status] 1011 | ): Unit = extern 1012 | 1013 | /** Fills in `values` with the value of the attribute `attr_name` of `oper`. 1014 | * `values` must point to an array of length at least `max_values` (ideally 1015 | * set to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, 1016 | * attr_name)). 1017 | */ 1018 | def TF_OperationGetAttrTypeList( 1019 | oper: Ptr[TF_Operation], 1020 | attr_name: CString, 1021 | values: Ptr[TF_DataType], 1022 | max_values: CInt, 1023 | status: Ptr[TF_Status] 1024 | ): Unit = extern 1025 | 1026 | /** Fills in `value` with the value of the attribute `attr_name` of `oper`. 1027 | * `values` must point to an array of length at least `num_dims` (ideally set 1028 | * to TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). 1029 | */ 1030 | def TF_OperationGetAttrShape( 1031 | oper: Ptr[TF_Operation], 1032 | attr_name: CString, 1033 | value: Ptr[int64_t], 1034 | num_dims: CInt, 1035 | status: Ptr[TF_Status] 1036 | ): Unit = extern 1037 | 1038 | /** Fills in `dims` with the list of shapes in the attribute `attr_name` of 1039 | * `oper` and `num_dims` with the corresponding number of dimensions. On 1040 | * return, for every i where `num_dims[i]` > 0, `dims[i]` will be an array of 1041 | * `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the 1042 | * i-th shape in the list is unknown. 1043 | * 1044 | * The elements of `dims` will point to addresses in `storage` which must be 1045 | * large enough to hold at least `storage_size` int64_ts. Ideally, 1046 | * `num_shapes` would be set to TF_AttrMetadata.list_size and `storage_size` 1047 | * would be set to TF_AttrMetadata.total_size from 1048 | * TF_OperationGetAttrMetadata(oper, attr_name). 1049 | * 1050 | * Fails if storage_size is insufficient to hold the requested shapes. 1051 | */ 1052 | def TF_OperationGetAttrShapeList( 1053 | oper: Ptr[TF_Operation], 1054 | attr_name: CString, 1055 | dims: Ptr[Ptr[int64_t]], 1056 | num_dims: Ptr[CInt], 1057 | num_shapes: CInt, 1058 | storage: Ptr[int64_t], 1059 | storage_size: CInt, 1060 | status: Ptr[TF_Status] 1061 | ): Unit = extern 1062 | 1063 | /** Sets `value` to the binary-serialized TensorShapeProto of the value of 1064 | * `attr_name` attribute of `oper`'. 1065 | */ 1066 | def TF_OperationGetAttrTensorShapeProto( 1067 | oper: Ptr[TF_Operation], 1068 | attr_name: CString, 1069 | value: Ptr[TF_Buffer], 1070 | status: Ptr[TF_Status] 1071 | ): Unit = extern 1072 | 1073 | /** Fills in `values` with binary-serialized TensorShapeProto values of the 1074 | * attribute `attr_name` of `oper`. `values` must point to an array of length 1075 | * at least `num_values` (ideally set to TF_AttrMetadata.list_size from 1076 | * TF_OperationGetAttrMetadata(oper, attr_name)). 1077 | */ 1078 | def TF_OperationGetAttrTensorShapeProtoList( 1079 | oper: Ptr[TF_Operation], 1080 | attr_name: CString, 1081 | values: Ptr[Ptr[TF_Buffer]], 1082 | max_values: CInt, 1083 | status: Ptr[TF_Status] 1084 | ): Unit = 1085 | extern 1086 | 1087 | /** Gets the TF_Tensor valued attribute of `attr_name` of `oper`. 1088 | * 1089 | * Allocates a new TF_Tensor which the caller is expected to take ownership 1090 | * of (and can deallocate using TF_DeleteTensor). 1091 | */ 1092 | def TF_OperationGetAttrTensor( 1093 | oper: Ptr[TF_Operation], 1094 | attr_name: CString, 1095 | value: Ptr[Ptr[TF_Tensor]], 1096 | status: Ptr[TF_Status] 1097 | ): Unit = extern 1098 | 1099 | /** Fills in `values` with the TF_Tensor values of the attribute `attr_name` 1100 | * of `oper`. `values` must point to an array of TF_Tensor* of length at 1101 | * least `max_values` (ideally set to TF_AttrMetadata.list_size from 1102 | * TF_OperationGetAttrMetadata(oper, attr_name)). 1103 | * 1104 | * The caller takes ownership of all the non-null TF_Tensor* entries in 1105 | * `values` (which can be deleted using TF_DeleteTensor(values[i])). 1106 | */ 1107 | def TF_OperationGetAttrTensorList( 1108 | oper: Ptr[TF_Operation], 1109 | attr_name: CString, 1110 | values: Ptr[Ptr[TF_Tensor]], 1111 | max_values: CInt, 1112 | status: Ptr[TF_Status] 1113 | ): Unit = extern 1114 | 1115 | /** Sets `output_attr_value` to the binary-serialized AttrValue proto 1116 | * representation of the value of the `attr_name` attr of `oper`. 1117 | */ 1118 | def TF_OperationGetAttrValueProto( 1119 | oper: Ptr[TF_Operation], 1120 | attr_name: CString, 1121 | output_attr_value: Ptr[TF_Buffer], 1122 | status: Ptr[TF_Status] 1123 | ): Unit = extern 1124 | 1125 | /** Returns the operation in the graph with `oper_name`. Returns nullptr if no 1126 | * operation found. 1127 | */ 1128 | def TF_GraphOperationByName( 1129 | graph: Ptr[TF_Graph], 1130 | oper_name: CString 1131 | ): Ptr[TF_Operation] = extern 1132 | 1133 | /** Iterate through the operations of a graph. To use: 1134 | * {{{ 1135 | * size_t pos = 0; 1136 | * TF_Operation* oper; 1137 | * while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { 1138 | * DoSomethingWithOperation(oper); 1139 | * } 1140 | * }}} 1141 | */ 1142 | def TF_GraphNextOperation( 1143 | graph: Ptr[TF_Graph], 1144 | pos: Ptr[CSize] 1145 | ): Ptr[TF_Operation] = extern 1146 | 1147 | /** Write out a serialized representation of `graph` (as a GraphDef protocol 1148 | * message) to `output_graph_def` (allocated by TF_NewBuffer()). 1149 | * `output_graph_def`'s underlying buffer will be freed when 1150 | * TF_DeleteBuffer() is called. 1151 | * 1152 | * May fail on very large graphs in the future. 1153 | */ 1154 | def TF_GraphToGraphDef( 1155 | graph: Ptr[TF_Graph], 1156 | output_graph_def: Ptr[TF_Buffer], 1157 | status: Ptr[TF_Status] 1158 | ): Unit = extern 1159 | 1160 | /** Returns the serialized OpDef proto with name `op_name`, or a bad status if 1161 | * no such op exists. This can return OpDefs of functions copied into the 1162 | * graph. 1163 | */ 1164 | def TF_GraphGetOpDef( 1165 | graph: Ptr[TF_Graph], 1166 | op_name: CString, 1167 | output_op_def: Ptr[TF_Buffer], 1168 | status: Ptr[TF_Status] 1169 | ): Unit = extern 1170 | 1171 | /** Returns the serialized VersionDef proto for this graph. 1172 | */ 1173 | def TF_GraphVersions( 1174 | graph: Ptr[TF_Graph], 1175 | output_version_def: Ptr[TF_Buffer], 1176 | status: Ptr[TF_Status] 1177 | ): Unit = extern 1178 | 1179 | /** TF_ImportGraphDefOptions holds options that can be passed to 1180 | * TF_GraphImportGraphDef. 1181 | */ 1182 | type TF_ImportGraphDefOptions = CStruct0 1183 | 1184 | /** */ 1185 | def TF_NewImportGraphDefOptions(): Ptr[TF_ImportGraphDefOptions] = extern 1186 | 1187 | /** */ 1188 | def TF_DeleteImportGraphDefOptions( 1189 | opts: Ptr[TF_ImportGraphDefOptions] 1190 | ): Unit = extern 1191 | 1192 | /** Set the prefix to be prepended to the names of nodes in `graph_def` that 1193 | * will be imported into `graph`. `prefix` is copied and has no lifetime 1194 | * requirements. 1195 | */ 1196 | def TF_ImportGraphDefOptionsSetPrefix( 1197 | opts: Ptr[TF_ImportGraphDefOptions], 1198 | prefix: CString 1199 | ): Unit = extern 1200 | 1201 | /** Set the execution device for nodes in `graph_def`. Only applies to nodes 1202 | * where a device was not already explicitly specified. `device` is copied 1203 | * and has no lifetime requirements. 1204 | */ 1205 | def TF_ImportGraphDefOptionsSetDefaultDevice( 1206 | opts: Ptr[TF_ImportGraphDefOptions], 1207 | device: CString 1208 | ): Unit = extern 1209 | 1210 | /** Set whether to uniquify imported operation names. If true, imported 1211 | * operation names will be modified if their name already exists in the 1212 | * graph. If false, conflicting names will be treated as an error. Note that 1213 | * this option has no effect if a prefix is set, since the prefix will 1214 | * guarantee all names are unique. Defaults to false. 1215 | */ 1216 | def TF_ImportGraphDefOptionsSetUniquifyNames( 1217 | opts: Ptr[TF_ImportGraphDefOptions], 1218 | uniquify_names: CUnsignedChar 1219 | ): Unit = extern 1220 | 1221 | /** If true, the specified prefix will be modified if it already exists as an 1222 | * operation name or prefix in the graph. If false, a conflicting prefix will 1223 | * be treated as an error. This option has no effect if no prefix is 1224 | * specified. 1225 | */ 1226 | def TF_ImportGraphDefOptionsSetUniquifyPrefix( 1227 | opts: Ptr[TF_ImportGraphDefOptions], 1228 | uniquify_prefix: CUnsignedChar 1229 | ): Unit = extern 1230 | 1231 | /** Set any imported nodes with input `src_name:src_index` to have that input 1232 | * replaced with `dst`. `src_name` refers to a node in the graph to be 1233 | * imported, `dst` references a node already existing in the graph being 1234 | * imported into. `src_name` is copied and has no lifetime requirements. 1235 | */ 1236 | @name("scalanative_TF_ImportGraphDefOptionsAddInputMapping") 1237 | def TF_ImportGraphDefOptionsAddInputMapping( 1238 | opts: Ptr[TF_ImportGraphDefOptions], 1239 | src_name: CString, 1240 | src_index: CInt, 1241 | dst: Ptr[TF_Output] 1242 | ): Unit = extern // TF_output 1243 | 1244 | /** Set any imported nodes with control input `src_name` to have that input 1245 | * replaced with `dst`. `src_name` refers to a node in the graph to be 1246 | * imported, `dst` references an operation already existing in the graph 1247 | * being imported into. `src_name` is copied and has no lifetime 1248 | * requirements. 1249 | */ 1250 | def TF_ImportGraphDefOptionsRemapControlDependency( 1251 | opts: Ptr[TF_ImportGraphDefOptions], 1252 | src_name: CString, 1253 | dst: Ptr[TF_Operation] 1254 | ): Unit = extern 1255 | 1256 | /** Cause the imported graph to have a control dependency on `oper`. `oper` 1257 | * should exist in the graph being imported into. 1258 | */ 1259 | def TF_ImportGraphDefOptionsAddControlDependency( 1260 | opts: Ptr[TF_ImportGraphDefOptions], 1261 | oper: Ptr[TF_Operation] 1262 | ): Unit = extern 1263 | 1264 | /** Add an output in `graph_def` to be returned via the `return_outputs` 1265 | * output parameter of TF_GraphImportGraphDef(). If the output is remapped 1266 | * via an input mapping, the corresponding existing tensor in `graph` will be 1267 | * returned. `oper_name` is copied and has no lifetime requirements. 1268 | */ 1269 | def TF_ImportGraphDefOptionsAddReturnOutput( 1270 | opts: Ptr[TF_ImportGraphDefOptions], 1271 | oper_name: CString, 1272 | index: CInt 1273 | ): Unit = extern 1274 | 1275 | /** Returns the number of return outputs added via 1276 | * TF_ImportGraphDefOptionsAddReturnOutput(). 1277 | */ 1278 | def TF_ImportGraphDefOptionsNumReturnOutputs( 1279 | opts: Ptr[TF_ImportGraphDefOptions] 1280 | ): CInt = extern 1281 | 1282 | /** Add an operation in `graph_def` to be returned via the `return_opers` 1283 | * output parameter of TF_GraphImportGraphDef(). `oper_name` is copied and 1284 | * has no lifetime requirements. 1285 | */ 1286 | def TF_ImportGraphDefOptionsAddReturnOperation( 1287 | opts: Ptr[TF_ImportGraphDefOptions], 1288 | oper_name: CString 1289 | ): Unit = extern 1290 | 1291 | /** Returns the number of return operations added via 1292 | * TF_ImportGraphDefOptionsAddReturnOperation(). 1293 | */ 1294 | def TF_ImportGraphDefOptionsNumReturnOperations( 1295 | opts: Ptr[TF_ImportGraphDefOptions] 1296 | ): CInt = extern 1297 | 1298 | /** TF_ImportGraphDefResults holds results that are generated by 1299 | * TF_GraphImportGraphDefWithResults(). 1300 | */ 1301 | type TF_ImportGraphDefResults = CStruct0 1302 | 1303 | /** Fetches the return outputs requested via 1304 | * TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs 1305 | * is returned in `num_outputs`. The array of return outputs is returned in 1306 | * `outputs`. `*outputs` is owned by and has the lifetime of `results`. 1307 | */ 1308 | def TF_ImportGraphDefResultsReturnOutputs( 1309 | results: Ptr[TF_ImportGraphDefResults], 1310 | num_outputs: Ptr[CInt], 1311 | outputs: Ptr[Ptr[TF_Output]] 1312 | ): Unit = extern 1313 | 1314 | /** Fetches the return operations requested via 1315 | * TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched 1316 | * operations is returned in `num_opers`. The array of return operations is 1317 | * returned in `opers`. `*opers` is owned by and has the lifetime of 1318 | * `results`. 1319 | */ 1320 | def TF_ImportGraphDefResultsReturnOperations( 1321 | results: Ptr[TF_ImportGraphDefResults], 1322 | num_opers: Ptr[CInt], 1323 | opers: Ptr[Ptr[Ptr[TF_Operation]]] 1324 | ): Unit = extern 1325 | 1326 | /** Fetches any input mappings requested via 1327 | * TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the 1328 | * GraphDef and weren't used as input to any node in the imported graph def. 1329 | * The number of fetched mappings is returned in 1330 | * `num_missing_unused_input_mappings`. The array of each mapping's source 1331 | * node name is returned in `src_names`, and the array of each mapping's 1332 | * source index is returned in `src_indexes`. 1333 | * 1334 | * `*src_names`, `*src_indexes`, and the memory backing each string in 1335 | * `src_names` are owned by and have the lifetime of `results`. 1336 | */ 1337 | def TF_ImportGraphDefResultsMissingUnusedInputMappings( 1338 | results: Ptr[TF_ImportGraphDefResults], 1339 | num_missing_unused_input_mappings: Ptr[CInt], 1340 | src_names: Ptr[Ptr[CString]], 1341 | src_indexes: Ptr[Ptr[CInt]] 1342 | ): Unit = extern 1343 | 1344 | /** Deletes a results object returned by TF_GraphImportGraphDefWithResults(). 1345 | */ 1346 | def TF_DeleteImportGraphDefResults( 1347 | results: Ptr[TF_ImportGraphDefResults] 1348 | ): Unit = extern 1349 | 1350 | /** Import the graph serialized in `graph_def` into `graph`. Returns nullptr 1351 | * and a bad status on error. Otherwise, returns a populated 1352 | * TF_ImportGraphDefResults instance. The returned instance must be deleted 1353 | * via TF_DeleteImportGraphDefResults(). 1354 | */ 1355 | def TF_GraphImportGraphDefWithResults( 1356 | graph: Ptr[TF_Graph], 1357 | graph_def: Ptr[TF_Buffer], 1358 | options: Ptr[TF_ImportGraphDefOptions], 1359 | status: Ptr[TF_Status] 1360 | ): Ptr[TF_ImportGraphDefResults] = extern 1361 | 1362 | /** Import the graph serialized in `graph_def` into `graph`. Convenience 1363 | * function for when only return outputs are needed. 1364 | * 1365 | * `num_return_outputs` must be the number of return outputs added (i.e. the 1366 | * result of TF_ImportGraphDefOptionsNumReturnOutputs()). If 1367 | * `num_return_outputs` is non-zero, `return_outputs` must be of length 1368 | * `num_return_outputs`. Otherwise it can be null. 1369 | */ 1370 | def TF_GraphImportGraphDefWithReturnOutputs( 1371 | graph: Ptr[TF_Graph], 1372 | graph_def: Ptr[TF_Buffer], 1373 | options: Ptr[TF_ImportGraphDefOptions], 1374 | return_outputs: Ptr[TF_Output], 1375 | num_return_outputs: CInt, 1376 | status: Ptr[TF_Status] 1377 | ): Unit = extern 1378 | 1379 | /** Import the graph serialized in `graph_def` into `graph`. Convenience 1380 | * function for when no results are needed. 1381 | */ 1382 | def TF_GraphImportGraphDef( 1383 | graph: Ptr[TF_Graph], 1384 | graph_def: Ptr[TF_Buffer], 1385 | options: Ptr[TF_ImportGraphDefOptions], 1386 | status: Ptr[TF_Status] 1387 | ): Unit = extern 1388 | 1389 | /** Adds a copy of function `func` and optionally its gradient function `grad` 1390 | * to `g`. Once `func`/`grad` is added to `g`, it can be called by creating 1391 | * an operation using the function's name. Any changes to `func`/`grad` 1392 | * (including deleting it) done after this method returns, won't affect the 1393 | * copy of `func`/`grad` in `g`. If `func` or `grad` are already in `g`, 1394 | * TF_GraphCopyFunction has no effect on them, but can establish the 1395 | * function->gradient relationship between them if `func` does not already 1396 | * have a gradient. If `func` already has a gradient different from `grad`, 1397 | * an error is returned. 1398 | * 1399 | * `func` must not be null. If `grad` is null and `func` is not in `g`, 1400 | * `func` is added without a gradient. If `grad` is null and `func` is in 1401 | * `g`, TF_GraphCopyFunction is a noop. `grad` must have appropriate 1402 | * signature as described in the doc of GradientDef in 1403 | * tensorflow/core/framework/function.proto. 1404 | * 1405 | * If successful, status is set to OK and `func` and `grad` are added to `g`. 1406 | * Otherwise, status is set to the encountered error and `g` is unmodified. 1407 | */ 1408 | def TF_GraphCopyFunction( 1409 | g: Ptr[TF_Graph], 1410 | func: Ptr[TF_Function], 1411 | grad: Ptr[TF_Function], 1412 | status: Ptr[TF_Status] 1413 | ): Unit = extern 1414 | 1415 | /** Returns the number of TF_Functions registered in `g`. 1416 | */ 1417 | def TF_GraphNumFunctions(g: Ptr[TF_Graph]): CInt = extern 1418 | 1419 | /** Fills in `funcs` with the TF_Function* registered in `g`. `funcs` must 1420 | * point to an array of TF_Function* of length at least `max_func`. In usual 1421 | * usage, max_func should be set to the result of TF_GraphNumFunctions(g). In 1422 | * this case, all the functions registered in `g` will be returned. Else, an 1423 | * unspecified subset. 1424 | * 1425 | * If successful, returns the number of TF_Function* successfully set in 1426 | * `funcs` and sets status to OK. The caller takes ownership of all the 1427 | * returned TF_Functions. They must be deleted with TF_DeleteFunction. On 1428 | * error, returns 0, sets status to the encountered error, and the contents 1429 | * of funcs will be undefined. 1430 | */ 1431 | def TF_GraphGetFunctions( 1432 | g: Ptr[TF_Graph], 1433 | funcs: Ptr[Ptr[TF_Function]], 1434 | max_func: CInt, 1435 | status: Ptr[TF_Status] 1436 | ): CInt = extern 1437 | 1438 | /** Note: The following function may fail on very large protos in the future. 1439 | */ 1440 | def TF_OperationToNodeDef( 1441 | oper: Ptr[TF_Operation], 1442 | output_node_def: Ptr[TF_Buffer], 1443 | status: Ptr[TF_Status] 1444 | ): Unit = extern 1445 | 1446 | /** Creates a TF_WhileParams for creating a while loop in `g`. `inputs` are 1447 | * outputs that already exist in `g` used as initial values for the loop 1448 | * variables. 1449 | * 1450 | * The returned TF_WhileParams will have all fields initialized except 1451 | * `cond_output`, `body_outputs`, and `name`. The `body_outputs` buffer will 1452 | * be allocated to size `ninputs`. The caller should build `cond_graph` and 1453 | * `body_graph` starting from the inputs, and store the final outputs in 1454 | * `cond_output` and `body_outputs`. 1455 | * 1456 | * If `status` is OK, the caller must call either TF_FinishWhile or 1457 | * TF_AbortWhile on the returned TF_WhileParams. If `status` isn't OK, the 1458 | * returned TF_WhileParams is not valid, and the caller should not call 1459 | * TF_FinishWhile() or TF_AbortWhile(). 1460 | * 1461 | * Missing functionality (TODO): 1462 | * - Gradients 1463 | * - Reference-type inputs 1464 | * - Directly referencing external tensors from the cond/body graphs (this 1465 | * is possible in the Python API) 1466 | */ 1467 | def TF_NewWhile( 1468 | g: Ptr[TF_Graph], 1469 | inputs: Ptr[TF_Output], 1470 | ninputs: CInt, 1471 | status: Ptr[TF_Status] 1472 | ): TF_WhileParams = extern 1473 | 1474 | /** Builds the while loop specified by `params` and returns the output tensors 1475 | * of the while loop in `outputs`. `outputs` should be allocated to size 1476 | * `params.ninputs`. 1477 | * 1478 | * `params` is no longer valid once this returns. 1479 | * 1480 | * Either this or TF_AbortWhile() must be called after a successful 1481 | * TF_NewWhile() call. 1482 | */ 1483 | def TF_FinishWhile( 1484 | params: Ptr[TF_WhileParams], 1485 | status: Ptr[TF_Status], 1486 | outputs: Ptr[TF_Output] 1487 | ): Unit = extern 1488 | 1489 | /** Frees `params`s resources without building a while loop. `params` is no 1490 | * longer valid after this returns. Either this or TF_FinishWhile() must be 1491 | * called after a successful TF_NewWhile() call. 1492 | */ 1493 | def TF_AbortWhile(params: Ptr[TF_WhileParams]): Unit = extern 1494 | 1495 | /** Adds operations to compute the partial derivatives of sum of `y`s w.r.t 1496 | * `x`s, i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... 1497 | * 1498 | * `dx` are used as initial gradients (which represent the symbolic partial 1499 | * derivatives of some loss function `L` w.r.t. `y`). `dx` must be nullptr or 1500 | * have size `ny`. If `dx` is nullptr, the implementation will use dx of 1501 | * `OnesLike` for all shapes in `y`. The partial derivatives are returned in 1502 | * `dy`. `dy` should be allocated to size `nx`. 1503 | * 1504 | * Gradient nodes are automatically named under the "gradients/" prefix. To 1505 | * guarantee name uniqueness, subsequent calls to the same graph will append 1506 | * an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... See 1507 | * TF_AddGradientsWithPrefix, which provides a means to specify a custom name 1508 | * prefix for operations added to a graph to compute the gradients. 1509 | * 1510 | * WARNING: This function does not yet support all the gradients that python 1511 | * supports. See 1512 | * https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md for 1513 | * instructions on how to add C++ more gradients. 1514 | */ 1515 | def TF_AddGradients( 1516 | g: Ptr[TF_Graph], 1517 | y: Ptr[TF_Output], 1518 | ny: CInt, 1519 | x: Ptr[TF_Output], 1520 | nx: CInt, 1521 | dx: Ptr[TF_Output], 1522 | status: Ptr[TF_Status], 1523 | dy: Ptr[TF_Output] 1524 | ): Unit = extern 1525 | 1526 | /** Adds operations to compute the partial derivatives of sum of `y`s w.r.t 1527 | * `x`s, i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... This is a 1528 | * variant of TF_AddGradients that allows to caller to pass a custom name 1529 | * prefix to the operations added to a graph to compute the gradients. 1530 | * 1531 | * `dx` are used as initial gradients (which represent the symbolic partial 1532 | * derivatives of some loss function `L` w.r.t. `y`). `dx` must be nullptr or 1533 | * have size `ny`. If `dx` is nullptr, the implementation will use dx of 1534 | * `OnesLike` for all shapes in `y`. The partial derivatives are returned in 1535 | * `dy`. `dy` should be allocated to size `nx`. `prefix` names the scope into 1536 | * which all gradients operations are being added. `prefix` must be unique 1537 | * within the provided graph otherwise this operation will fail. If `prefix` 1538 | * is nullptr, the default prefixing behaviour takes place, see 1539 | * TF_AddGradients for more details. 1540 | * 1541 | * WARNING: This function does not yet support all the gradients that python 1542 | * supports. See 1543 | * https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md for 1544 | * instructions on how to add C++ more gradients. 1545 | */ 1546 | def TF_AddGradientsWithPrefix( 1547 | g: Ptr[TF_Graph], 1548 | prefix: CString, 1549 | y: Ptr[TF_Output], 1550 | ny: CInt, 1551 | x: Ptr[TF_Output], 1552 | nx: CInt, 1553 | dx: Ptr[TF_Output], 1554 | status: Ptr[TF_Status], 1555 | dy: Ptr[TF_Output] 1556 | ): Unit = extern 1557 | 1558 | /** Create a TF_Function from a TF_Graph 1559 | * 1560 | * Params: 1561 | * 1562 | * fn_body 1563 | * - the graph whose operations (or subset of whose operations) will be 1564 | * converted to TF_Function. 1565 | * 1566 | * fn_name 1567 | * - the name of the new TF_Function. Should match the operation name 1568 | * (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. If 1569 | * `append_hash_to_fn_name` is false, `fn_name` must be distinct from 1570 | * other function and operation names (at least those registered in 1571 | * graphs where this function will be used). 1572 | * 1573 | * append_hash_to_fn_name 1574 | * - Must be 0 or 1. If set to 1, the actual name of the function will be 1575 | * `fn_name` appended with '_'. If 1576 | * set to 0, the function's name will be `fn_name`. 1577 | * 1578 | * num_opers 1579 | * - `num_opers` contains the number of elements in the `opers` array or a 1580 | * special value of -1 meaning that no array is given. The distinction 1581 | * between an empty array of operations and no array of operations is 1582 | * necessary to distinguish the case of creating a function with no body 1583 | * (e.g. identity or permutation) and the case of creating a function 1584 | * whose body contains all the nodes in the graph (except for the 1585 | * automatic skipping, see below). 1586 | * 1587 | * opers 1588 | * - Array of operations to become the body of the function or null. 1589 | * - If no array is given (`num_opers` = -1), all the operations in 1590 | * `fn_body` will become part of the function except operations 1591 | * referenced in `inputs`. These operations must have a single output 1592 | * (these operations are typically placeholders created for the sole 1593 | * purpose of representing an input. We can relax this constraint if 1594 | * there are compelling use cases). 1595 | * - If an array is given (`num_opers` >= 0), all operations in it will 1596 | * become part of the function. In particular, no automatic skipping of 1597 | * dummy input operations is performed. 1598 | * 1599 | * ninputs 1600 | * - number of elements in `inputs` array 1601 | * 1602 | * inputs 1603 | * - array of TF_Outputs that specify the inputs to the function. If 1604 | * `ninputs` is zero (the function takes no inputs), `inputs` can be 1605 | * null. The names used for function inputs are normalized names of the 1606 | * operations (usually placeholders) pointed to by `inputs`. These 1607 | * operation names should start with a letter. Normalization will convert 1608 | * all letters to lowercase and non-alphanumeric characters to '_' to 1609 | * make resulting names match the "[a-z][a-z0-9_]*" pattern for operation 1610 | * argument names. `inputs` cannot contain the same tensor twice. 1611 | * 1612 | * noutputs 1613 | * - number of elements in `outputs` array outputs - array of TF_Outputs 1614 | * that specify the outputs of the function. If `noutputs` is zero (the 1615 | * function returns no outputs), `outputs` can be null. `outputs` can 1616 | * contain the same tensor more than once. 1617 | * 1618 | * output_names 1619 | * - The names of the function's outputs. `output_names` array must either 1620 | * have the same length as `outputs` (i.e. `noutputs`) or be null. In the 1621 | * former case, the names should match the regular expression for ArgDef 1622 | * names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will 1623 | * be generated automatically. 1624 | * 1625 | * opts 1626 | * - various options for the function, e.g. XLA's inlining control. 1627 | * 1628 | * description 1629 | * - optional human-readable description of this function. 1630 | * 1631 | * status 1632 | * - Set to OK on success and an appropriate error on failure. 1633 | * 1634 | * Note that when the same TF_Output is listed as both an input and an 1635 | * output, the corresponding function's output will equal to this input, 1636 | * instead of the original node's output. 1637 | * 1638 | * Callers must also satisfy the following constraints: 1639 | * - `inputs` cannot refer to TF_Outputs within a control flow context. For 1640 | * example, one cannot use the output of "switch" node as input. 1641 | * - `inputs` and `outputs` cannot have reference types. Reference types 1642 | * are not exposed through C API and are being replaced with Resources. 1643 | * We support reference types inside function's body to support legacy 1644 | * code. Do not use them in new code. 1645 | * - Every node in the function's body must have all of its inputs 1646 | * (including control inputs). In other words, for every node in the 1647 | * body, each input must be either listed in `inputs` or must come from 1648 | * another node in the body. In particular, it is an error to have a 1649 | * control edge going from a node outside of the body into a node in the 1650 | * body. This applies to control edges going from nodes referenced in 1651 | * `inputs` to nodes in the body when the former nodes are not in the 1652 | * body (automatically skipped or not included in explicitly specified 1653 | * body). 1654 | * 1655 | * Returns: On success, a newly created TF_Function instance. It must be 1656 | * deleted by calling TF_DeleteFunction. 1657 | */ 1658 | def TF_GraphToFunction( 1659 | fn_body: Ptr[TF_Graph], 1660 | fn_name: CString, 1661 | append_hash_to_fn_name: CUnsignedChar, 1662 | num_opers: CInt, 1663 | opers: Ptr[Ptr[TF_Operation]], 1664 | ninputs: CInt, 1665 | inputs: Ptr[TF_Output], 1666 | noutputs: CInt, 1667 | outputs: Ptr[TF_Output], 1668 | output_names: Ptr[CString], 1669 | opts: Ptr[TF_FunctionOptions], 1670 | description: CString, 1671 | status: Ptr[TF_Status] 1672 | ): Ptr[TF_Function] = extern 1673 | 1674 | /** Returns the name of the graph function. The return value points to memory 1675 | * that is only usable until the next mutation to *func. 1676 | */ 1677 | def TF_FunctionName(func: Ptr[TF_Function]): CString = extern 1678 | 1679 | /** Write out a serialized representation of `func` (as a FunctionDef protocol 1680 | * message) to `output_func_def` (allocated by TF_NewBuffer()). 1681 | * `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() 1682 | * is called. 1683 | * 1684 | * May fail on very large graphs in the future. 1685 | */ 1686 | def TF_FunctionToFunctionDef( 1687 | func: Ptr[TF_Function], 1688 | output_func_def: Ptr[TF_Buffer], 1689 | status: Ptr[TF_Status] 1690 | ): Unit = extern 1691 | 1692 | /** Construct and return the function whose FunctionDef representation is 1693 | * serialized in `proto`. `proto_len` must equal the number of bytes pointed 1694 | * to by `proto`. Returns: On success, a newly created TF_Function instance. 1695 | * It must be deleted by calling TF_DeleteFunction. 1696 | * 1697 | * On failure, null. 1698 | */ 1699 | def TF_FunctionImportFunctionDef( 1700 | proto: Ptr[Byte], 1701 | proto_len: CSize, 1702 | status: Ptr[TF_Status] 1703 | ): Ptr[TF_Function] = 1704 | extern 1705 | 1706 | /** Sets function attribute named `attr_name` to value stored in `proto`. If 1707 | * this attribute is already set to another value, it is overridden. `proto` 1708 | * should point to a sequence of bytes of length `proto_len` representing a 1709 | * binary serialization of an AttrValue protocol buffer. 1710 | */ 1711 | def TF_FunctionSetAttrValueProto( 1712 | func: Ptr[TF_Function], 1713 | attr_name: CString, 1714 | proto: Ptr[Byte], 1715 | proto_len: CSize, 1716 | status: Ptr[TF_Status] 1717 | ): Unit = extern 1718 | 1719 | /** Sets `output_attr_value` to the binary-serialized AttrValue proto 1720 | * representation of the value of the `attr_name` attr of `func`. If 1721 | * `attr_name` attribute is not present, status is set to an error. 1722 | */ 1723 | def TF_FunctionGetAttrValueProto( 1724 | func: Ptr[TF_Function], 1725 | attr_name: CString, 1726 | output_attr_value: Ptr[TF_Buffer], 1727 | status: Ptr[TF_Status] 1728 | ): Unit = extern 1729 | 1730 | /** Frees the memory used by the `func` struct. TF_DeleteFunction is a noop if 1731 | * `func` is null. Deleting a function does not remove it from any graphs it 1732 | * was copied to. 1733 | */ 1734 | def TF_DeleteFunction(func: Ptr[TF_Function]): Unit = extern 1735 | 1736 | /** Attempts to evaluate `output`. This will only be possible if `output` 1737 | * doesn't depend on any graph inputs (this function is safe to call if this 1738 | * isn't the case though). 1739 | * 1740 | * If the evaluation is successful, this function returns true and `output`s 1741 | * value is returned in `result`. Otherwise returns false. An error status is 1742 | * returned if something is wrong with the graph or input. Note that this may 1743 | * return false even if no error status is set. 1744 | */ 1745 | @name("scalanative_TF_TryEvaluateConstant") 1746 | def TF_TryEvaluateConstant( 1747 | graph: Ptr[TF_Graph], 1748 | output: Ptr[TF_Output], // TF_output 1749 | result: Ptr[Ptr[TF_Tensor]], 1750 | status: Ptr[TF_Status] 1751 | ): CUnsignedChar = extern 1752 | 1753 | /** API for driving Graph execution. 1754 | */ 1755 | type TF_Session = CStruct0 1756 | 1757 | /** Return a new execution session with the associated graph, or NULL on 1758 | * error. Does not take ownership of any input parameters. 1759 | * 1760 | * *`graph` must be a valid graph (not deleted or nullptr). `graph` will be 1761 | * be kept alive for the lifetime of the returned TF_Session. New nodes can 1762 | * still be added to `graph` after this call. 1763 | */ 1764 | def TF_NewSession( 1765 | graph: Ptr[TF_Graph], 1766 | opts: Ptr[TF_SessionOptions], 1767 | status: Ptr[TF_Status] 1768 | ): Ptr[TF_Session] = extern 1769 | 1770 | /** This function creates a new TF_Session (which is created on success) using 1771 | * `session_options`, and then initializes state (restoring tensors and other 1772 | * assets) using `run_options`. 1773 | * 1774 | * Any NULL and non-NULL value combinations for (`run_options, 1775 | * `meta_graph_def`) are valid. 1776 | * 1777 | * - `export_dir` must be set to the path of the exported SavedModel. 1778 | * - `tags` must include the set of tags used to identify one MetaGraphDef 1779 | * in the SavedModel. 1780 | * - `graph` must be a graph newly allocated with TF_NewGraph(). 1781 | * 1782 | * If successful, populates `graph` with the contents of the Graph and 1783 | * `meta_graph_def` with the MetaGraphDef of the loaded model. 1784 | */ 1785 | def TF_LoadSessionFromSavedModel( 1786 | session_options: Ptr[TF_SessionOptions], 1787 | run_options: Ptr[TF_Buffer], 1788 | export_dir: CString, 1789 | tags: Ptr[CString], 1790 | tags_len: CInt, 1791 | graph: Ptr[TF_Graph], 1792 | meta_graph_def: Ptr[TF_Buffer], 1793 | status: Ptr[TF_Status] 1794 | ): Ptr[TF_Session] = 1795 | extern 1796 | 1797 | /** Close a session. 1798 | * 1799 | * Contacts any other processes associated with the session, if applicable. 1800 | * May not be called after TF_DeleteSession(). 1801 | */ 1802 | def TF_CloseSession(session: Ptr[TF_Session], status: Ptr[TF_Status]): Unit = 1803 | extern 1804 | 1805 | /** Destroy a session object. 1806 | * 1807 | * Even if error information is recorded in *status, this call discards all 1808 | * local resources associated with the session. The session may not be used 1809 | * during or after this call (and the session drops its reference to the 1810 | * corresponding graph). 1811 | */ 1812 | def TF_DeleteSession(session: Ptr[TF_Session], status: Ptr[TF_Status]): Unit = 1813 | extern 1814 | 1815 | /** Run the graph associated with the session starting with the supplied 1816 | * inputs (inputs[0,ninputs-1] with corresponding values in 1817 | * input_values[0,ninputs-1]). 1818 | * 1819 | * Any NULL and non-NULL value combinations for (`run_options`, 1820 | * `run_metadata`) are valid. 1821 | * 1822 | * - `run_options` may be NULL, in which case it will be ignored; or 1823 | * non-NULL, in which case it must point to a `TF_Buffer` containing the 1824 | * serialized representation of a `RunOptions` protocol buffer. 1825 | * - `run_metadata` may be NULL, in which case it will be ignored; or 1826 | * non-NULL, in which case it must point to an empty, freshly allocated 1827 | * `TF_Buffer` that may be updated to contain the serialized 1828 | * representation of a `RunMetadata` protocol buffer. 1829 | * 1830 | * The caller retains ownership of `input_values` (which can be deleted using 1831 | * TF_DeleteTensor). The caller also retains ownership of `run_options` 1832 | * and/or `run_metadata` (when not NULL) and should manually call 1833 | * TF_DeleteBuffer on them. 1834 | * 1835 | * On success, the tensors corresponding to outputs[0,noutputs-1] are placed 1836 | * in output_values[]. Ownership of the elements of output_values[] is 1837 | * transferred to the caller, which must eventually call TF_DeleteTensor on 1838 | * them. 1839 | * 1840 | * On failure, output_values[] contains NULLs. 1841 | */ 1842 | def TF_SessionRun( 1843 | session: Ptr[TF_Session], 1844 | // RunOptions 1845 | run_options: Ptr[TF_Buffer], 1846 | // Input tensors 1847 | inputs: Ptr[TF_Output], 1848 | input_values: Ptr[Ptr[TF_Tensor]], 1849 | ninputs: CInt, 1850 | // Output tensors 1851 | outputs: Ptr[TF_Output], 1852 | output_values: Ptr[Ptr[TF_Tensor]], 1853 | noutputs: CInt, 1854 | // Target operations 1855 | target_opers: Ptr[Ptr[TF_Operation]], 1856 | ntargets: CInt, 1857 | // RunMetadata 1858 | run_metadata: Ptr[TF_Buffer], 1859 | // Output status 1860 | status: Ptr[TF_Status] 1861 | ): Unit = extern 1862 | 1863 | /** Set up the graph with the intended feeds (inputs) and fetches (outputs) 1864 | * for a sequence of partial run calls. 1865 | * 1866 | * On success, returns a handle that is used for subsequent PRun calls. The 1867 | * handle should be deleted with TF_DeletePRunHandle when it is no longer 1868 | * needed. 1869 | * 1870 | * On failure, out_status contains a tensorflow::Status with an error 1871 | * message. *handle is set to nullptr. 1872 | */ 1873 | def TF_SessionPRunSetup( 1874 | session: Ptr[TF_Session], 1875 | // Input names 1876 | inputs: Ptr[TF_Output], 1877 | ninputs: CInt, 1878 | // Output names 1879 | outputs: Ptr[TF_Output], 1880 | noutputs: CInt, 1881 | // Target operations 1882 | target_opers: Ptr[Ptr[TF_Operation]], 1883 | ntargets: CInt, 1884 | // Output handle 1885 | handle: Ptr[CString], 1886 | // Output status 1887 | status: Ptr[TF_Status] 1888 | ): Unit = extern 1889 | 1890 | /** Continue to run the graph with additional feeds and fetches. The execution 1891 | * state is uniquely identified by the handle. 1892 | */ 1893 | def TF_SessionPRun( 1894 | session: Ptr[TF_Session], 1895 | handle: CString, 1896 | // Input tensors 1897 | inputs: Ptr[TF_Output], 1898 | input_values: Ptr[Ptr[TF_Tensor]], 1899 | ninputs: CInt, 1900 | // Output tensors 1901 | outputs: Ptr[TF_Output], 1902 | output_values: Ptr[Ptr[TF_Tensor]], 1903 | noutputs: CInt, 1904 | // Target operations 1905 | target_opers: Ptr[Ptr[TF_Operation]], 1906 | ntargets: CInt, 1907 | // Output status 1908 | status: Ptr[TF_Status] 1909 | ): Unit = extern 1910 | 1911 | /** Deletes a handle allocated by TF_SessionPRunSetup. Once called, no more 1912 | * calls to TF_SessionPRun should be made. 1913 | */ 1914 | def TF_DeletePRunHandle(handle: CString): Unit = extern 1915 | 1916 | /** The deprecated session API. Please switch to the above instead of 1917 | * TF_ExtendGraph(). This deprecated API can be removed at any time without 1918 | * notice. 1919 | */ 1920 | type TF_DeprecatedSession = CStruct0 1921 | 1922 | /** */ 1923 | def TF_NewDeprecatedSession( 1924 | sessionOptions: Ptr[TF_SessionOptions], 1925 | status: Ptr[TF_Status] 1926 | ): Ptr[TF_DeprecatedSession] = extern 1927 | 1928 | /** */ 1929 | def TF_CloseDeprecatedSession( 1930 | deprecatedSession: Ptr[TF_DeprecatedSession], 1931 | status: Ptr[TF_Status] 1932 | ): Unit = extern 1933 | 1934 | /** */ 1935 | def TF_DeleteDeprecatedSession( 1936 | deprecatedSession: Ptr[TF_DeprecatedSession], 1937 | status: Ptr[TF_Status] 1938 | ): Unit = extern 1939 | 1940 | /** */ 1941 | def TF_Reset( 1942 | opt: Ptr[TF_SessionOptions], 1943 | containers: Ptr[CString], 1944 | ncontainers: CInt, 1945 | status: Ptr[TF_Status] 1946 | ): Unit = extern 1947 | 1948 | /** Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and add the 1949 | * nodes in that GraphDef to the graph for the session. 1950 | * 1951 | * Prefer use of TF_Session and TF_GraphImportGraphDef over this. 1952 | */ 1953 | def TF_ExtendGraph( 1954 | deprecatedSession: Ptr[TF_DeprecatedSession], 1955 | proto: Ptr[Byte], 1956 | proto_len: CSize, 1957 | status: Ptr[TF_Status] 1958 | ): Unit = extern 1959 | 1960 | /** See TF_SessionRun() above. 1961 | */ 1962 | def TF_Run( 1963 | deprecatedSession: Ptr[TF_DeprecatedSession], 1964 | run_options: Ptr[TF_Buffer], 1965 | input_names: Ptr[CString], 1966 | inputs: Ptr[Ptr[TF_Tensor]], 1967 | ninputs: CInt, 1968 | output_names: Ptr[CString], 1969 | outputs: Ptr[Ptr[TF_Tensor]], 1970 | noutputs: CInt, 1971 | target_oper_names: Ptr[CString], 1972 | ntargets: CInt, 1973 | run_metadata: Ptr[TF_Buffer], 1974 | status: Ptr[TF_Status] 1975 | ): Unit = extern 1976 | 1977 | /** See TF_SessionPRunSetup() above. 1978 | */ 1979 | def TF_PRunSetup( 1980 | deprecatedSession: Ptr[TF_DeprecatedSession], 1981 | input_names: Ptr[CString], 1982 | ninputs: CInt, 1983 | output_names: Ptr[CString], 1984 | noutputs: CInt, 1985 | target_oper_names: Ptr[CString], 1986 | ntargets: CInt, 1987 | handle: Ptr[CString], 1988 | status: Ptr[TF_Status] 1989 | ): Unit = extern 1990 | 1991 | /** See TF_SessionPRun above. 1992 | */ 1993 | def TF_PRun( 1994 | deprecatedSession: Ptr[TF_DeprecatedSession], 1995 | handle: CString, 1996 | input_names: Ptr[CString], 1997 | inputs: Ptr[Ptr[TF_Tensor]], 1998 | ninputs: CInt, 1999 | output_names: Ptr[CString], 2000 | outputs: Ptr[Ptr[TF_Tensor]], 2001 | noutputs: CInt, 2002 | target_oper_names: Ptr[CString], 2003 | ntargets: CInt, 2004 | status: Ptr[TF_Status] 2005 | ): Unit = extern 2006 | 2007 | type TF_DeviceList = CStruct0 2008 | 2009 | /** Lists all devices in a TF_Session. 2010 | * 2011 | * Caller takes ownership of the returned TF_DeviceList* which must 2012 | * eventually be freed with a call to TF_DeleteDeviceList. 2013 | */ 2014 | def TF_SessionListDevices( 2015 | session: Ptr[TF_Session], 2016 | status: Ptr[TF_Status] 2017 | ): Ptr[TF_DeviceList] = extern 2018 | 2019 | /** Lists all devices in a TF_Session. 2020 | * 2021 | * Caller takes ownership of the returned TF_DeviceList* which must 2022 | * eventually be freed with a call to TF_DeleteDeviceList. 2023 | */ 2024 | def TF_DeprecatedSessionListDevices( 2025 | session: Ptr[TF_DeprecatedSession], 2026 | status: Ptr[TF_Status] 2027 | ): Ptr[TF_DeviceList] = extern 2028 | 2029 | /** Deallocates the device list. 2030 | */ 2031 | def TF_DeleteDeviceList(list: Ptr[TF_DeviceList]): Unit = extern 2032 | 2033 | /** Counts the number of elements in the device list. 2034 | */ 2035 | def TF_DeviceListCount(list: Ptr[TF_DeviceList]): CInt = extern 2036 | 2037 | /** Retrieves the full name of the device (e.g. /job:worker/replica:0/...) The 2038 | * return value will be a pointer to a null terminated string. The caller 2039 | * must not modify or delete the string. It will be deallocated upon a call 2040 | * to TF_DeleteDeviceList. 2041 | * 2042 | * If index is out of bounds, an error code will be set in the status object, 2043 | * and a null pointer will be returned. 2044 | */ 2045 | def TF_DeviceListName( 2046 | list: Ptr[TF_DeviceList], 2047 | index: CInt, 2048 | status: Ptr[TF_Status] 2049 | ): CString = extern 2050 | 2051 | /** Retrieves the type of the device at the given index. 2052 | * 2053 | * The caller must not modify or delete the string. It will be deallocated 2054 | * upon a call to TF_DeleteDeviceList. 2055 | * 2056 | * If index is out of bounds, an error code will be set in the status object, 2057 | * and a null pointer will be returned. 2058 | */ 2059 | def TF_DeviceListType( 2060 | list: Ptr[TF_DeviceList], 2061 | index: CInt, 2062 | status: Ptr[TF_Status] 2063 | ): CString = extern 2064 | 2065 | /** Retrieve the amount of memory associated with a given device. 2066 | * 2067 | * If index is out of bounds, an error code will be set in the status object, 2068 | * and -1 will be returned. 2069 | */ 2070 | def TF_DeviceListMemoryBytes( 2071 | list: Ptr[TF_DeviceList], 2072 | index: CInt, 2073 | status: Ptr[TF_Status] 2074 | ): int64_t = extern 2075 | 2076 | /** Retrieve the incarnation number of a given device. 2077 | * 2078 | * If index is out of bounds, an error code will be set in the status object, 2079 | * and 0 will be returned. 2080 | */ 2081 | def TF_DeviceListIncarnation( 2082 | list: Ptr[TF_DeviceList], 2083 | index: CInt, 2084 | status: Ptr[TF_Status] 2085 | ): uint64_t = extern 2086 | 2087 | // Load plugins containing custom ops and kernels 2088 | 2089 | /** TF_Library holds information about dynamically loaded TensorFlow plugins. 2090 | */ 2091 | type TF_Library = CStruct0 2092 | 2093 | /** Load the library specified by library_filename and register the ops and 2094 | * kernels present in that library. 2095 | * 2096 | * Pass "library_filename" to a platform-specific mechanism for dynamically 2097 | * loading a library. The rules for determining the exact location of the 2098 | * library are platform-specific and are not documented here. 2099 | * 2100 | * On success, place OK in status and return the newly created library 2101 | * handle. The caller owns the library handle. 2102 | * 2103 | * On failure, place an error status in status and return NULL. 2104 | */ 2105 | def TF_LoadLibrary( 2106 | library_filename: CString, 2107 | status: Ptr[TF_Status] 2108 | ): Ptr[TF_Library] = extern 2109 | 2110 | /** Get the OpList of OpDefs defined in the library pointed by lib_handle. 2111 | * 2112 | * Returns a TF_Buffer. The memory pointed to by the result is owned by 2113 | * lib_handle. The data in the buffer will be the serialized OpList proto for 2114 | * ops defined in the library. 2115 | */ 2116 | def TF_GetOpList(lib_handle: Ptr[TF_Library]): TF_Buffer = extern 2117 | 2118 | /** Frees the memory associated with the library handle. Does NOT unload the 2119 | * library. 2120 | */ 2121 | def TF_DeleteLibraryHandle(lib_handle: Ptr[TF_Library]): Unit = extern 2122 | 2123 | /** Get the OpList of all OpDefs defined in this address space. Returns a 2124 | * TF_Buffer, ownership of which is transferred to the caller (and can be 2125 | * freed using TF_DeleteBuffer). 2126 | * 2127 | * The data in the buffer will be the serialized OpList proto for ops 2128 | * registered in this address space. 2129 | */ 2130 | def TF_GetAllOpList(): Ptr[TF_Buffer] = extern 2131 | 2132 | /** TF_ApiDefMap encapsulates a collection of API definitions for an 2133 | * operation. 2134 | * 2135 | * This object maps the name of a TensorFlow operation to a description of 2136 | * the API to generate for it, as defined by the ApiDef protocol buffer ( 2137 | * https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) 2138 | * 2139 | * The ApiDef messages are typically used to generate convenience wrapper 2140 | * functions for TensorFlow operations in various language bindings. 2141 | */ 2142 | type TF_ApiDefMap = CStruct0 2143 | 2144 | /** Creates a new TF_ApiDefMap instance. 2145 | * 2146 | * Params: 2147 | * 2148 | * op_list_buffer 2149 | * - TF_Buffer instance containing serialized OpList protocol buffer. (See 2150 | * https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto 2151 | * for the OpList proto definition). 2152 | * 2153 | * status 2154 | * - Set to OK on success and an appropriate error on failure. 2155 | */ 2156 | def TF_NewApiDefMap( 2157 | op_list_buffer: Ptr[TF_Buffer], 2158 | status: Ptr[TF_Status] 2159 | ): Ptr[TF_ApiDefMap] = extern 2160 | 2161 | /** Deallocates a TF_ApiDefMap. 2162 | */ 2163 | def TF_DeleteApiDefMap(apimap: Ptr[TF_ApiDefMap]): Unit = extern 2164 | 2165 | /** Add ApiDefs to the map. 2166 | * 2167 | * `text` corresponds to a text representation of an ApiDefs protocol 2168 | * message. 2169 | * (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). 2170 | * 2171 | * The provided ApiDefs will be merged with existing ones in the map, with 2172 | * precedence given to the newly added version in case of conflicts with 2173 | * previous calls to TF_ApiDefMapPut. 2174 | */ 2175 | def TF_ApiDefMapPut( 2176 | api_def_map: Ptr[TF_ApiDefMap], 2177 | text: CString, 2178 | text_len: CSize, 2179 | status: Ptr[TF_Status] 2180 | ): Unit = extern 2181 | 2182 | /** Returns a serialized ApiDef protocol buffer for the TensorFlow operation 2183 | * named `name`. 2184 | */ 2185 | def TF_ApiDefMapGet( 2186 | api_def_map: Ptr[TF_ApiDefMap], 2187 | name: CString, 2188 | name_len: CSize, 2189 | status: Ptr[TF_Status] 2190 | ): Ptr[TF_Buffer] = extern 2191 | 2192 | // Kernel definition information. 2193 | 2194 | /** Returns a serialized KernelList protocol buffer containing KernelDefs for 2195 | * all registered kernels. 2196 | */ 2197 | def TF_GetAllRegisteredKernels(status: Ptr[TF_Status]): Ptr[TF_Buffer] = 2198 | extern 2199 | 2200 | /** Returns a serialized KernelList protocol buffer containing KernelDefs for 2201 | * all kernels registered for the operation named `name`. 2202 | */ 2203 | def TF_GetRegisteredKernelsForOp( 2204 | name: CString, 2205 | status: Ptr[TF_Status] 2206 | ): Ptr[TF_Buffer] = 2207 | extern 2208 | 2209 | /** In-process TensorFlow server functionality, for use in distributed 2210 | * training. A Server instance encapsulates a set of devices and a Session 2211 | * target that can participate in distributed training. A server belongs to a 2212 | * cluster (specified by a ClusterSpec), and corresponds to a particular task 2213 | * in a named job. The server can communicate with any other server in the 2214 | * same cluster. 2215 | * 2216 | * In-process TensorFlow server. 2217 | */ 2218 | type TF_Server = CStruct0 2219 | 2220 | /** Creates a new in-process TensorFlow server configured using a serialized 2221 | * ServerDef protocol buffer provided via `proto` and `proto_len`. 2222 | * 2223 | * The server will not serve any requests until TF_ServerStart is invoked. 2224 | * The server will stop serving requests once TF_ServerStop or 2225 | * TF_DeleteServer is invoked. 2226 | */ 2227 | def TF_NewServer( 2228 | proto: Ptr[Byte], 2229 | proto_len: CSize, 2230 | status: Ptr[TF_Status] 2231 | ): Ptr[TF_Server] = extern 2232 | 2233 | /** Starts an in-process TensorFlow server. 2234 | */ 2235 | def TF_ServerStart(server: Ptr[TF_Server], status: Ptr[TF_Status]): Unit = 2236 | extern 2237 | 2238 | /** Stops an in-process TensorFlow server. 2239 | */ 2240 | def TF_ServerStop(server: Ptr[TF_Server], status: Ptr[TF_Status]): Unit = 2241 | extern 2242 | 2243 | /** Blocks until the server has been successfully stopped (via TF_ServerStop 2244 | * or TF_ServerClose). 2245 | */ 2246 | def TF_ServerJoin(server: Ptr[TF_Server], status: Ptr[TF_Status]): Unit = 2247 | extern 2248 | 2249 | /** Returns the target string that can be provided to TF_SetTarget() to 2250 | * connect a TF_Session to `server`. 2251 | * 2252 | * The returned string is valid only until TF_DeleteServer is invoked. 2253 | */ 2254 | def TF_ServerTarget(server: Ptr[TF_Server]): CString = extern 2255 | 2256 | /** Destroy an in-process TensorFlow server, frees memory. If server is 2257 | * running it will be stopped and joined. 2258 | */ 2259 | def TF_DeleteServer(server: Ptr[TF_Server]): Unit = extern 2260 | } 2261 | 2262 | import tensorflow._ 2263 | 2264 | object tensorflowOps { 2265 | 2266 | implicit class TF_Buffer_ops(val p: Ptr[TF_Buffer]) extends AnyVal { 2267 | def data: Ptr[Byte] = p._1 2268 | def data_=(value: Ptr[Byte]): Unit = p._1 = value 2269 | def length: CSize = p._2 2270 | def length_=(value: CSize): Unit = p._2 = value 2271 | def data_deallocator: CFuncPtr2[Ptr[Byte], CSize, Unit] = p._3 2272 | def data_deallocator_=(value: CFuncPtr2[Ptr[Byte], CSize, Unit]): Unit = 2273 | p._3 = value 2274 | } 2275 | 2276 | def TF_Buffer()(implicit z: Zone): Ptr[TF_Buffer] = 2277 | alloc[TF_Buffer]() 2278 | 2279 | implicit class TF_Input_ops(val p: Ptr[TF_Input]) extends AnyVal { 2280 | def oper: Ptr[TF_Operation] = p._1 2281 | def oper_=(value: Ptr[TF_Operation]): Unit = p._1 = value 2282 | def index: CInt = p._2 2283 | def index_=(value: CInt): Unit = p._2 = value 2284 | } 2285 | 2286 | def TF_Input()(implicit z: Zone): Ptr[TF_Input] = 2287 | alloc[TF_Input]() 2288 | 2289 | implicit class TF_Output_ops(val p: Ptr[TF_Output]) extends AnyVal { 2290 | def oper: Ptr[TF_Operation] = p._1 2291 | def oper_=(value: Ptr[TF_Operation]): Unit = p._1 = value 2292 | def index: CInt = p._2 2293 | def index_=(value: CInt): Unit = p._2 = value 2294 | } 2295 | 2296 | def TF_Output()(implicit z: Zone): Ptr[TF_Output] = 2297 | alloc[TF_Output]() 2298 | 2299 | implicit class TF_AttrMetadata_ops(val p: Ptr[TF_AttrMetadata]) 2300 | extends AnyVal { 2301 | def is_list: CUnsignedChar = p._1 2302 | def is_list_=(value: CUnsignedChar): Unit = p._1 = value 2303 | def list_size: int64_t = p._2 2304 | def list_size_=(value: int64_t): Unit = p._2 = value 2305 | def `type`: TF_AttrType = p._3 2306 | def `type_=`(value: TF_AttrType): Unit = p._3 = value 2307 | def total_size: int64_t = p._4 2308 | def total_size_=(value: int64_t): Unit = p._4 = value 2309 | } 2310 | 2311 | def TF_AttrMetadata()(implicit z: Zone): Ptr[TF_AttrMetadata] = 2312 | alloc[TF_AttrMetadata]() 2313 | 2314 | implicit class TF_WhileParams_ops(val p: Ptr[TF_WhileParams]) extends AnyVal { 2315 | def ninputs: CInt = p._1 2316 | def ninputs_=(value: CInt): Unit = p._1 = value 2317 | def cond_graph: Ptr[TF_Graph] = p._2 2318 | def cond_graph_=(value: Ptr[TF_Graph]): Unit = p._2 = value 2319 | def cond_inputs: Ptr[TF_Output] = p._3 2320 | def cond_inputs_=(value: Ptr[TF_Output]): Unit = p._3 = value 2321 | def cond_output: Ptr[TF_Output] = p._4 // TF_output 2322 | def cond_output_=(value: Ptr[TF_Output]): Unit = p._4 = value // TF_output 2323 | def body_graph: Ptr[TF_Graph] = p._5 2324 | def body_graph_=(value: Ptr[TF_Graph]): Unit = p._5 = value 2325 | def body_inputs: Ptr[TF_Output] = p._6 2326 | def body_inputs_=(value: Ptr[TF_Output]): Unit = p._6 = value 2327 | def body_outputs: Ptr[TF_Output] = p._7 2328 | def body_outputs_=(value: Ptr[TF_Output]): Unit = p._7 = value 2329 | def name: CString = p._8 2330 | def name_=(value: CString): Unit = p._8 = value 2331 | } 2332 | 2333 | def TF_WhileParams()(implicit z: Zone): Ptr[TF_WhileParams] = 2334 | alloc[TF_WhileParams]() 2335 | } 2336 | -------------------------------------------------------------------------------- /stensorflow/src/test/scala/org/ekrich/tensorflow/unsafe/TensorflowTest.scala: -------------------------------------------------------------------------------- 1 | package org.ekrich.tensorflow.unsafe 2 | 3 | import org.junit.Assert._ 4 | import org.junit.Test 5 | 6 | import scalanative.libc.stdlib 7 | import scalanative.unsafe.{CFloat, CFuncPtr3, CSize, fromCString} 8 | import scalanative.unsafe.{Ptr, Zone, alloc, sizeof} 9 | import scalanative.unsigned._ 10 | 11 | import org.ekrich.tensorflow.unsafe.tensorflow._ 12 | import org.ekrich.tensorflow.unsafe.tensorflowEnums._ 13 | 14 | class TensorflowTest { 15 | 16 | val tfMinVersion = "2.15" 17 | 18 | // major, minor 19 | case class Version(major: Int, minor: Int) extends Ordered[Version] { 20 | import math.Ordered.orderingToOrdered 21 | def compare(that: Version): Int = 22 | (this.major, this.minor).compare(that.major, that.minor) 23 | } 24 | 25 | def version(ver: String): Version = { 26 | val arr = ver.split("[.]") 27 | Version(arr(0).toInt, arr(1).toInt) 28 | } 29 | 30 | type DeallocateTensor = CFuncPtr3[Ptr[Byte], CSize, Ptr[Byte], Unit] 31 | 32 | val deallocateTensor: DeallocateTensor = 33 | (data: Ptr[Byte], len: CSize, deallocateArg: Ptr[Byte]) => { 34 | stdlib.free(data) 35 | println("Free Original Tensor") 36 | } 37 | 38 | @Test def TF_VersionTest(): Unit = { 39 | Zone { 40 | val tfVersion = fromCString(TF_Version()) 41 | println(s"Tensorflow version: ${tfVersion}") 42 | val swVersion = version(tfVersion) 43 | val minVersion = version(tfMinVersion) 44 | assertTrue( 45 | s"Looking for version: $tfMinVersion", 46 | minVersion <= swVersion 47 | ) 48 | } 49 | } 50 | @Test def TF_ExampleTest(): Unit = { 51 | Zone { 52 | println("Running example...") 53 | 54 | // handle dims 55 | val dimsVals = Seq(1, 5, 12) 56 | val dimsSize = dimsVals.size 57 | val dimsBytes = dimsSize.toUSize * sizeof[int64_t] 58 | val dims = stdlib.malloc(dimsBytes).asInstanceOf[Ptr[int64_t]] 59 | 60 | // copy to memory 61 | for (i <- 0 until dimsSize) { 62 | dims(i) = dimsVals(i) 63 | } 64 | 65 | // handle data based on dims 66 | val dataVals = Seq( 67 | -0.4809832f, -0.3770838f, 0.1743573f, 0.7720509f, -0.4064746f, 68 | 0.0116595f, 0.0051413f, 0.9135732f, 0.7197526f, -0.0400658f, 0.1180671f, 69 | -0.6829428f, -0.4810135f, -0.3772099f, 0.1745346f, 0.7719303f, 70 | -0.4066443f, 0.0114614f, 0.0051195f, 0.9135003f, 0.7196983f, 71 | -0.0400035f, 0.1178188f, -0.6830465f, -0.4809143f, -0.3773398f, 72 | 0.1746384f, 0.7719052f, -0.4067171f, 0.0111654f, 0.0054433f, 0.9134697f, 73 | 0.7192584f, -0.0399981f, 0.1177435f, -0.6835230f, -0.4808300f, 74 | -0.3774327f, 0.1748246f, 0.7718700f, -0.4070232f, 0.0109549f, 75 | 0.0059128f, 0.9133330f, 0.7188759f, -0.0398740f, 0.1181437f, 76 | -0.6838635f, -0.4807833f, -0.3775733f, 0.1748378f, 0.7718275f, 77 | -0.4073670f, 0.0107582f, 0.0062978f, 0.9131795f, 0.7187147f, 78 | -0.0394935f, 0.1184392f, -0.6840039f 79 | ) 80 | 81 | // dimensions need to match data 82 | val dataSize = dimsVals.reduceLeft(_ * _) 83 | val dataBytes = dataSize.toUSize * sizeof[CFloat] 84 | // val data = alloc[CFloat](dataSize) 85 | val data = stdlib.malloc(dataBytes).asInstanceOf[Ptr[CFloat]] 86 | 87 | // copy to memory 88 | for (i <- 0 until dataSize) { 89 | data(i) = dataVals(i) 90 | } 91 | 92 | println(dimsVals) 93 | println(dimsSize) 94 | println(dims) 95 | 96 | println(dataVals) 97 | println(dataSize) 98 | println(dataBytes) 99 | println(data) 100 | 101 | // same as null? 102 | val nullptr = alloc[Byte]() 103 | !nullptr = 0x00 104 | 105 | println("Create Tensor") 106 | val tensor = 107 | TF_NewTensor( 108 | TF_FLOAT, 109 | dims, 110 | dimsSize, 111 | data.asInstanceOf[Ptr[Byte]], 112 | dataBytes, 113 | deallocateTensor, 114 | nullptr 115 | ); 116 | 117 | println(s"Tensor: $tensor") 118 | 119 | if (tensor == null) { 120 | println("Wrong create tensor") 121 | } 122 | 123 | if (TF_TensorType(tensor) != TF_FLOAT) { 124 | println("Wrong tensor type") 125 | } 126 | 127 | if (TF_NumDims(tensor) != dimsSize) { 128 | println(s"Wrong number of dimensions") 129 | } 130 | 131 | for (i <- 0 until dimsSize) { 132 | if (TF_Dim(tensor, i) != dims(i)) { 133 | println(s"Wrong dimension size for dim: $i") 134 | } 135 | } 136 | 137 | if (TF_TensorByteSize(tensor) != dataBytes) { 138 | println("Wrong tensor byte size") 139 | } 140 | 141 | val tensor_data = TF_TensorData(tensor).asInstanceOf[Ptr[Float]] 142 | 143 | if (tensor_data == null) { 144 | println("Wrong data tensor") 145 | } 146 | 147 | for (i <- 0 until dataSize) { 148 | if (tensor_data(i) != dataVals(i)) { 149 | println(s"Element: $i does not match") 150 | } 151 | } 152 | 153 | println("Success create tensor") 154 | TF_DeleteTensor(tensor) 155 | println("Done.") 156 | } 157 | } 158 | } 159 | --------------------------------------------------------------------------------