├── .github ├── dependabot.yml └── workflows │ ├── build.yml │ └── security.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── build.gradle ├── bump_ort_version.sh ├── download.sh ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── jextract.sh ├── onnx ├── build.gradle └── src │ └── main │ └── proto │ └── onnx-ml.proto ├── onnxruntime-benchmark ├── benchmark.sh ├── build.gradle └── src │ ├── jmh │ └── java │ │ └── com │ │ └── jyuzawa │ │ └── onnxruntime_benchmark │ │ └── Microbenchmark.java │ └── main │ ├── java │ └── com │ │ └── jyuzawa │ │ └── onnxruntime_benchmark │ │ ├── Benchmark.java │ │ ├── Microsoft.java │ │ ├── OnnxruntimeJava.java │ │ ├── OnnxruntimeJavaIoBinding.java │ │ └── Wrapper.java │ └── resources │ └── model.onnx ├── onnxruntime-sample-application ├── build.gradle └── src │ └── main │ └── java │ └── com │ └── example │ └── onnxruntime_sample_application │ └── Application.java ├── onnxruntime-sample-library ├── build.gradle └── src │ └── main │ ├── java │ └── com │ │ └── example │ │ └── onnxruntime_sample_library │ │ └── MyOracle.java │ └── resources │ └── floatIdentity.onnx ├── onnxruntime_all.h ├── settings.gradle ├── src ├── dist │ └── META-INF │ │ └── NOTICE.txt ├── main │ └── java │ │ ├── com │ │ └── jyuzawa │ │ │ ├── onnxruntime │ │ │ ├── Api.java │ │ │ ├── ApiImpl.java │ │ │ ├── Environment.java │ │ │ ├── EnvironmentImpl.java │ │ │ ├── ExecutionProvider.java │ │ │ ├── ExecutionProviderCPUConfig.java │ │ │ ├── ExecutionProviderCUDAConfig.java │ │ │ ├── ExecutionProviderConfig.java │ │ │ ├── ExecutionProviderConfigFactory.java │ │ │ ├── ExecutionProviderCoreMLConfig.java │ │ │ ├── ExecutionProviderDnnlConfig.java │ │ │ ├── ExecutionProviderMIGraphXConfig.java │ │ │ ├── ExecutionProviderMapConfig.java │ │ │ ├── ExecutionProviderObjectConfig.java │ │ │ ├── ExecutionProviderOpenVINOConfig.java │ │ │ ├── ExecutionProviderROCMConfig.java │ │ │ ├── ExecutionProviderSimpleMapConfig.java │ │ │ ├── ExecutionProviderTensorRTConfig.java │ │ │ ├── ExecutionProviderVitisAIConfig.java │ │ │ ├── IoBinding.java │ │ │ ├── IoBindingImpl.java │ │ │ ├── Loader.java │ │ │ ├── ManagedImpl.java │ │ │ ├── MapInfo.java │ │ │ ├── MapInfoImpl.java │ │ │ ├── ModelMetadata.java │ │ │ ├── ModelMetadataImpl.java │ │ │ ├── NamedCollection.java │ │ │ ├── NamedCollectionImpl.java │ │ │ ├── NodeInfo.java │ │ │ ├── NodeInfoImpl.java │ │ │ ├── OnnxMap.java │ │ │ ├── OnnxMapImpl.java │ │ │ ├── OnnxMapLongImpl.java │ │ │ ├── OnnxMapStringImpl.java │ │ │ ├── OnnxOpaque.java │ │ │ ├── OnnxOptional.java │ │ │ ├── OnnxRuntime.java │ │ │ ├── OnnxRuntimeException.java │ │ │ ├── OnnxRuntimeExecutionMode.java │ │ │ ├── OnnxRuntimeImpl.java │ │ │ ├── OnnxRuntimeLoggingLevel.java │ │ │ ├── OnnxRuntimeOptimizationLevel.java │ │ │ ├── OnnxSequence.java │ │ │ ├── OnnxSequenceImpl.java │ │ │ ├── OnnxSparseTensor.java │ │ │ ├── OnnxTensor.java │ │ │ ├── OnnxTensorBufferImpl.java │ │ │ ├── OnnxTensorByteImpl.java │ │ │ ├── OnnxTensorDoubleImpl.java │ │ │ ├── OnnxTensorElementDataType.java │ │ │ ├── OnnxTensorFloatImpl.java │ │ │ ├── OnnxTensorImpl.java │ │ │ ├── OnnxTensorIntImpl.java │ │ │ ├── OnnxTensorLongImpl.java │ │ │ ├── OnnxTensorShortImpl.java │ │ │ ├── OnnxTensorStringImpl.java │ │ │ ├── OnnxType.java │ │ │ ├── OnnxTypedMap.java │ │ │ ├── OnnxValue.java │ │ │ ├── OnnxValueImpl.java │ │ │ ├── OpaqueInfo.java │ │ │ ├── Session.java │ │ │ ├── SessionImpl.java │ │ │ ├── TensorInfo.java │ │ │ ├── TensorInfoImpl.java │ │ │ ├── Transaction.java │ │ │ ├── TransactionImpl.java │ │ │ ├── TypeInfo.java │ │ │ ├── TypeInfoImpl.java │ │ │ ├── ValueContext.java │ │ │ └── package-info.java │ │ │ └── onnxruntime_extern │ │ │ ├── OrtAllocator.java │ │ │ ├── OrtApi.java │ │ │ ├── OrtApiBase.java │ │ │ ├── OrtCUDAProviderOptions.java │ │ │ ├── OrtCustomCreateThreadFn.java │ │ │ ├── OrtCustomHandleType.java │ │ │ ├── OrtCustomJoinThreadFn.java │ │ │ ├── OrtCustomOp.java │ │ │ ├── OrtLoggingFunction.java │ │ │ ├── OrtMIGraphXProviderOptions.java │ │ │ ├── OrtOpenVINOProviderOptions.java │ │ │ ├── OrtROCMProviderOptions.java │ │ │ ├── OrtTensorRTProviderOptions.java │ │ │ ├── OrtThreadWorkerFn.java │ │ │ ├── RegisterCustomOpsFn.java │ │ │ ├── RunAsyncCallbackFn.java │ │ │ └── onnxruntime_all_h.java │ │ └── module-info.java └── test │ ├── java │ └── com │ │ └── jyuzawa │ │ └── onnxruntime │ │ ├── ConcurrencyTest.java │ │ └── SessionTest.java │ └── resources │ ├── libcustom_op_library.dylib │ └── log4j2.xml └── symbols.conf /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: gradle 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | day: "monday" 8 | time: "18:00" 9 | timezone: "America/New_York" 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "weekly" 14 | day: "monday" 15 | time: "18:00" 16 | timezone: "America/New_York" -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | - master 8 | pull_request: 9 | branches: 10 | - develop 11 | - master 12 | 13 | jobs: 14 | build: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | # baseline version 19 | java: [22] 20 | os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] 21 | include: 22 | # latest version 23 | - java: 23 24 | os: ubuntu-latest 25 | runs-on: ${{ matrix.os }} 26 | name: Build on ${{ matrix.os }} on Java ${{ matrix.java }} 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: gradle/actions/wrapper-validation@v4 30 | - name: Set up JDK ${{ matrix.java }} 31 | uses: actions/setup-java@v4 32 | with: 33 | distribution: oracle 34 | java-version: ${{ matrix.java }} 35 | - name: Setup Gradle 36 | uses: gradle/gradle-build-action@v3 37 | - name: Build with Gradle 38 | run: | 39 | ./gradlew --no-daemon -s :build 40 | - name: Upload coverage to Codecov 41 | if: ${{ matrix.java == '22' && matrix.os == 'ubuntu-latest' }} 42 | uses: codecov/codecov-action@v5.4.3 43 | env: 44 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 45 | - name: Save reports 46 | uses: actions/upload-artifact@v4 47 | if: ${{ failure() }} 48 | with: 49 | name: reports-${{ matrix.os }}-${{ matrix.java }} 50 | path: build/reports/ 51 | -------------------------------------------------------------------------------- /.github/workflows/security.yml: -------------------------------------------------------------------------------- 1 | name: Security 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '27 23 * * 1' 7 | 8 | jobs: 9 | security: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | actions: read 14 | security-events: write 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: gradle/actions/wrapper-validation@v4 18 | - name: Use JDK 22 19 | uses: actions/setup-java@v4 20 | with: 21 | distribution: oracle 22 | java-version: 22 23 | - name: Initialize CodeQL 24 | uses: github/codeql-action/init@v3 25 | with: 26 | languages: java 27 | - name: Setup Gradle 28 | uses: gradle/gradle-build-action@v3 29 | with: 30 | cache-read-only: true 31 | - name: Build with Gradle 32 | run: | 33 | ./gradlew --no-daemon --no-build-cache -s :build 34 | - name: Perform CodeQL Analysis 35 | uses: github/codeql-action/analyze@v3 36 | with: 37 | category: "/language:java" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.gradle/ 2 | build/ 3 | 4 | .idea/ 5 | *.iml 6 | *.ipr 7 | *.iws 8 | 9 | .DS_Store 10 | .settings 11 | 12 | .vscode/ 13 | 14 | cmake-build-*/ 15 | 16 | 17 | .project 18 | .classpath 19 | bin/ 20 | 21 | cache/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM eclipse-temurin:17-jdk-jammy as builder 2 | WORKDIR /build 3 | ADD https://github.com/llvm/llvm-project/releases/download/llvmorg-15.0.6/clang+llvm-15.0.6-x86_64-linux-gnu-ubuntu-18.04.tar.xz clang_llvm.tar.xz 4 | RUN apt-get update && apt-get install -y -q \ 5 | xz-utils \ 6 | && apt-get clean 7 | RUN tar xvf clang_llvm.tar.xz && rm clang_llvm.tar.xz && mv clang* clang_llvm 8 | ADD https://download.java.net/java/GA/jdk22/830ec9fcccef480bb3e73fb7ecafe059/36/GPL/openjdk-22_linux-x64_bin.tar.gz openjdk.tar.gz 9 | RUN tar xvxf openjdk.tar.gz && rm openjdk.tar.gz 10 | ADD https://github.com/openjdk/jextract/archive/b9ec8879cff052b463237fdd76382b3a5cd8ff2b.tar.gz jextract.tar.gz 11 | RUN tar xvzf jextract.tar.gz && \ 12 | rm jextract.tar.gz && \ 13 | mv jextract-* jextract 14 | WORKDIR /build/jextract 15 | RUN sh ./gradlew --version 16 | RUN sh ./gradlew --no-daemon -Pjdk22_home=/build/jdk-22 -Pllvm_home=/build/clang_llvm clean verify 17 | 18 | 19 | FROM eclipse-temurin:17-jre-jammy 20 | RUN apt-get update && apt-get install -y -q \ 21 | build-essential \ 22 | && apt-get clean 23 | WORKDIR /jextract 24 | COPY --from=builder /build/jextract/build/jextract . 25 | ENTRYPOINT ["/jextract/bin/jextract"] 26 | CMD ["/bin/bash"] 27 | WORKDIR /workdir -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # onnxruntime-java 2 | by [@yuzawa-san](https://github.com/yuzawa-san/) 3 | 4 | [![build](https://github.com/yuzawa-san/onnxruntime-java/workflows/build/badge.svg)](https://github.com/yuzawa-san/onnxruntime-java/actions) 5 | [![codecov](https://codecov.io/gh/yuzawa-san/onnxruntime-java/branch/master/graph/badge.svg)](https://codecov.io/gh/yuzawa-san/onnxruntime-java) 6 | 7 | This is an **performant** and **modern** Java binding to Microsoft's [ONNX Runtime](https://github.com/microsoft/onnxruntime) which uses Java's new Foreign Function & Memory API (a.k.a. Project Panama). 8 | 9 | This project's goals are to provide a type-safe, lightweight, and performant binding which abstracts a lot of the native and C API intricacies away behind a Java-friendly interface. 10 | This is loosely coupled to the upstream project and built off of the public (and stable) [C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html). 11 | 12 | The minimum supported Java version is 22, since the FFI API was introduced (and taken out of preview) in that version. 13 | There are [other](https://github.com/bytedeco/javacpp-presets/tree/master/onnxruntime) [fine](https://github.com/microsoft/onnxruntime/tree/main/java) bindings which use JNI and are capable of supporting earlier Java versions. 14 | 15 | ## Usage 16 | 17 | 18 | This project is released to [Maven Central](https://search.maven.org/artifact/com.jyuzawa/onnxruntime) and can be used in your project. 19 | 20 | ### Artifacts 21 | 22 | The library is currently built for Linux, Windows, MacOS and for arm64 and x86_64. 23 | These were chosen since the upstream projects publishes artifacts for these enviroments. 24 | Here are the artifacts published listed below. 25 | Snapshot releases are periodically released for testing and experimentation. 26 | 27 | #### onnxruntime 28 | 29 | [![maven](https://img.shields.io/maven-central/v/com.jyuzawa/onnxruntime)](https://search.maven.org/artifact/com.jyuzawa/onnxruntime) [![javadoc](https://javadoc.io/badge2/com.jyuzawa/onnxruntime/javadoc.svg)](https://javadoc.io/doc/com.jyuzawa/onnxruntime) [![maven-snapshot](https://img.shields.io/maven-metadata/v?metadataUrl=https%3A%2F%2Fs01.oss.sonatype.org%2Fcontent%2Frepositories%2Fsnapshots%2Fcom%2Fjyuzawa%2Fonnxruntime%2Fmaven-metadata.xml)](https://s01.oss.sonatype.org/content/repositories/snapshots/com/jyuzawa/onnxruntime/) 30 | 31 | The binding with no native libraries. For use as a implementation dependency. 32 | 33 | The native library (from [Microsoft](https://github.com/microsoft/onnxruntime/releases)) will need to be provided at runtime using one of the next two artifacts. 34 | Alternatively, the Java library path (`java.library.path`) will be used if neither of those artifacts is provided. 35 | This allows users to "bring their own" shared library. 36 | The API has a validation to make sure the shared library is minor version compatible with this library. 37 | 38 | #### onnxruntime-cpu 39 | 40 | [![maven](https://img.shields.io/maven-central/v/com.jyuzawa/onnxruntime-cpu)](https://search.maven.org/artifact/com.jyuzawa/onnxruntime-cpu) [![maven-snapshot](https://img.shields.io/maven-metadata/v?metadataUrl=https%3A%2F%2Fs01.oss.sonatype.org%2Fcontent%2Frepositories%2Fsnapshots%2Fcom%2Fjyuzawa%2Fonnxruntime-cpu%2Fmaven-metadata.xml)](https://s01.oss.sonatype.org/content/repositories/snapshots/com/jyuzawa/onnxruntime-cpu/) 41 | 42 | A collection of native libraries with CPU support for a several common OS/architecture combinations. For use as an optional runtime dependency. Include one of the OS/Architecture classifiers like `osx-x86_64` to provide specific support. 43 | 44 | #### onnxruntime-gpu 45 | 46 | See https://github.com/yuzawa-san/onnxruntime-java/issues/258 47 | 48 | ### In your library 49 | 50 | There is an example library in the `onnxruntime-sample-library` directory. 51 | The library should use the `onnxruntime` as a implementation dependency. 52 | This puts the burden of providing a native library on your end user. 53 | 54 | ### In your application 55 | 56 | There is an example application in the `onnxruntime-sample-application` directory. 57 | The library should use the `onnxruntime` as a implementation dependency. 58 | The application needs to have acccess to the native library. 59 | You have the option providing it via a runtime dependency using either a classifier variant from `onnxruntime-cpu`. 60 | Otherwise, the Java library path will be used to load the native library. 61 | 62 | 63 | The example application can be ran: 64 | ````bash 65 | ./gradlew onnxruntime-sample-application:run 66 | ```` 67 | 68 | #### JVM Arguments 69 | 70 | Since this uses a native library, this will require the runtime to have the `--enable-native-access` JVM option, likely `--enable-native-access=ALL-UNNAMED`. 71 | 72 | ### Execution Providers 73 | 74 | Only those which are exposed in the C API are supported. 75 | If you wish to use another execution provider which is present in the C API, but not in any of the artifacts from the upstream project, you can choose to bring your own onnxruntime shared library to link against. 76 | 77 | ## Versioning 78 | 79 | The version of the upstream project used will be reflected in the release notes. 80 | Semantic versioning is used. 81 | Major version will be bumped when this API or the underlying C API has backward incompatible changes. 82 | Upstream major version changes will typically be major version changes here. 83 | Minor version will be bumped for smaller, but compatible changes. 84 | Upstream minor version changes will typically be minor version changes here. 85 | 86 | The `onnxruntime-cpu` artifacts are versioned to match the upstream versions and depend on a minimum compatible `onnxruntime` version. 87 | -------------------------------------------------------------------------------- /bump_ort_version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | NEW_VERSION=$(curl -s https://api.github.com/repos/microsoft/onnxruntime/tags | jq -r ".[0].name" | sed 's/^v//g') 4 | read -p "Use version $NEW_VERSION (y/n)? " answer 5 | case ${answer:0:1} in 6 | y|Y ) 7 | echo "Using version $NEW_VERSION" 8 | ;; 9 | * ) 10 | exit 1; 11 | ;; 12 | esac 13 | git checkout ort/$NEW_VERSION && echo "please delete existing branch ort/$NEW_VERSION" && exit 1 14 | 15 | git checkout master && \ 16 | git pull && \ 17 | git checkout -b ort/$NEW_VERSION && \ 18 | sed -i '' "s/^com.jyuzawa.onnxruntime.library_version=.*/com.jyuzawa.onnxruntime.library_version=$NEW_VERSION/" gradle.properties && \ 19 | git commit -am "ORT v$NEW_VERSION bump" && \ 20 | ./gradlew jextract && \ 21 | git add src/main/java/com/jyuzawa/onnxruntime_extern && \ 22 | git commit -am "ORT v$NEW_VERSION jextract" 23 | ./gradlew clean build -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | doIt () { 4 | OS_ARCH=$1 5 | GPU_SUFFIX=$2 6 | DOWNLOAD_NAME=$3 7 | COMPRESSION=$4 8 | LIBRARY_SUFFIX=$5 9 | mkdir -p build/onnxruntime-${ORT_VERSION} 10 | FILENAME=onnxruntime-${DOWNLOAD_NAME}-${ORT_VERSION}.${COMPRESSION} 11 | cd build/onnxruntime-${ORT_VERSION} 12 | rm -rf ${FILENAME} ${OS_ARCH}${GPU_SUFFIX} onnxruntime-${DOWNLOAD_NAME}-${ORT_VERSION} 13 | curl -OL https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/${FILENAME} 14 | if [ "${COMPRESSION}" = "zip" ]; then 15 | unzip ${FILENAME} 16 | else 17 | # https://github.com/msys2/MSYS2-packages/issues/1216 18 | # https://github.com/heroku/heroku-slugs/issues/3 19 | tar xvzf ${FILENAME} || tar xvzf ${FILENAME} 20 | fi 21 | mv onnxruntime-${DOWNLOAD_NAME}-${ORT_VERSION} ${OS_ARCH}${GPU_SUFFIX} 22 | ASSEMBLY_DIR=${OS_ARCH}${GPU_SUFFIX}/assembly/com/jyuzawa/onnxruntime/native/${OS_ARCH} 23 | mkdir -p ${ASSEMBLY_DIR} 24 | for i in $(ls ${OS_ARCH}${GPU_SUFFIX}/lib | grep "\.${LIBRARY_SUFFIX}" | grep -v "${ORT_VERSION}"); do 25 | cp ${OS_ARCH}${GPU_SUFFIX}/lib/$i ${ASSEMBLY_DIR} 26 | done 27 | cd ${ASSEMBLY_DIR} 28 | ls | grep -v libraries > libraries 29 | } 30 | 31 | case $VARIANT in 32 | linux-aarch_64) 33 | doIt linux-aarch_64 "" linux-aarch64 tgz so 34 | ;; 35 | linux-x86_64) 36 | doIt linux-x86_64 "" linux-x64 tgz so 37 | ;; 38 | linux-x86_64-gpu) 39 | doIt linux-x86_64 -gpu linux-x64-gpu tgz so 40 | ;; 41 | osx-x86_64) 42 | doIt osx-x86_64 "" osx-x86_64 tgz dylib 43 | ;; 44 | osx-aarch_64) 45 | doIt osx-aarch_64 "" osx-arm64 tgz dylib 46 | ;; 47 | windows-aarch_64) 48 | doIt windows-aarch_64 "" win-arm64 zip dll 49 | ;; 50 | windows-x86_64) 51 | doIt windows-x86_64 "" win-x64 zip dll 52 | ;; 53 | windows-x86_64-gpu) 54 | doIt windows-x86_64 -gpu win-x64-gpu zip dll 55 | ;; 56 | esac -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | version=2.0.0-SNAPSHOT 2 | com.jyuzawa.onnxruntime.library_version=1.20.1 3 | com.jyuzawa.onnxruntime.library_baseline=2.0.0 4 | org.gradle.parallel=true 5 | org.gradle.caching=true 6 | org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ 7 | --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ 8 | --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ 9 | --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ 10 | --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED 11 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuzawa-san/onnxruntime-java/922a424b9c103e4f235e7b8f21afa1be520faa11/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-8.12-bin.zip 4 | networkTimeout=10000 5 | validateDistributionUrl=true 6 | zipStoreBase=GRADLE_USER_HOME 7 | zipStorePath=wrapper/dists 8 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | @rem SPDX-License-Identifier: Apache-2.0 17 | @rem 18 | 19 | @if "%DEBUG%"=="" @echo off 20 | @rem ########################################################################## 21 | @rem 22 | @rem Gradle startup script for Windows 23 | @rem 24 | @rem ########################################################################## 25 | 26 | @rem Set local scope for the variables with windows NT shell 27 | if "%OS%"=="Windows_NT" setlocal 28 | 29 | set DIRNAME=%~dp0 30 | if "%DIRNAME%"=="" set DIRNAME=. 31 | @rem This is normally unused 32 | set APP_BASE_NAME=%~n0 33 | set APP_HOME=%DIRNAME% 34 | 35 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 36 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 37 | 38 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 39 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 40 | 41 | @rem Find java.exe 42 | if defined JAVA_HOME goto findJavaFromJavaHome 43 | 44 | set JAVA_EXE=java.exe 45 | %JAVA_EXE% -version >NUL 2>&1 46 | if %ERRORLEVEL% equ 0 goto execute 47 | 48 | echo. 1>&2 49 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 50 | echo. 1>&2 51 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2 52 | echo location of your Java installation. 1>&2 53 | 54 | goto fail 55 | 56 | :findJavaFromJavaHome 57 | set JAVA_HOME=%JAVA_HOME:"=% 58 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 59 | 60 | if exist "%JAVA_EXE%" goto execute 61 | 62 | echo. 1>&2 63 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 64 | echo. 1>&2 65 | echo Please set the JAVA_HOME variable in your environment to match the 1>&2 66 | echo location of your Java installation. 1>&2 67 | 68 | goto fail 69 | 70 | :execute 71 | @rem Setup the command line 72 | 73 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 74 | 75 | 76 | @rem Execute Gradle 77 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 78 | 79 | :end 80 | @rem End local scope for the variables with windows NT shell 81 | if %ERRORLEVEL% equ 0 goto mainEnd 82 | 83 | :fail 84 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 85 | rem the _cmd.exe /c_ return code! 86 | set EXIT_CODE=%ERRORLEVEL% 87 | if %EXIT_CODE% equ 0 set EXIT_CODE=1 88 | if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% 89 | exit /b %EXIT_CODE% 90 | 91 | :mainEnd 92 | if "%OS%"=="Windows_NT" endlocal 93 | 94 | :omega 95 | -------------------------------------------------------------------------------- /jextract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # TODO: enable jextract 4 | # GENERATED_DIR=build/generated/source/jextract 5 | GENERATED_DIR=src/main/java 6 | HEADER_DIR=build/onnxruntime-${ORT_VERSION}/headers 7 | mkdir -p ${HEADER_DIR} 8 | cp build/onnxruntime-${ORT_VERSION}/osx-x86_64/include/*.h ${HEADER_DIR} 9 | cp build/onnxruntime-${ORT_VERSION}/linux-x86_64-gpu/include/*.h ${HEADER_DIR} 10 | HEADER_FILE=onnxruntime_all.h 11 | rm -rf 'symbols.conf' 'src/main/java/com/jyuzawa/onnxruntime_extern' 12 | docker build -t onnxruntime-jextract . 13 | #docker run --rm -v `pwd`:/workdir onnxruntime-jextract 14 | /Users/jtyuzawa/Documents/personal_repos/jextract/build/jextract/bin/jextract --use-system-load-library --output ${GENERATED_DIR} -l onnxruntime --target-package com.jyuzawa.onnxruntime_extern -I /usr/include -I /usr/include/x86_64-linux-gnu -I ${HEADER_DIR} --dump-includes symbols.conf ${HEADER_FILE} 15 | # strip out the irrelevant symbols 16 | csplit symbols.conf "/headers/" 17 | rm xx00 18 | mv xx01 symbols.conf 19 | /Users/jtyuzawa/Documents/personal_repos/jextract/build/jextract/bin/jextract --use-system-load-library --output ${GENERATED_DIR} -l onnxruntime --target-package com.jyuzawa.onnxruntime_extern -I /usr/include -I /usr/include/x86_64-linux-gnu -I ${HEADER_DIR} @symbols.conf ${HEADER_FILE} 20 | # strip out loads since we'll manage load 21 | RUNTIME_HELPER=${GENERATED_DIR}/com/jyuzawa/onnxruntime_extern/onnxruntime_all_h.java 22 | grep -v "System.loadLibrary" < ${RUNTIME_HELPER} > ${RUNTIME_HELPER}.bak 23 | mv ${RUNTIME_HELPER}.bak ${RUNTIME_HELPER} 24 | -------------------------------------------------------------------------------- /onnx/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "com.google.protobuf" version "0.9.5" 3 | } 4 | apply plugin: "java-library" 5 | 6 | dependencies { 7 | api 'com.google.protobuf:protobuf-java:4.31.1' 8 | } 9 | 10 | protobuf { 11 | protoc { 12 | artifact = "com.google.protobuf:protoc:4.31.1" 13 | } 14 | generateProtoTasks { 15 | ofSourceSet('main') 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p build/jfr 4 | JFR=build/jfr/ort-`date +"%Y-%m-%dT%H-%M-%S%z"`.jfr 5 | echo "writing output to ${JFR}"; 6 | export JAVA_HOME=/Library/Java/JavaVirtualMachines/jdk-20.jdk/Contents/Home 7 | java -version 8 | export ONNXRUNTIME_BENCHMARK_OPTS="-agentpath:/Users/jtyuzawa/Downloads/async-profiler-2.9-macos/build/libasyncProfiler.dylib=start,event=cpu,alloc,lock,file=${JFR}" 9 | exec build/install/onnxruntime-benchmark/bin/onnxruntime-benchmark -------------------------------------------------------------------------------- /onnxruntime-benchmark/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "me.champeau.jmh" version "0.7.3" 3 | id "io.morethan.jmhreport" version "0.9.6" 4 | } 5 | apply plugin: "application" 6 | 7 | sourceSets { 8 | jmh { 9 | compileClasspath += sourceSets.main.output 10 | runtimeClasspath += sourceSets.main.output 11 | } 12 | } 13 | 14 | configurations { 15 | jmhImplementation.extendsFrom implementation 16 | jmhRuntimeOnly.extendsFrom runtimeOnly 17 | } 18 | 19 | dependencies { 20 | implementation project(":") 21 | implementation project(path: ":", configuration: 'cpu') 22 | implementation project(':onnx') 23 | implementation "com.microsoft.onnxruntime:onnxruntime:1.22.0" 24 | // jmh "org.bytedeco:onnxruntime-platform:1.13.1-1.5.8" 25 | jmh 'org.openjdk.jmh:jmh-core:1.37' 26 | jmh 'org.openjdk.jmh:jmh-generator-annprocess:1.37' 27 | } 28 | 29 | application { 30 | mainClass = 'com.jyuzawa.onnxruntime_benchmark.Benchmark' 31 | applicationDefaultJvmArgs = ['--enable-native-access=ALL-UNNAMED'] 32 | } 33 | 34 | task benchmark(type: Exec) { 35 | commandLine "${projectDir}/benchmark.sh" 36 | dependsOn installDist 37 | } 38 | 39 | jmh { 40 | resultFormat = "JSON" 41 | resultsFile = project.file("${project.buildDir}/reports/jmh/result.json") 42 | profilers = ["gc"] //, "async:libPath=/Users/jtyuzawa/Downloads/async-profiler-2.9-macos/build/libasyncProfiler.dylib"] 43 | } 44 | 45 | jmhReport { 46 | jmhResultPath = project.file("${project.buildDir}/reports/jmh/result.json") 47 | jmhReportOutput = project.file("${project.buildDir}/reports/jmh") 48 | } 49 | tasks.jmh.finalizedBy tasks.jmhReport 50 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/jmh/java/com/jyuzawa/onnxruntime_benchmark/Microbenchmark.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | import java.util.concurrent.ThreadLocalRandom; 8 | import java.util.concurrent.TimeUnit; 9 | import onnx.OnnxMl.GraphProto; 10 | import onnx.OnnxMl.ModelProto; 11 | import onnx.OnnxMl.NodeProto; 12 | import onnx.OnnxMl.OperatorSetIdProto; 13 | import onnx.OnnxMl.TensorProto.DataType; 14 | import onnx.OnnxMl.TensorShapeProto; 15 | import onnx.OnnxMl.TensorShapeProto.Dimension; 16 | import onnx.OnnxMl.TypeProto; 17 | import onnx.OnnxMl.TypeProto.Tensor; 18 | import onnx.OnnxMl.ValueInfoProto; 19 | import org.openjdk.jmh.annotations.Benchmark; 20 | import org.openjdk.jmh.annotations.BenchmarkMode; 21 | import org.openjdk.jmh.annotations.Fork; 22 | import org.openjdk.jmh.annotations.Measurement; 23 | import org.openjdk.jmh.annotations.Mode; 24 | import org.openjdk.jmh.annotations.OutputTimeUnit; 25 | import org.openjdk.jmh.annotations.Param; 26 | import org.openjdk.jmh.annotations.Setup; 27 | import org.openjdk.jmh.annotations.State; 28 | import org.openjdk.jmh.annotations.Threads; 29 | import org.openjdk.jmh.annotations.Warmup; 30 | import org.openjdk.jmh.infra.Blackhole; 31 | 32 | @BenchmarkMode(Mode.AverageTime) 33 | @Warmup(iterations = 3, time = 1, timeUnit = TimeUnit.SECONDS) 34 | @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) 35 | @State(org.openjdk.jmh.annotations.Scope.Thread) 36 | @OutputTimeUnit(TimeUnit.NANOSECONDS) 37 | @Fork( 38 | value = 3, 39 | jvmArgsAppend = {"--enable-native-access=ALL-UNNAMED", "--enable-preview"}) 40 | public class Microbenchmark { 41 | 42 | private static final String ONNXRUNTIME_JAVA = "onnxruntime-java"; 43 | private static final String ONNXRUNTIME_JAVA_ARENA = "onnxruntime-java-arena"; 44 | private static final String ONNXRUNTIME_JAVA_IOBINDING = "onnxruntime-java-iobinding"; 45 | private static final String MICROSOFT = "microsoft"; 46 | 47 | @Param(value = {ONNXRUNTIME_JAVA, ONNXRUNTIME_JAVA_ARENA, ONNXRUNTIME_JAVA_IOBINDING, MICROSOFT}) 48 | private String implementation; 49 | 50 | @Param({"16", "256", "4096"}) 51 | private int size; 52 | 53 | private long[] input; 54 | private Wrapper wrapper; 55 | 56 | @Setup 57 | public void setup() throws Exception { 58 | TypeProto type = TypeProto.newBuilder() 59 | .setTensorType(Tensor.newBuilder() 60 | .setElemType(DataType.INT64_VALUE) 61 | .setShape(TensorShapeProto.newBuilder() 62 | .addDim(Dimension.newBuilder().setDimValue(1)) 63 | .addDim(Dimension.newBuilder().setDimValue(size)))) 64 | .build(); 65 | ModelProto model = ModelProto.newBuilder() 66 | .setIrVersion(8) 67 | .addOpsetImport(OperatorSetIdProto.newBuilder().setVersion(15)) 68 | .setGraph(GraphProto.newBuilder() 69 | .addNode(NodeProto.newBuilder() 70 | .addInput("input") 71 | .addOutput("output") 72 | .setOpType("Identity")) 73 | .addInput(ValueInfoProto.newBuilder().setName("input").setType(type)) 74 | .addOutput(ValueInfoProto.newBuilder().setName("output").setType(type))) 75 | .build(); 76 | byte[] bytes = model.toByteArray(); 77 | input = new long[size]; 78 | ThreadLocalRandom random = ThreadLocalRandom.current(); 79 | for (int i = 0; i < size; i++) { 80 | input[i] = random.nextLong(); 81 | } 82 | wrapper = switch (implementation) { 83 | case ONNXRUNTIME_JAVA -> new OnnxruntimeJava(bytes, false, size); 84 | case ONNXRUNTIME_JAVA_ARENA -> new OnnxruntimeJava(bytes, true, size); 85 | case ONNXRUNTIME_JAVA_IOBINDING -> new OnnxruntimeJavaIoBinding(bytes, false, size); 86 | case MICROSOFT -> new Microsoft(bytes); 87 | default -> throw new IllegalArgumentException();}; 88 | } 89 | 90 | @Benchmark 91 | @Threads(Threads.MAX) 92 | public void run(Blackhole bh) throws Exception { 93 | bh.consume(wrapper.evaluate(input)); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Benchmark.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | import java.io.InputStream; 8 | import java.util.Arrays; 9 | import java.util.List; 10 | 11 | public final class Benchmark { 12 | 13 | public static final void main(String[] args) throws Exception { 14 | byte[] bytes; 15 | try (InputStream is = Benchmark.class.getResourceAsStream("/model.onnx")) { 16 | bytes = is.readAllBytes(); 17 | } 18 | long[] input = new long[] {1, 2, 3}; 19 | List wrappers = List.of(new OnnxruntimeJava(bytes, false, input.length), new Microsoft(bytes)); 20 | long i = 0; 21 | long startMs = System.currentTimeMillis(); 22 | while (i >= 0) { 23 | for (Wrapper wrapper : wrappers) { 24 | long[] output = wrapper.evaluate(input); 25 | if (!Arrays.equals(input, output)) { 26 | throw new RuntimeException("mismatch"); 27 | } 28 | } 29 | if (i % 10000 == 0) { 30 | if (System.currentTimeMillis() - startMs > 60000) { 31 | break; 32 | } 33 | } 34 | i++; 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Microsoft.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | import ai.onnxruntime.OnnxTensor; 8 | import ai.onnxruntime.OrtEnvironment; 9 | import ai.onnxruntime.OrtLoggingLevel; 10 | import ai.onnxruntime.OrtSession; 11 | import ai.onnxruntime.OrtSession.Result; 12 | import java.util.Map; 13 | 14 | final class Microsoft implements Wrapper { 15 | 16 | private static final OrtEnvironment ENVIRONMENT = 17 | OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING); 18 | 19 | private final OrtSession session; 20 | private final String inputName; 21 | 22 | Microsoft(byte[] bytes) throws Exception { 23 | this.session = ENVIRONMENT.createSession(bytes); 24 | this.inputName = session.getInputNames().iterator().next(); 25 | } 26 | 27 | @Override 28 | public void close() throws Exception { 29 | session.close(); 30 | } 31 | 32 | @Override 33 | public long[] evaluate(long[] input) throws Exception { 34 | 35 | try (OnnxTensor tensor = OnnxTensor.createTensor(ENVIRONMENT, new long[][] {input}); 36 | Result result = session.run(Map.of(inputName, tensor))) { 37 | return ((long[][]) result.get(0).getValue())[0]; 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJava.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | import com.jyuzawa.onnxruntime.Environment; 8 | import com.jyuzawa.onnxruntime.ExecutionProvider; 9 | import com.jyuzawa.onnxruntime.NamedCollection; 10 | import com.jyuzawa.onnxruntime.OnnxRuntime; 11 | import com.jyuzawa.onnxruntime.OnnxRuntimeLoggingLevel; 12 | import com.jyuzawa.onnxruntime.OnnxValue; 13 | import com.jyuzawa.onnxruntime.Session; 14 | import com.jyuzawa.onnxruntime.Transaction; 15 | import java.io.IOException; 16 | import java.util.Map; 17 | 18 | class OnnxruntimeJava implements Wrapper { 19 | 20 | private static final Environment ENVIRONMENT = OnnxRuntime.get() 21 | .getApi() 22 | .newEnvironment() 23 | .setLogSeverityLevel(OnnxRuntimeLoggingLevel.WARNING) 24 | .build(); 25 | 26 | protected final Session session; 27 | protected final long[] out; 28 | 29 | OnnxruntimeJava(byte[] bytes, boolean arena, int size) throws IOException { 30 | this.session = ENVIRONMENT 31 | .newSession() 32 | .setByteArray(bytes) 33 | .addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", arena ? "1" : "0")) 34 | .build(); 35 | this.out = new long[size]; 36 | } 37 | 38 | @Override 39 | public void close() throws Exception { 40 | session.close(); 41 | } 42 | 43 | @Override 44 | public long[] evaluate(long[] input) { 45 | try (Transaction txn = session.newTransaction().build()) { 46 | txn.addInput(0).asTensor().getLongBuffer().put(input); 47 | txn.addOutput(0); 48 | NamedCollection result = txn.run(); 49 | result.get(0).asTensor().getLongBuffer().get(out); 50 | return out; 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/OnnxruntimeJavaIoBinding.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | import com.jyuzawa.onnxruntime.IoBinding; 8 | import java.io.IOException; 9 | import java.nio.LongBuffer; 10 | 11 | final class OnnxruntimeJavaIoBinding extends OnnxruntimeJava { 12 | 13 | private final IoBinding ioBinding; 14 | private final LongBuffer inputBuf; 15 | private final LongBuffer outputBuf; 16 | 17 | OnnxruntimeJavaIoBinding(byte[] bytes, boolean arena, int size) throws IOException { 18 | super(bytes, arena, size); 19 | this.ioBinding = session.newIoBinding().bindInput(0).bindOutput(0).build(); 20 | this.inputBuf = ioBinding.getInputs().get(0).asTensor().getLongBuffer(); 21 | this.outputBuf = ioBinding.getOutputs().get(0).asTensor().getLongBuffer(); 22 | } 23 | 24 | @Override 25 | public void close() throws Exception { 26 | ioBinding.close(); 27 | super.close(); 28 | } 29 | 30 | @Override 31 | public long[] evaluate(long[] input) { 32 | inputBuf.clear().put(input); 33 | ioBinding.run(); 34 | outputBuf.rewind().get(out); 35 | return out; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/java/com/jyuzawa/onnxruntime_benchmark/Wrapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_benchmark; 6 | 7 | public interface Wrapper extends AutoCloseable { 8 | 9 | long[] evaluate(long[] input) throws Exception; 10 | } 11 | -------------------------------------------------------------------------------- /onnxruntime-benchmark/src/main/resources/model.onnx: -------------------------------------------------------------------------------- 1 | :N 2 |  3 | inputoutput"IdentityZ 4 | input 5 |  6 |  7 | b 8 | output 9 |  10 |  11 | B -------------------------------------------------------------------------------- /onnxruntime-sample-application/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: "application" 2 | 3 | dependencies { 4 | implementation project(":onnxruntime-sample-library") 5 | runtimeOnly project(path: ":", configuration: 'cpu') 6 | 7 | // In your application, use the following instead to use a compile dependency. 8 | // implementation "com.example:sample-library:X.Y.Z" 9 | // For the application to work, you will need to provide the native libraries. 10 | // Optionally, provide the CPU libraries (for various OS/Architecture combinations) 11 | // runtimeOnly "com.jyuzawa:onnxruntime-cpu:1.X.0:osx-x86_64" 12 | // Alternatively, do nothing and the Java library path will be used 13 | } 14 | 15 | application { 16 | mainClass = 'com.example.onnxruntime_sample_application.Application' 17 | applicationDefaultJvmArgs = ['--enable-native-access=ALL-UNNAMED'] 18 | } 19 | -------------------------------------------------------------------------------- /onnxruntime-sample-application/src/main/java/com/example/onnxruntime_sample_application/Application.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.example.onnxruntime_sample_application; 6 | 7 | import com.example.onnxruntime_sample_library.MyOracle; 8 | import java.util.Random; 9 | 10 | public final class Application { 11 | 12 | public static final void main(String[] args) throws Exception { 13 | try (MyOracle oracle = new MyOracle()) { 14 | Random random = new Random(); 15 | for (int i = 0; i < 1000; i++) { 16 | float expected = random.nextFloat(); 17 | float output = oracle.getIdentity(expected); 18 | if (expected != output) { 19 | throw new IllegalArgumentException("invalid output"); 20 | } 21 | } 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /onnxruntime-sample-library/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: "java-library" 2 | 3 | dependencies { 4 | api project(":") 5 | 6 | // In your library, use the following instead to expose a compile dependency. 7 | // This allows the users of your Java library to bring the native libraries. 8 | // api "com.jyuzawa:onnxruntime:X.Y.Z" 9 | } 10 | -------------------------------------------------------------------------------- /onnxruntime-sample-library/src/main/java/com/example/onnxruntime_sample_library/MyOracle.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.example.onnxruntime_sample_library; 6 | 7 | import com.jyuzawa.onnxruntime.Environment; 8 | import com.jyuzawa.onnxruntime.NamedCollection; 9 | import com.jyuzawa.onnxruntime.OnnxRuntime; 10 | import com.jyuzawa.onnxruntime.OnnxValue; 11 | import com.jyuzawa.onnxruntime.Session; 12 | import com.jyuzawa.onnxruntime.Transaction; 13 | import java.io.IOException; 14 | import java.io.InputStream; 15 | 16 | public final class MyOracle implements AutoCloseable { 17 | 18 | private static final Environment ENVIRONMENT = 19 | OnnxRuntime.get().getApi().newEnvironment().build(); 20 | private final Session session; 21 | 22 | public MyOracle() throws IOException { 23 | try (InputStream is = this.getClass().getResourceAsStream("/floatIdentity.onnx")) { 24 | this.session = 25 | ENVIRONMENT.newSession().setByteArray(is.readAllBytes()).build(); 26 | } 27 | } 28 | 29 | public float getIdentity(float input) { 30 | try (Transaction txn = session.newTransaction().build()) { 31 | txn.addInput(0).asTensor().getFloatBuffer().put(input); 32 | txn.addOutput(0); 33 | NamedCollection result = txn.run(); 34 | return result.get(0).asTensor().getFloatBuffer().get(); 35 | } 36 | } 37 | 38 | @Override 39 | public void close() { 40 | session.close(); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /onnxruntime-sample-library/src/main/resources/floatIdentity.onnx: -------------------------------------------------------------------------------- 1 | :F 2 |  3 | inputoutput"IdentityZ 4 | input 5 | 6 |  7 | b 8 | output 9 | 10 |  11 | B -------------------------------------------------------------------------------- /onnxruntime_all.h: -------------------------------------------------------------------------------- 1 | #include "onnxruntime_c_api.h" 2 | // add EP's which are present but optionally included 3 | #include "coreml_provider_factory.h" -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'onnxruntime' 2 | include 'onnxruntime-sample-library' 3 | include 'onnxruntime-sample-application' 4 | include 'onnxruntime-benchmark' 5 | include 'onnx' 6 | -------------------------------------------------------------------------------- /src/dist/META-INF/NOTICE.txt: -------------------------------------------------------------------------------- 1 | This product contains compiled artifacts from onnxruntime: https://github.com/microsoft/onnxruntime/releases 2 | 3 | * LICENSE (MIT license) 4 | * license/LICENSE.onnxruntime.txt 5 | * HOMEPAGE: 6 | * https://github.com/microsoft/onnxruntime -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/Api.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Set; 8 | 9 | /** 10 | * The top-level API of the ONNX runtime. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface Api { 15 | 16 | /** 17 | * Create a new environment. 18 | * @return a builder 19 | */ 20 | Environment.Builder newEnvironment(); 21 | 22 | /** 23 | * Determine what execution providers are available. 24 | * 25 | * NOTE: not all of these execution provider may be supported by this library. 26 | * @return a set of execution providers. 27 | */ 28 | Set getAvailableProviders(); 29 | 30 | /** 31 | * This function returns the onnxruntime build information: including git branch, git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. 32 | * @return the version string 33 | * @since 1.2.0 34 | */ 35 | String getBuildString(); 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/Environment.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * The environment in which model evaluation sessions can be constructed. Only 11 | * one environment should be used in your application. This class is thread safe. 12 | * 13 | * @since 1.0.0 14 | */ 15 | public interface Environment extends AutoCloseable { 16 | 17 | /** 18 | * Releases native resources. 19 | */ 20 | @Override 21 | void close(); 22 | 23 | /** 24 | * Change the telemetry events. 25 | * 26 | * @param enabled true if telemetry events should be enabled. 27 | */ 28 | void setTelemetryEvents(boolean enabled); 29 | 30 | // TODO: allocator 31 | // void createAndRegisterAllocator(); 32 | 33 | /** 34 | * Create a new session for model evaluation. 35 | * 36 | * @return a builder 37 | */ 38 | Session.Builder newSession(); 39 | 40 | /** 41 | * A builder of an {@link Environment}. The default severity level is determined 42 | * from the logger level of the logger of {@link Environment}. 43 | * 44 | * @since 1.0.0 45 | */ 46 | public interface Builder { 47 | /** 48 | * Set the severity for logging. 49 | * 50 | * @param level the severity 51 | * @return the builder 52 | */ 53 | Builder setLogSeverityLevel(OnnxRuntimeLoggingLevel level); 54 | 55 | /** 56 | * Set the logging identifier. 57 | * 58 | * @param id the identifier 59 | * @return the builder 60 | */ 61 | Builder setLogId(String id); 62 | 63 | /** 64 | * Set the shared thread pool's "denormal as zero" setting. 65 | * @param globalDenormalAsZero whether to denormal as zero 66 | * @return the builder 67 | */ 68 | Builder setGlobalDenormalAsZero(boolean globalDenormalAsZero); 69 | 70 | /** 71 | * Set the shared thread pool's "inter op" thread count. 72 | * @param numThreads a number of threads to use, or 0 to have the library use a default 73 | * @return the builder 74 | */ 75 | Builder setGlobalInterOpNumThreads(int numThreads); 76 | 77 | /** 78 | * Set the shared thread pool's "intra op" thread count. 79 | * @param numThreads a number of threads to use, or 0 to have the library use a default 80 | * @return the builder 81 | */ 82 | Builder setGlobalIntraOpNumThreads(int numThreads); 83 | 84 | /** 85 | * Set whether the shared thread pool's should spin when idle. 86 | * @param globalSpinControl whether to spin or not 87 | * @return the builder 88 | */ 89 | Builder setGlobalSpinControl(boolean globalSpinControl); 90 | 91 | /** 92 | * Set whether the environment's shared allocator for CPU should be arena-based 93 | * @param config the key/value configuration of the arena 94 | * @return the builder 95 | * @since 1.2.0 96 | */ 97 | Builder setArenaConfig(Map config); 98 | 99 | /** 100 | * Constructs the {@link Environment}. 101 | * 102 | * @return a new instance 103 | */ 104 | Environment build(); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/EnvironmentImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 9 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.ORT_PROJECTION_JAVA; 10 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtArenaAllocator; 11 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtMemTypeDefault; 12 | 13 | import java.lang.foreign.Arena; 14 | import java.lang.foreign.MemorySegment; 15 | import java.util.Map; 16 | 17 | final class EnvironmentImpl extends ManagedImpl implements Environment { 18 | 19 | private final MemorySegment address; 20 | final MemorySegment memoryInfo; 21 | final MemorySegment ortAllocator; 22 | 23 | EnvironmentImpl(Builder builder) { 24 | super(builder.api, Arena.ofShared()); 25 | try (Arena temporarySession = Arena.ofConfined()) { 26 | MemorySegment logName = temporarySession.allocateFrom(builder.logId); 27 | if (builder.useThreadingOptions) { 28 | MemorySegment threadingOptionsAddress = builder.newThreadingOptions(temporarySession); 29 | try { 30 | this.address = api.create( 31 | arena, 32 | out -> api.CreateEnvWithCustomLoggerAndGlobalThreadPools.apply( 33 | OnnxRuntimeLoggingLevel.LOG_CALLBACK, 34 | MemorySegment.NULL, 35 | builder.severityLevel.getNumber(), 36 | logName, 37 | threadingOptionsAddress, 38 | out)); 39 | } finally { 40 | api.ReleaseThreadingOptions.apply(threadingOptionsAddress); 41 | } 42 | } else { 43 | this.address = api.create( 44 | arena, 45 | out -> api.CreateEnvWithCustomLogger.apply( 46 | OnnxRuntimeLoggingLevel.LOG_CALLBACK, 47 | MemorySegment.NULL, 48 | builder.severityLevel.getNumber(), 49 | logName, 50 | out)); 51 | } 52 | api.checkStatus(api.SetLanguageProjection.apply(address, ORT_PROJECTION_JAVA())); 53 | this.memoryInfo = api.create( 54 | arena, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out)); 55 | this.ortAllocator = api.create(arena, out -> api.GetAllocatorWithDefaultOptions.apply(out)); 56 | Map arenaConfig = builder.arenaConfig; 57 | if (arenaConfig == null) { 58 | api.RegisterAllocator.apply(address, ortAllocator); 59 | } else { 60 | int size = arenaConfig.size(); 61 | MemorySegment keyArray = temporarySession.allocate(C_POINTER, size); 62 | MemorySegment valueArray = temporarySession.allocate(C_LONG, size); 63 | int i = 0; 64 | for (Map.Entry entry : arenaConfig.entrySet()) { 65 | keyArray.setAtIndex(C_POINTER, i, temporarySession.allocateFrom(entry.getKey())); 66 | valueArray.setAtIndex(C_LONG, i, entry.getValue()); 67 | i++; 68 | } 69 | MemorySegment arenaConfigAddress = api.create( 70 | temporarySession, out -> api.CreateArenaCfgV2.apply(keyArray, valueArray, size, out)); 71 | api.checkStatus(api.CreateAndRegisterAllocator.apply(address, memoryInfo, arenaConfigAddress)); 72 | } 73 | } 74 | } 75 | 76 | @Override 77 | public void close() { 78 | api.ReleaseMemoryInfo.apply(memoryInfo); 79 | api.ReleaseEnv.apply(address); 80 | super.close(); 81 | } 82 | 83 | @Override 84 | MemorySegment address() { 85 | return address; 86 | } 87 | 88 | @Override 89 | public Session.Builder newSession() { 90 | return new SessionImpl.Builder(this); 91 | } 92 | 93 | @Override 94 | public void setTelemetryEvents(boolean enabled) { 95 | if (enabled) { 96 | api.checkStatus(api.EnableTelemetryEvents.apply(address)); 97 | } else { 98 | api.checkStatus(api.DisableTelemetryEvents.apply(address)); 99 | } 100 | } 101 | 102 | static final class Builder implements Environment.Builder { 103 | 104 | private final ApiImpl api; 105 | private OnnxRuntimeLoggingLevel severityLevel; 106 | private String logId; 107 | private boolean useThreadingOptions; 108 | private Boolean globalDenormalAsZero; 109 | private Integer globalInterOpNumThreads; 110 | private Integer globalIntraOpNumThreads; 111 | private Boolean globalSpinControl; 112 | private Map arenaConfig; 113 | 114 | Builder(ApiImpl api) { 115 | this.api = api; 116 | this.severityLevel = OnnxRuntimeLoggingLevel.DEFAULT; 117 | this.logId = "onnxruntime-java"; 118 | } 119 | 120 | @Override 121 | public Builder setLogSeverityLevel(OnnxRuntimeLoggingLevel severityLevel) { 122 | this.severityLevel = severityLevel; 123 | return this; 124 | } 125 | 126 | @Override 127 | public Builder setLogId(String id) { 128 | this.logId = id; 129 | return this; 130 | } 131 | 132 | @Override 133 | public Builder setGlobalDenormalAsZero(boolean globalDenormalAsZero) { 134 | this.useThreadingOptions = true; 135 | this.globalDenormalAsZero = globalDenormalAsZero; 136 | return this; 137 | } 138 | 139 | @Override 140 | public Builder setGlobalInterOpNumThreads(int numThreads) { 141 | this.useThreadingOptions = true; 142 | this.globalInterOpNumThreads = numThreads; 143 | return this; 144 | } 145 | 146 | @Override 147 | public Builder setGlobalIntraOpNumThreads(int numThreads) { 148 | this.useThreadingOptions = true; 149 | this.globalIntraOpNumThreads = numThreads; 150 | return this; 151 | } 152 | 153 | @Override 154 | public Builder setGlobalSpinControl(boolean globalSpinControl) { 155 | this.useThreadingOptions = true; 156 | this.globalSpinControl = globalSpinControl; 157 | return this; 158 | } 159 | 160 | @Override 161 | public Builder setArenaConfig(Map config) { 162 | this.arenaConfig = config; 163 | return this; 164 | } 165 | 166 | private MemorySegment newThreadingOptions(Arena arena) { 167 | MemorySegment threadingOptions = api.create(arena, out -> api.CreateThreadingOptions.apply(out)); 168 | if (globalDenormalAsZero != null && globalSpinControl) { 169 | api.checkStatus(api.SetGlobalDenormalAsZero.apply(threadingOptions)); 170 | } 171 | if (globalSpinControl != null) { 172 | api.checkStatus(api.SetGlobalSpinControl.apply(threadingOptions, globalSpinControl ? 1 : 0)); 173 | } 174 | if (globalInterOpNumThreads != null) { 175 | api.checkStatus(api.SetGlobalInterOpNumThreads.apply(threadingOptions, globalInterOpNumThreads)); 176 | } 177 | if (globalIntraOpNumThreads != null) { 178 | api.checkStatus(api.SetGlobalIntraOpNumThreads.apply(threadingOptions, globalIntraOpNumThreads)); 179 | } 180 | return threadingOptions; 181 | } 182 | 183 | @Override 184 | public Environment build() { 185 | return new EnvironmentImpl(this); 186 | } 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProvider.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * The universe of possible execution providers. 9 | * 10 | * The identifiers are the constants defined in https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/graph/constants.h 11 | * 12 | * Support in the public C API and in the provided native library is a requirement for successful usage. 13 | * 14 | * @since 1.0.0 15 | */ 16 | public enum ExecutionProvider { 17 | /** 18 | * CPU: The default, added implicitly always. 19 | */ 20 | CPU_EXECUTION_PROVIDER("CPUExecutionProvider", ExecutionProviderCPUConfig::new), 21 | /** 22 | * CUDA: supported 23 | */ 24 | CUDA_EXECUTION_PROVIDER("CUDAExecutionProvider", ExecutionProviderCUDAConfig::new), 25 | /** 26 | * oneDNN (DNNL): supported 27 | */ 28 | DNNL_EXECUTION_PROVIDER("DnnlExecutionProvider", ExecutionProviderDnnlConfig::new), 29 | /** 30 | * OpenVINO: supported 31 | */ 32 | OPENVINO_EXECUTION_PROVIDER("OpenVINOExecutionProvider", ExecutionProviderOpenVINOConfig::new), 33 | /** 34 | * VitisAI: supported 35 | */ 36 | VITISAI_EXECUTION_PROVIDER("VitisAIExecutionProvider", ExecutionProviderVitisAIConfig::new), 37 | /** 38 | * TensorRT: supported 39 | */ 40 | TENSORRT_EXECUTION_PROVIDER("TensorrtExecutionProvider", ExecutionProviderTensorRTConfig::new), 41 | /** 42 | * NNApi: not supported 43 | */ 44 | NNAPI_EXECUTION_PROVIDER("NnapiExecutionProvider"), 45 | /** 46 | * QNN: supported 47 | * @since 1.2.0 48 | */ 49 | QNN_EXECUTION_PROVIDER("QNNExecutionProvider", ExecutionProviderSimpleMapConfig.of("QNN")), 50 | /** 51 | * RKNPU: not supported 52 | */ 53 | RKNPU_EXECUTION_PROVIDER("RknpuExecutionProvider"), 54 | /** 55 | * DML: not supported 56 | */ 57 | DML_EXECUTION_PROVIDER("DmlExecutionProvider"), 58 | /** 59 | * MIGraphX: supported 60 | */ 61 | MIGRAPHX_EXECUTION_PROVIDER("MIGraphXExecutionProvider", ExecutionProviderMIGraphXConfig::new), 62 | /** 63 | * ACL: not supported 64 | */ 65 | ACL_EXECUTION_PROVIDER("ACLExecutionProvider"), 66 | /** 67 | * ARMNN: not supported 68 | */ 69 | ARMNN_EXECUTION_PROVIDER("ArmNNExecutionProvider"), 70 | /** 71 | * ROCM: supported 72 | */ 73 | ROCM_EXECUTION_PROVIDER("ROCMExecutionProvider", ExecutionProviderROCMConfig::new), 74 | /** 75 | * CoreML: supported 76 | */ 77 | COREML_EXECUTION_PROVIDER("CoreMLExecutionProvider", ExecutionProviderCoreMLConfig::new), 78 | /** 79 | * JS Execution Provider: not supported 80 | * @since 1.2.0 81 | */ 82 | JS_EXECUTION_PROVIDER("JsExecutionProvider"), 83 | /** 84 | * SNPE: supported 85 | */ 86 | SNPE_EXECUTION_PROVIDER("SNPEExecutionProvider", ExecutionProviderSimpleMapConfig.of("SNPE")), 87 | /** 88 | * TVM: not supported 89 | */ 90 | TVM_EXECUTION_PROVIDER("TvmExecutionProvider"), 91 | /** 92 | * XNNPACK: supported 93 | */ 94 | XNNPACK_EXECUTION_PROVIDER("XnnpackExecutionProvider", ExecutionProviderSimpleMapConfig.of("XNNPACK")), 95 | /** 96 | * WebNN: not supported 97 | * @since 1.2.0 98 | */ 99 | WEBNN_EXECUTION_PROVIDER("WebNNExecutionProvider"), 100 | /** 101 | * CANN: not supported 102 | */ 103 | CANN_EXECUTION_PROVIDER("CANNExecutionProvider"), 104 | /** 105 | * Azure: not supported 106 | * @since 1.2.0 107 | */ 108 | AZURE_EXECUTION_PROVIDER("AzureExecutionProvider"); 109 | 110 | private final String identifier; 111 | final ExecutionProviderConfigFactory factory; 112 | 113 | private ExecutionProvider(String identifier) { 114 | this(identifier, null); 115 | } 116 | 117 | private ExecutionProvider(String identifier, ExecutionProviderConfigFactory factory) { 118 | this.identifier = identifier; 119 | this.factory = factory; 120 | } 121 | 122 | boolean isSupported() { 123 | return factory != null; 124 | } 125 | 126 | /** 127 | * Find a enum value based on its identifier string. 128 | * @param identifier the string constant name for this execution provider 129 | * @return an enum value 130 | */ 131 | public static final ExecutionProvider of(String identifier) { 132 | for (ExecutionProvider executionProvider : ExecutionProvider.values()) { 133 | if (executionProvider.identifier.equals(identifier)) { 134 | return executionProvider; 135 | } 136 | } 137 | return null; 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderCPUConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | 11 | final class ExecutionProviderCPUConfig extends ExecutionProviderConfig { 12 | 13 | private static final String USE_ARENA = "use_arena"; 14 | 15 | protected ExecutionProviderCPUConfig(Map properties) { 16 | super(properties); 17 | } 18 | 19 | @Override 20 | void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) { 21 | // default is true 22 | // https://github.com/microsoft/onnxruntime/blob/fb85b31facb9fb3fc99c76f99c93ea8f06ada39b/onnxruntime/core/providers/cpu/cpu_execution_provider.h#L14 23 | String useArena = properties.getOrDefault(USE_ARENA, "1"); 24 | if ("1".equals(useArena) || "true".equals(useArena)) { 25 | api.checkStatus(api.EnableCpuMemArena.apply(sessionOptions)); 26 | } else { 27 | api.checkStatus(api.DisableCpuMemArena.apply(sessionOptions)); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderCUDAConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.Map; 9 | 10 | final class ExecutionProviderCUDAConfig extends ExecutionProviderObjectConfig { 11 | 12 | ExecutionProviderCUDAConfig(Map properties) { 13 | super(properties); 14 | } 15 | 16 | @Override 17 | protected MemorySegment create(ApiImpl api, MemorySegment out) { 18 | return api.CreateCUDAProviderOptions.apply(out); 19 | } 20 | 21 | @Override 22 | protected void release(ApiImpl api, MemorySegment config) { 23 | api.ReleaseCUDAProviderOptions.apply(config); 24 | } 25 | 26 | @Override 27 | protected MemorySegment update( 28 | ApiImpl api, MemorySegment config, MemorySegment keys, MemorySegment values, int numProperties) { 29 | return api.UpdateCUDAProviderOptions.apply(config, keys, values, numProperties); 30 | } 31 | 32 | @Override 33 | protected MemorySegment append(ApiImpl api, MemorySegment sessionOptions, MemorySegment config) { 34 | return api.SessionOptionsAppendExecutionProvider_CUDA_V2.apply(sessionOptions, config); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | import java.util.Optional; 11 | import java.util.function.BiConsumer; 12 | 13 | abstract class ExecutionProviderConfig { 14 | 15 | protected static final String TRUE_VALUE = "1"; 16 | 17 | protected final Map properties; 18 | 19 | protected ExecutionProviderConfig(Map properties) { 20 | this.properties = properties; 21 | } 22 | 23 | @Override 24 | public final String toString() { 25 | return properties.toString(); 26 | } 27 | 28 | private Optional get(String key) { 29 | return Optional.ofNullable(properties.get(key)); 30 | } 31 | 32 | protected void copyByte(String key, MemorySegment config, BiConsumer consumer) { 33 | get(key).ifPresent(val -> consumer.accept(config, Integer.valueOf(val).byteValue())); 34 | } 35 | 36 | protected void copyInteger(String key, MemorySegment config, BiConsumer consumer) { 37 | get(key).ifPresent(val -> consumer.accept(config, Integer.parseInt(val))); 38 | } 39 | 40 | protected void copyBoolean(String key, MemorySegment config, BiConsumer consumer) { 41 | get(key).ifPresent(val -> consumer.accept(config, TRUE_VALUE.equals(val) || Boolean.valueOf(val))); 42 | } 43 | 44 | protected void copyLong(String key, MemorySegment config, BiConsumer consumer) { 45 | get(key).ifPresent(val -> consumer.accept(config, Long.parseLong(val))); 46 | } 47 | 48 | protected void copyString( 49 | String key, MemorySegment config, Arena arena, BiConsumer consumer) { 50 | get(key).ifPresent(val -> consumer.accept(config, arena.allocateFrom(val))); 51 | } 52 | 53 | abstract void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions); 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderConfigFactory.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | import java.util.function.Function; 9 | 10 | interface ExecutionProviderConfigFactory extends Function, ExecutionProviderConfig> {} 11 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderCoreMLConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import com.jyuzawa.onnxruntime_extern.onnxruntime_all_h; 8 | import java.lang.foreign.Arena; 9 | import java.lang.foreign.MemorySegment; 10 | import java.util.Map; 11 | 12 | final class ExecutionProviderCoreMLConfig extends ExecutionProviderConfig { 13 | 14 | protected ExecutionProviderCoreMLConfig(Map properties) { 15 | super(properties); 16 | } 17 | 18 | @Override 19 | final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) { 20 | int flags = 0; 21 | if (TRUE_VALUE.equals(properties.get("use_cpu_only"))) { 22 | flags |= onnxruntime_all_h.COREML_FLAG_USE_CPU_ONLY(); 23 | } 24 | if (TRUE_VALUE.equals(properties.get("enable_on_subgraph"))) { 25 | flags |= onnxruntime_all_h.COREML_FLAG_ENABLE_ON_SUBGRAPH(); 26 | } 27 | if (TRUE_VALUE.equals(properties.get("device_with_ane"))) { 28 | flags |= onnxruntime_all_h.COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(); 29 | } 30 | if (TRUE_VALUE.equals(properties.get("allow_static_input_shapes"))) { 31 | flags |= onnxruntime_all_h.COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(); 32 | } 33 | if (TRUE_VALUE.equals(properties.get("use_cpu_and_cpu"))) { 34 | flags |= onnxruntime_all_h.COREML_FLAG_USE_CPU_AND_GPU(); 35 | } 36 | try { 37 | api.checkStatus(onnxruntime_all_h.OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, flags)); 38 | } catch (UnsatisfiedLinkError e) { 39 | // NOTE: CoreML is not always present 40 | throw new OnnxRuntimeException("CoreML is not enabled in this build", e); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderDnnlConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.Map; 9 | 10 | final class ExecutionProviderDnnlConfig extends ExecutionProviderObjectConfig { 11 | 12 | ExecutionProviderDnnlConfig(Map properties) { 13 | super(properties); 14 | } 15 | 16 | @Override 17 | protected MemorySegment create(ApiImpl api, MemorySegment out) { 18 | return api.CreateDnnlProviderOptions.apply(out); 19 | } 20 | 21 | @Override 22 | protected void release(ApiImpl api, MemorySegment config) { 23 | api.ReleaseDnnlProviderOptions.apply(config); 24 | } 25 | 26 | @Override 27 | protected MemorySegment update( 28 | ApiImpl api, MemorySegment config, MemorySegment keys, MemorySegment values, int numProperties) { 29 | return api.UpdateDnnlProviderOptions.apply(config, keys, values, numProperties); 30 | } 31 | 32 | @Override 33 | protected MemorySegment append(ApiImpl api, MemorySegment sessionOptions, MemorySegment config) { 34 | return api.SessionOptionsAppendExecutionProvider_Dnnl.apply(sessionOptions, config); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderMIGraphXConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import com.jyuzawa.onnxruntime_extern.OrtMIGraphXProviderOptions; 8 | import java.lang.foreign.Arena; 9 | import java.lang.foreign.MemorySegment; 10 | import java.util.Map; 11 | 12 | final class ExecutionProviderMIGraphXConfig extends ExecutionProviderConfig { 13 | 14 | protected ExecutionProviderMIGraphXConfig(Map properties) { 15 | super(properties); 16 | } 17 | 18 | @Override 19 | final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) { 20 | MemorySegment config = OrtMIGraphXProviderOptions.allocate(arena); 21 | copyInteger("device_id", config, OrtMIGraphXProviderOptions::device_id); 22 | copyInteger("migraphx_fp16_enable", config, OrtMIGraphXProviderOptions::migraphx_fp16_enable); 23 | copyInteger("migraphx_int8_enable", config, OrtMIGraphXProviderOptions::migraphx_int8_enable); 24 | copyInteger( 25 | "migraphx_use_native_calibration_table", 26 | config, 27 | OrtMIGraphXProviderOptions::migraphx_use_native_calibration_table); 28 | copyString( 29 | "migraphx_int8_calibration_table_name", 30 | config, 31 | arena, 32 | OrtMIGraphXProviderOptions::migraphx_int8_calibration_table_name); 33 | copyInteger("migraphx_save_compiled_model", config, OrtMIGraphXProviderOptions::migraphx_save_compiled_model); 34 | copyString("migraphx_save_model_path", config, arena, OrtMIGraphXProviderOptions::migraphx_save_model_path); 35 | 36 | copyInteger("migraphx_load_compiled_model", config, OrtMIGraphXProviderOptions::migraphx_load_compiled_model); 37 | copyString("migraphx_load_model_path", config, arena, OrtMIGraphXProviderOptions::migraphx_load_model_path); 38 | copyBoolean("migraphx_exhaustive_tune", config, OrtMIGraphXProviderOptions::migraphx_exhaustive_tune); 39 | api.checkStatus(api.SessionOptionsAppendExecutionProvider_MIGraphX.apply(sessionOptions, config)); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderMapConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 8 | 9 | import java.lang.foreign.Arena; 10 | import java.lang.foreign.MemorySegment; 11 | import java.util.Map; 12 | 13 | abstract class ExecutionProviderMapConfig extends ExecutionProviderConfig { 14 | 15 | protected ExecutionProviderMapConfig(Map properties) { 16 | super(properties); 17 | } 18 | 19 | protected abstract void appendToSessionOptions( 20 | Arena arena, 21 | ApiImpl api, 22 | MemorySegment sessionOptions, 23 | MemorySegment keys, 24 | MemorySegment values, 25 | int numProperties); 26 | 27 | @Override 28 | final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) { 29 | int numProps = properties.size(); 30 | MemorySegment keys = arena.allocate(C_POINTER, numProps); 31 | MemorySegment values = arena.allocate(C_POINTER, numProps); 32 | int i = 0; 33 | for (Map.Entry entry : properties.entrySet()) { 34 | keys.setAtIndex(C_POINTER, i, arena.allocateFrom(entry.getKey())); 35 | values.setAtIndex(C_POINTER, i, arena.allocateFrom(entry.getValue())); 36 | i++; 37 | } 38 | appendToSessionOptions(arena, api, sessionOptions, keys, values, numProps); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderObjectConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | 11 | abstract class ExecutionProviderObjectConfig extends ExecutionProviderMapConfig { 12 | 13 | protected ExecutionProviderObjectConfig(Map properties) { 14 | super(properties); 15 | } 16 | 17 | @Override 18 | protected final void appendToSessionOptions( 19 | Arena arena, 20 | ApiImpl api, 21 | MemorySegment sessionOptions, 22 | MemorySegment keys, 23 | MemorySegment values, 24 | int numProperties) { 25 | MemorySegment config = api.create(arena, out -> create(api, out)); 26 | try { 27 | api.checkStatus(update(api, config, keys, values, numProperties)); 28 | api.checkStatus(append(api, sessionOptions, config)); 29 | } finally { 30 | release(api, config); 31 | } 32 | } 33 | 34 | protected abstract MemorySegment create(ApiImpl api, MemorySegment out); 35 | 36 | protected abstract void release(ApiImpl api, MemorySegment config); 37 | 38 | protected abstract MemorySegment update( 39 | ApiImpl api, MemorySegment config, MemorySegment keys, MemorySegment values, int numProperties); 40 | 41 | protected abstract MemorySegment append(ApiImpl api, MemorySegment sessionOptions, MemorySegment config); 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderOpenVINOConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | 11 | final class ExecutionProviderOpenVINOConfig extends ExecutionProviderMapConfig { 12 | 13 | protected ExecutionProviderOpenVINOConfig(Map properties) { 14 | super(properties); 15 | } 16 | 17 | @Override 18 | protected void appendToSessionOptions( 19 | Arena arena, 20 | ApiImpl api, 21 | MemorySegment sessionOptions, 22 | MemorySegment keys, 23 | MemorySegment values, 24 | int numProperties) { 25 | api.checkStatus(api.SessionOptionsAppendExecutionProvider_OpenVINO_V2.apply( 26 | sessionOptions, keys, values, numProperties)); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderROCMConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import com.jyuzawa.onnxruntime_extern.OrtROCMProviderOptions; 8 | import java.lang.foreign.Arena; 9 | import java.lang.foreign.MemorySegment; 10 | import java.util.Map; 11 | 12 | final class ExecutionProviderROCMConfig extends ExecutionProviderConfig { 13 | 14 | protected ExecutionProviderROCMConfig(Map properties) { 15 | super(properties); 16 | } 17 | 18 | @Override 19 | final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) { 20 | MemorySegment config = OrtROCMProviderOptions.allocate(arena); 21 | copyInteger("device_id", config, OrtROCMProviderOptions::device_id); 22 | copyInteger("miopen_conv_exhaustive_search", config, OrtROCMProviderOptions::miopen_conv_exhaustive_search); 23 | copyLong("gpu_mem_limit", config, OrtROCMProviderOptions::gpu_mem_limit); 24 | copyInteger("arena_extend_strategy", config, OrtROCMProviderOptions::arena_extend_strategy); 25 | copyInteger("do_copy_in_default_stream", config, OrtROCMProviderOptions::do_copy_in_default_stream); 26 | copyInteger("has_user_compute_stream", config, OrtROCMProviderOptions::has_user_compute_stream); 27 | // TODO: user_compute_stream 28 | // TODO: default_memory_arena_cfg 29 | copyInteger("enable_hip_graph", config, OrtROCMProviderOptions::enable_hip_graph); 30 | copyInteger("tunable_op_enable", config, OrtROCMProviderOptions::tunable_op_enable); 31 | copyInteger("tunable_op_tuning_enable", config, OrtROCMProviderOptions::tunable_op_tuning_enable); 32 | copyInteger( 33 | "tunable_op_max_tuning_duration_ms", config, OrtROCMProviderOptions::tunable_op_max_tuning_duration_ms); 34 | api.checkStatus(api.SessionOptionsAppendExecutionProvider_ROCM.apply(sessionOptions, config)); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderSimpleMapConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | 11 | final class ExecutionProviderSimpleMapConfig extends ExecutionProviderMapConfig { 12 | 13 | private final String name; 14 | 15 | private ExecutionProviderSimpleMapConfig(Map properties, String name) { 16 | super(properties); 17 | this.name = name; 18 | } 19 | 20 | static final ExecutionProviderConfigFactory of(String name) { 21 | return properties -> new ExecutionProviderSimpleMapConfig(properties, name); 22 | } 23 | 24 | @Override 25 | protected void appendToSessionOptions( 26 | Arena arena, 27 | ApiImpl api, 28 | MemorySegment sessionOptions, 29 | MemorySegment keys, 30 | MemorySegment values, 31 | int numProperties) { 32 | api.checkStatus(api.SessionOptionsAppendExecutionProvider.apply( 33 | sessionOptions, arena.allocateFrom(name), keys, values, numProperties)); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderTensorRTConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.Map; 9 | 10 | final class ExecutionProviderTensorRTConfig extends ExecutionProviderObjectConfig { 11 | 12 | ExecutionProviderTensorRTConfig(Map properties) { 13 | super(properties); 14 | } 15 | 16 | @Override 17 | protected MemorySegment create(ApiImpl api, MemorySegment out) { 18 | return api.CreateTensorRTProviderOptions.apply(out); 19 | } 20 | 21 | @Override 22 | protected void release(ApiImpl api, MemorySegment config) { 23 | api.ReleaseTensorRTProviderOptions.apply(config); 24 | } 25 | 26 | @Override 27 | protected MemorySegment update( 28 | ApiImpl api, MemorySegment config, MemorySegment keys, MemorySegment values, int numProperties) { 29 | return api.UpdateTensorRTProviderOptions.apply(config, keys, values, numProperties); 30 | } 31 | 32 | @Override 33 | protected MemorySegment append(ApiImpl api, MemorySegment sessionOptions, MemorySegment config) { 34 | return api.SessionOptionsAppendExecutionProvider_TensorRT_V2.apply(sessionOptions, config); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ExecutionProviderVitisAIConfig.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.Map; 10 | 11 | final class ExecutionProviderVitisAIConfig extends ExecutionProviderMapConfig { 12 | 13 | protected ExecutionProviderVitisAIConfig(Map properties) { 14 | super(properties); 15 | } 16 | 17 | @Override 18 | protected void appendToSessionOptions( 19 | Arena arena, 20 | ApiImpl api, 21 | MemorySegment sessionOptions, 22 | MemorySegment keys, 23 | MemorySegment values, 24 | int numProperties) { 25 | api.checkStatus( 26 | api.SessionOptionsAppendExecutionProvider_VitisAI.apply(sessionOptions, keys, values, numProperties)); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/IoBinding.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * A representation of a model evaluation. Capable of reuse in repetitive runs. More efficient than {@link Transaction}. Only supports tensors. Input and outputs are pre-allocated and must be of fixed size. This class is NOT thread-safe. 11 | * 12 | * @since 1.4.0 13 | */ 14 | public interface IoBinding extends AutoCloseable { 15 | 16 | /** 17 | * Set the severity for logging for this specific transaction. 18 | * Can override the environment's or session's logger's severity. 19 | * @param level 20 | * @return this 21 | */ 22 | IoBinding setLogSeverityLevel(OnnxRuntimeLoggingLevel level); 23 | 24 | /** 25 | * Set the verbosity for logging for this specific transaction. 26 | * Can override the environment's or session's logger's verbosity. 27 | * @param level 28 | * @return this 29 | */ 30 | IoBinding setLogVerbosityLevel(int level); 31 | 32 | /** 33 | * Set the run tag (which is the logger id) 34 | * @param runTag 35 | * @return this 36 | */ 37 | IoBinding setRunTag(String runTag); 38 | 39 | IoBinding synchronizeBoundInputs(); 40 | 41 | IoBinding synchronizeBoundOutputs(); 42 | 43 | NamedCollection getInputs(); 44 | 45 | NamedCollection getOutputs(); 46 | 47 | /** 48 | * Run the model evaluation. 49 | */ 50 | void run(); 51 | 52 | /** 53 | * Frees the native resources (typically buffers) associated with this transaction. 54 | */ 55 | @Override 56 | void close(); 57 | 58 | /** 59 | * A builder of a {@link IoBinding}. Should NOT be reused. This class is NOT thread-safe. 60 | * 61 | * @since 1.0.0 62 | */ 63 | public interface Builder { 64 | /** 65 | * Add an input and get an OnnxValue to populate. 66 | * @param name 67 | * @return the value to be populated 68 | */ 69 | Builder bindInput(String name); 70 | 71 | /** 72 | * Add an input and get an OnnxValue to populate. 73 | * @param index 74 | * @return the value to be populated 75 | */ 76 | Builder bindInput(int index); 77 | 78 | /** 79 | * Request a specific output to be produced. 80 | * @param name 81 | */ 82 | Builder bindOutput(String name); 83 | 84 | /** 85 | * Request a specific output to be produced. 86 | * @param index 87 | */ 88 | Builder bindOutput(int index); 89 | 90 | /** 91 | * Set custom parameters for this transaction. 92 | * @param config 93 | * @return the builder 94 | */ 95 | Builder setConfigMap(Map config); 96 | 97 | /** 98 | * Construct a {@link IoBinding}. 99 | * 100 | * @return a new instance 101 | */ 102 | IoBinding build(); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.ArrayList; 10 | import java.util.LinkedHashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | 14 | final class IoBindingImpl implements IoBinding { 15 | private final ApiImpl api; 16 | private final Arena arena; 17 | private final MemorySegment ioBinding; 18 | private final MemorySegment runOptions; 19 | private final NamedCollectionImpl inputs; 20 | private final NamedCollectionImpl outputs; 21 | private final List closeables; 22 | private final MemorySegment session; 23 | 24 | IoBindingImpl(Builder builder) { 25 | // NOTE: this is shared since we want to allow closing from another thread. 26 | this.arena = Arena.ofShared(); 27 | this.api = builder.api; 28 | this.session = builder.session.address(); 29 | this.ioBinding = builder.api.create(arena, out -> builder.api.CreateIoBinding.apply(session, out)); 30 | this.runOptions = api.create(arena, out -> api.CreateRunOptions.apply(out)); 31 | Map config = builder.config; 32 | if (config != null && !config.isEmpty()) { 33 | for (Map.Entry entry : config.entrySet()) { 34 | api.checkStatus(api.AddRunConfigEntry.apply( 35 | runOptions, arena.allocateFrom(entry.getKey()), arena.allocateFrom(entry.getValue()))); 36 | } 37 | } 38 | List rawInputs = builder.inputs; 39 | List rawOutputs = builder.outputs; 40 | this.closeables = new ArrayList<>(rawInputs.size() + rawOutputs.size()); 41 | ValueContext valueContext = new ValueContext( 42 | builder.api, 43 | arena, 44 | builder.session.environment.ortAllocator, 45 | builder.session.environment.memoryInfo, 46 | closeables); 47 | this.inputs = add(rawInputs, valueContext, api, ioBinding, true); 48 | this.outputs = add(rawOutputs, valueContext, api, ioBinding, false); 49 | } 50 | 51 | private static final NamedCollectionImpl add( 52 | List nodes, 53 | ValueContext valueContext, 54 | ApiImpl api, 55 | MemorySegment ioBinding, 56 | boolean isInput) { 57 | LinkedHashMap out = new LinkedHashMap<>(nodes.size()); 58 | for (NodeInfoImpl node : nodes) { 59 | OnnxValueImpl output = node.getTypeInfo().newValue(valueContext, null); 60 | MemorySegment valueAddress = output.toNative(); 61 | valueContext.closeables().add(() -> api.ReleaseValue.apply(valueAddress)); 62 | out.put(node.getName(), output); 63 | final MemorySegment result; 64 | if (isInput) { 65 | result = api.BindInput.apply(ioBinding, node.nameSegment, valueAddress); 66 | } else { 67 | result = api.BindOutput.apply(ioBinding, node.nameSegment, valueAddress); 68 | } 69 | api.checkStatus(result); 70 | } 71 | return new NamedCollectionImpl<>(out); 72 | } 73 | 74 | @Override 75 | public void close() { 76 | api.ReleaseIoBinding.apply(ioBinding); 77 | api.ReleaseRunOptions.apply(runOptions); 78 | for (Runnable closeable : closeables) { 79 | closeable.run(); 80 | } 81 | arena.close(); 82 | } 83 | 84 | @Override 85 | public void run() { 86 | api.checkStatus(api.RunWithBinding.apply(session, runOptions, ioBinding)); 87 | } 88 | 89 | static final class Builder implements IoBinding.Builder { 90 | 91 | final ApiImpl api; 92 | final SessionImpl session; 93 | private Map config; 94 | final List inputs; 95 | final List outputs; 96 | 97 | public Builder(SessionImpl session) { 98 | this.api = session.api; 99 | this.session = session; 100 | this.inputs = new ArrayList<>(session.inputs.size()); 101 | this.outputs = new ArrayList<>(session.outputs.size()); 102 | } 103 | 104 | @Override 105 | public IoBinding build() { 106 | if (inputs.isEmpty()) { 107 | throw new IllegalArgumentException("No inputs specified"); 108 | } 109 | if (outputs.isEmpty()) { 110 | throw new IllegalArgumentException("No outputs specified"); 111 | } 112 | return new IoBindingImpl(this); 113 | } 114 | 115 | @Override 116 | public Builder setConfigMap(Map config) { 117 | this.config = config; 118 | return this; 119 | } 120 | 121 | private void accumulate(List list, NodeInfoImpl nodeInfo) { 122 | if (nodeInfo == null) { 123 | throw new IllegalArgumentException("node info missing"); 124 | } 125 | list.add(nodeInfo); 126 | } 127 | 128 | @Override 129 | public Builder bindInput(String name) { 130 | accumulate(inputs, session.inputs.get(name)); 131 | return this; 132 | } 133 | 134 | @Override 135 | public Builder bindInput(int index) { 136 | accumulate(inputs, session.inputs.get(index)); 137 | return this; 138 | } 139 | 140 | @Override 141 | public Builder bindOutput(String name) { 142 | accumulate(outputs, session.outputs.get(name)); 143 | return this; 144 | } 145 | 146 | @Override 147 | public Builder bindOutput(int index) { 148 | accumulate(outputs, session.outputs.get(index)); 149 | return this; 150 | } 151 | } 152 | 153 | @Override 154 | public IoBinding setLogSeverityLevel(OnnxRuntimeLoggingLevel level) { 155 | api.checkStatus(api.RunOptionsSetRunLogSeverityLevel.apply(runOptions, level.getNumber())); 156 | return this; 157 | } 158 | 159 | @Override 160 | public IoBinding setLogVerbosityLevel(int level) { 161 | api.checkStatus(api.RunOptionsSetRunLogVerbosityLevel.apply(runOptions, level)); 162 | return this; 163 | } 164 | 165 | @Override 166 | public IoBinding setRunTag(String runTag) { 167 | try (Arena arena = Arena.ofConfined()) { 168 | MemorySegment segment = arena.allocateFrom(runTag); 169 | api.checkStatus(api.RunOptionsSetRunTag.apply(runOptions, segment)); 170 | } 171 | return this; 172 | } 173 | 174 | @Override 175 | public NamedCollection getInputs() { 176 | return inputs; 177 | } 178 | 179 | @Override 180 | public NamedCollection getOutputs() { 181 | return outputs; 182 | } 183 | 184 | @Override 185 | public IoBinding synchronizeBoundInputs() { 186 | api.checkStatus(api.SynchronizeBoundInputs.apply(ioBinding)); 187 | return this; 188 | } 189 | 190 | @Override 191 | public IoBinding synchronizeBoundOutputs() { 192 | api.checkStatus(api.SynchronizeBoundOutputs.apply(ioBinding)); 193 | return this; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/Loader.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | import java.lang.System.Logger; 10 | import java.lang.System.Logger.Level; 11 | import java.nio.file.Files; 12 | import java.nio.file.Path; 13 | import java.util.Locale; 14 | import java.util.Scanner; 15 | 16 | final class Loader { 17 | 18 | private static final Logger LOG = System.getLogger(Loader.class.getName()); 19 | 20 | static void load() { 21 | try { 22 | String arch = normalizeArch(System.getProperty("os.arch", "")); 23 | String os = normalizeOs(System.getProperty("os.name", "")); 24 | LOG.log(Level.DEBUG, "OS: " + os + ", Architecture: " + arch); 25 | Class clazz = Loader.class; 26 | String packagePath = clazz.getPackage().getName().replaceAll("\\.", "/"); 27 | String resourcePrefix = packagePath + "/native/" + os + "-" + arch + "/"; 28 | LOG.log(Level.DEBUG, "Searching for distributed libraries using resource prefix " + resourcePrefix); 29 | ClassLoader classLoader = clazz.getClassLoader(); 30 | try (InputStream inputStream = classLoader.getResourceAsStream(resourcePrefix + "libraries")) { 31 | if (inputStream == null) { 32 | LOG.log(Level.DEBUG, "No distributed packages, falling back to System.loadLibrary()"); 33 | // not found, fall back to lib path 34 | System.loadLibrary("onnxruntime"); 35 | return; 36 | } 37 | Path tempDir = Files.createTempDirectory("onnxruntime-java").toAbsolutePath(); 38 | tempDir.toFile().deleteOnExit(); 39 | try (Scanner scanner = new Scanner(inputStream)) { 40 | while (scanner.hasNext()) { 41 | String fileName = scanner.next(); 42 | String resource = resourcePrefix + fileName; 43 | Path filePath = tempDir.resolve(fileName).toAbsolutePath(); 44 | LOG.log(Level.DEBUG, "Copying " + filePath); 45 | filePath.toFile().deleteOnExit(); 46 | try (InputStream objStream = classLoader.getResourceAsStream(resource)) { 47 | Files.copy(objStream, filePath); 48 | } 49 | } 50 | } 51 | String libraryPath = tempDir.resolve(System.mapLibraryName("onnxruntime")) 52 | .toAbsolutePath() 53 | .toString(); 54 | LOG.log(Level.DEBUG, "Loading " + libraryPath); 55 | System.load(libraryPath); 56 | } 57 | } catch (IOException e) { 58 | UnsatisfiedLinkError ule = new UnsatisfiedLinkError("Failed to load onnxruntime"); 59 | e.initCause(e); 60 | throw ule; 61 | } 62 | } 63 | 64 | private static String normalize(String value) { 65 | return value.toLowerCase(Locale.US).replaceAll("[^a-z0-9]+", ""); 66 | } 67 | 68 | private static String normalizeArch(String value) { 69 | value = normalize(value); 70 | if (value.matches("^(x8664|amd64|ia32e|em64t|x64)$")) { 71 | return "x86_64"; 72 | } 73 | if ("aarch64".equals(value)) { 74 | return "aarch_64"; 75 | } 76 | throw new UnsupportedOperationException("Architecture is not supported"); 77 | } 78 | 79 | private static String normalizeOs(String value) { 80 | value = normalize(value); 81 | if (value.startsWith("linux")) { 82 | return "linux"; 83 | } 84 | if (value.startsWith("macosx") || value.startsWith("osx") || value.startsWith("darwin")) { 85 | return "osx"; 86 | } 87 | if (value.startsWith("windows")) { 88 | return "windows"; 89 | } 90 | throw new UnsupportedOperationException("Operating system is not supported"); 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ManagedImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | 10 | abstract class ManagedImpl implements AutoCloseable { 11 | 12 | protected final ApiImpl api; 13 | protected final Arena arena; 14 | 15 | protected ManagedImpl(ApiImpl api, Arena arena) { 16 | this.api = api; 17 | this.arena = arena; 18 | } 19 | 20 | @Override 21 | public void close() { 22 | arena.close(); 23 | } 24 | 25 | abstract MemorySegment address(); 26 | 27 | @Override 28 | public String toString() { 29 | return "{" + getClass().getSimpleName() + ": " + address() + "}"; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/MapInfo.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * A description of the type information related to ONNX's "Map" type. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public interface MapInfo { 13 | 14 | /** 15 | * Get the key information. 16 | * @return the key type 17 | */ 18 | OnnxTensorElementDataType getKeyType(); 19 | 20 | /** 21 | * Get the value information. 22 | * @return the value type 23 | */ 24 | TypeInfo getValueType(); 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | 9 | final class MapInfoImpl implements MapInfo { 10 | 11 | private final OnnxTensorElementDataType keyType; 12 | private final TypeInfoImpl valueType; 13 | 14 | MapInfoImpl(OnnxTensorElementDataType keyType, TypeInfoImpl valueType) { 15 | this.keyType = keyType; 16 | this.valueType = valueType; 17 | } 18 | 19 | @Override 20 | public OnnxTensorElementDataType getKeyType() { 21 | return keyType; 22 | } 23 | 24 | @Override 25 | public TypeInfoImpl getValueType() { 26 | return valueType; 27 | } 28 | 29 | @Override 30 | public String toString() { 31 | return "map(" + keyType + "," + valueType + ")"; 32 | } 33 | 34 | final OnnxValueImpl newValue(ValueContext valueContext, MemorySegment ortValueAddress) { 35 | if (valueType.getType() != OnnxType.TENSOR || valueType.getTensorInfo().getElementCount() != 1) { 36 | throw new UnsupportedOperationException("OnnxMap only supports scalar values"); 37 | } 38 | switch (keyType) { 39 | case INT64: 40 | return new OnnxMapLongImpl(this, valueContext, ortValueAddress); 41 | case STRING: 42 | return new OnnxMapStringImpl(this, valueContext, ortValueAddress); 43 | default: 44 | throw new UnsupportedOperationException("OnnxMap does not support keys of type " + keyType); 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ModelMetadata.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * A representation of the metadata stored in an ONNX model. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface ModelMetadata { 15 | /** 16 | * Access the description 17 | * @return description 18 | */ 19 | String getDescription(); 20 | 21 | /** 22 | * Access the domain 23 | * @return domain 24 | */ 25 | String getDomain(); 26 | 27 | /** 28 | * Access the graph description 29 | * @return graph description 30 | */ 31 | String getGraphDescription(); 32 | 33 | /** 34 | * Access the graph name 35 | * @return graph name 36 | */ 37 | String getGraphName(); 38 | 39 | /** 40 | * Access the producer name 41 | * @return producer 42 | */ 43 | String getProducerName(); 44 | 45 | /** 46 | * Access the version 47 | * @return version 48 | */ 49 | long getVersion(); 50 | 51 | /** 52 | * Access the custom metadata map. 53 | * @return an immutable map of custom key-value pairs. 54 | */ 55 | Map getCustomMetadata(); 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 9 | 10 | import java.lang.foreign.Arena; 11 | import java.lang.foreign.MemorySegment; 12 | import java.util.Collections; 13 | import java.util.LinkedHashMap; 14 | import java.util.Map; 15 | 16 | final class ModelMetadataImpl implements ModelMetadata { 17 | 18 | private final String description; 19 | private final String domain; 20 | private final String graphDescription; 21 | private final String graphName; 22 | private final String producerName; 23 | private final long version; 24 | private final Map customMetadata; 25 | 26 | ModelMetadataImpl(ApiImpl api, Arena arena, MemorySegment metadata, MemorySegment ortAllocator) { 27 | { 28 | MemorySegment pointer = 29 | api.create(arena, out -> api.ModelMetadataGetDescription.apply(metadata, ortAllocator, out)); 30 | this.description = pointer.getString(0); 31 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, pointer)); 32 | } 33 | { 34 | MemorySegment pointer = 35 | api.create(arena, out -> api.ModelMetadataGetDomain.apply(metadata, ortAllocator, out)); 36 | this.domain = pointer.getString(0); 37 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, pointer)); 38 | } 39 | { 40 | MemorySegment pointer = 41 | api.create(arena, out -> api.ModelMetadataGetGraphDescription.apply(metadata, ortAllocator, out)); 42 | this.graphDescription = pointer.getString(0); 43 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, pointer)); 44 | } 45 | { 46 | MemorySegment pointer = 47 | api.create(arena, out -> api.ModelMetadataGetGraphName.apply(metadata, ortAllocator, out)); 48 | this.graphName = pointer.getString(0); 49 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, pointer)); 50 | } 51 | { 52 | MemorySegment pointer = 53 | api.create(arena, out -> api.ModelMetadataGetProducerName.apply(metadata, ortAllocator, out)); 54 | this.producerName = pointer.getString(0); 55 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, pointer)); 56 | } 57 | { 58 | this.version = api.extractLong(arena, out -> api.ModelMetadataGetVersion.apply(metadata, out)); 59 | } 60 | { 61 | MemorySegment count = arena.allocate(C_LONG); 62 | MemorySegment keys = api.create( 63 | arena, out -> api.ModelMetadataGetCustomMetadataMapKeys.apply(metadata, ortAllocator, out, count)); 64 | long numKeys = count.getAtIndex(C_LONG, 0); 65 | if (numKeys == 0) { 66 | this.customMetadata = Collections.emptyMap(); 67 | } else { 68 | Map customMetadata = new LinkedHashMap<>(Math.toIntExact(numKeys)); 69 | for (long i = 0; i < numKeys; i++) { 70 | MemorySegment key = keys.getAtIndex(C_POINTER, i); 71 | MemorySegment value = api.create( 72 | arena, 73 | out -> api.ModelMetadataLookupCustomMetadataMap.apply(metadata, ortAllocator, key, out)); 74 | customMetadata.put(key.getString(0), value.getString(0)); 75 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, key)); 76 | api.checkStatus(api.AllocatorFree.apply(ortAllocator, value)); 77 | } 78 | this.customMetadata = Collections.unmodifiableMap(customMetadata); 79 | } 80 | } 81 | } 82 | 83 | @Override 84 | public String getDescription() { 85 | return description; 86 | } 87 | 88 | @Override 89 | public String getDomain() { 90 | return domain; 91 | } 92 | 93 | @Override 94 | public String getGraphDescription() { 95 | return graphDescription; 96 | } 97 | 98 | @Override 99 | public String getGraphName() { 100 | return graphName; 101 | } 102 | 103 | @Override 104 | public String getProducerName() { 105 | return producerName; 106 | } 107 | 108 | @Override 109 | public long getVersion() { 110 | return version; 111 | } 112 | 113 | @Override 114 | public Map getCustomMetadata() { 115 | return customMetadata; 116 | } 117 | 118 | @Override 119 | public String toString() { 120 | return "{ModelMetadata: graphName=" + graphName + ", version=" + version + ", producerName=" + producerName 121 | + ", domain=" + domain + "}"; 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/NamedCollection.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | /** 11 | * An immutable collection which can be accessed by list index or name. 12 | * @param the type in the collection 13 | * @since 1.0.0 14 | */ 15 | public interface NamedCollection { 16 | 17 | /** 18 | * Get by name. 19 | * @param name the name of the desired input or output 20 | * @return a value in the collection, or null if it is not present 21 | */ 22 | V get(String name); 23 | 24 | /** 25 | * Get by index. 26 | * @param index the index of the desired the input or output 27 | * @return a value in the collection 28 | * @throws ArrayIndexOutOfBoundsException if the index is out of bounds 29 | */ 30 | V get(int index); 31 | 32 | /** 33 | * Get a map view. 34 | * @return an unmodifiable map 35 | */ 36 | Map getMap(); 37 | 38 | /** 39 | * Get a list view. 40 | * @return an unmodifiable list 41 | */ 42 | List getList(); 43 | 44 | /** 45 | * Get the size 46 | * @return size 47 | */ 48 | int size(); 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/NamedCollectionImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.ArrayList; 8 | import java.util.Collections; 9 | import java.util.LinkedHashMap; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | final class NamedCollectionImpl implements NamedCollection { 14 | 15 | private final Map map; 16 | private final List list; 17 | 18 | NamedCollectionImpl(LinkedHashMap map) { 19 | this.map = Collections.unmodifiableMap(map); 20 | List rawList = new ArrayList<>(map.size()); 21 | rawList.addAll(map.values()); 22 | this.list = Collections.unmodifiableList(rawList); 23 | } 24 | 25 | @Override 26 | public V get(String name) { 27 | return map.get(name); 28 | } 29 | 30 | @Override 31 | public V get(int index) { 32 | return list.get(index); 33 | } 34 | 35 | @Override 36 | public Map getMap() { 37 | return map; 38 | } 39 | 40 | @Override 41 | public List getList() { 42 | return list; 43 | } 44 | 45 | @Override 46 | public String toString() { 47 | return map.toString(); 48 | } 49 | 50 | @Override 51 | public int size() { 52 | return list.size(); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/NodeInfo.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * A tuple of name and type information. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public interface NodeInfo { 13 | 14 | /** 15 | * Get the node's name. 16 | * @return name 17 | */ 18 | String getName(); 19 | 20 | /** 21 | * Get the type information for this node. 22 | * @return type information 23 | */ 24 | TypeInfo getTypeInfo(); 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | 9 | final class NodeInfoImpl implements NodeInfo { 10 | 11 | private final String name; 12 | final MemorySegment nameSegment; 13 | private final TypeInfoImpl typeInfo; 14 | 15 | NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) { 16 | this.name = name; 17 | this.nameSegment = nameSegment; 18 | this.typeInfo = typeInfo; 19 | } 20 | 21 | @Override 22 | public String getName() { 23 | return name; 24 | } 25 | 26 | @Override 27 | public TypeInfoImpl getTypeInfo() { 28 | return typeInfo; 29 | } 30 | 31 | @Override 32 | public String toString() { 33 | return "{NodeInfo: name=" + name + ", typeInfo=" + typeInfo + "}"; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxMap.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.NoSuchElementException; 8 | 9 | /** 10 | * A map view of an {@link OnnxValue}. Use the {@link MapInfo#getKeyType()} to 11 | * select a type-safe view. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type. 12 | * 13 | * @since 1.0.0 14 | */ 15 | public interface OnnxMap { 16 | 17 | /** 18 | * Get information about the key and value types. 19 | * @return map's type information 20 | */ 21 | MapInfo getInfo(); 22 | 23 | /** 24 | * Get a type safe view with long keys 25 | * @return long map view 26 | * @throws NoSuchElementException if this typed view does not exist 27 | */ 28 | OnnxTypedMap asLongMap(); 29 | 30 | /** 31 | * Get a type safe view with string keys 32 | * @return string map view 33 | * @throws NoSuchElementException if this typed view does not exist 34 | */ 35 | OnnxTypedMap asStringMap(); 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 8 | 9 | import java.lang.foreign.Arena; 10 | import java.lang.foreign.MemorySegment; 11 | import java.util.Collection; 12 | import java.util.Collections; 13 | import java.util.LinkedHashMap; 14 | import java.util.Map; 15 | import java.util.NoSuchElementException; 16 | import java.util.Set; 17 | import java.util.stream.Stream; 18 | 19 | abstract class OnnxMapImpl extends OnnxValueImpl implements OnnxMap, OnnxTypedMap { 20 | 21 | private final Map data; 22 | private final Map unmodifiableData; 23 | protected final MapInfoImpl mapInfo; 24 | 25 | protected OnnxMapImpl(MapInfoImpl mapInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 26 | super(OnnxType.MAP, valueContext); 27 | this.data = new LinkedHashMap<>(); 28 | this.mapInfo = mapInfo; 29 | if (ortValueAddress != null) { 30 | ApiImpl api = valueContext.api(); 31 | Arena arena = valueContext.arena(); 32 | MemorySegment ortAllocator = valueContext.ortAllocatorAddress(); 33 | MemorySegment keyAddress = 34 | api.create(arena, out -> api.GetValue.apply(ortValueAddress, 0, ortAllocator, out)); 35 | valueContext.closeables().add(() -> api.ReleaseValue.apply(keyAddress)); 36 | MemorySegment valueAddress = 37 | api.create(arena, out -> api.GetValue.apply(ortValueAddress, 1, ortAllocator, out)); 38 | valueContext.closeables().add(() -> api.ReleaseValue.apply(valueAddress)); 39 | MemorySegment keyInfo = api.create(arena, out -> api.GetTensorTypeAndShape.apply(keyAddress, out)); 40 | int size = 41 | Math.toIntExact(api.extractLong(arena, out -> api.GetTensorShapeElementCount.apply(keyInfo, out))); 42 | api.ReleaseTensorTypeAndShapeInfo.apply(keyInfo); 43 | T keyVector = newKeyVector(size, keyAddress); 44 | OnnxTensorImpl valueVector = newValueVector(size, valueAddress); 45 | valueVector.getScalars(explodeKeyVector(keyVector).map(this::set)); 46 | } 47 | this.unmodifiableData = Collections.unmodifiableMap(data); 48 | } 49 | 50 | private final OnnxTensorImpl newValueVector(int size, MemorySegment valueAddress) { 51 | return TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, valueContext.arena()) 52 | .newValue(valueContext, valueAddress); 53 | } 54 | 55 | private final T newKeyVector(int size, MemorySegment keyAddress) { 56 | return newKeyVector(TensorInfoImpl.of(mapInfo.getKeyType(), size, valueContext.arena()), keyAddress); 57 | } 58 | 59 | protected abstract T newKeyVector(TensorInfoImpl tensorInfo, MemorySegment keyAddress); 60 | 61 | protected abstract void implodeKeyVector(T keyVector, Set keys); 62 | 63 | protected abstract Stream explodeKeyVector(T keyVector); 64 | 65 | @Override 66 | public final String toString() { 67 | return "{OnnxMap: info=" + mapInfo + ", data=" + data + "}"; 68 | } 69 | 70 | @Override 71 | public final MapInfo getInfo() { 72 | return mapInfo; 73 | } 74 | 75 | @Override 76 | public final OnnxMap asMap() { 77 | return this; 78 | } 79 | 80 | private final NoSuchElementException fail() { 81 | return new NoSuchElementException("Invalid access of OnnxMap with key type " + mapInfo.getKeyType()); 82 | } 83 | 84 | @Override 85 | public OnnxTypedMap asLongMap() { 86 | throw fail(); 87 | } 88 | 89 | @Override 90 | public OnnxTypedMap asStringMap() { 91 | throw fail(); 92 | } 93 | 94 | @Override 95 | public MemorySegment toNative() { 96 | ApiImpl api = valueContext.api(); 97 | Arena arena = valueContext.arena(); 98 | int size = data.size(); 99 | T keyVector = newKeyVector(size, null); 100 | implodeKeyVector(keyVector, data.keySet()); 101 | OnnxTensorImpl valueVector = TensorInfoImpl.of( 102 | mapInfo.getValueType().getTensorInfo().getType(), size, arena) 103 | .newValue(valueContext, null); 104 | valueVector.putScalars(data.values()); 105 | MemorySegment kvArray = arena.allocate(C_POINTER, 2); 106 | kvArray.setAtIndex(C_POINTER, 0, keyVector.toNative()); 107 | kvArray.setAtIndex(C_POINTER, 1, valueVector.toNative()); 108 | return api.create(arena, out -> api.CreateValue.apply(kvArray, 2, OnnxType.MAP.getNumber(), out)); 109 | } 110 | 111 | @Override 112 | public final OnnxTensorImpl set(K key) { 113 | OnnxTensorImpl input = mapInfo.getValueType().getTensorInfo().newValue(valueContext, null); 114 | data.put(key, input); 115 | return input; 116 | } 117 | 118 | @Override 119 | public final OnnxValue put(K key, OnnxValue value) { 120 | return unmodifiableData.put(key, value); 121 | } 122 | 123 | @Override 124 | public final void putAll(Map m) { 125 | unmodifiableData.putAll(m); 126 | } 127 | 128 | @Override 129 | public final int size() { 130 | return unmodifiableData.size(); 131 | } 132 | 133 | @Override 134 | public final boolean isEmpty() { 135 | return unmodifiableData.isEmpty(); 136 | } 137 | 138 | @Override 139 | public final boolean containsKey(Object key) { 140 | return unmodifiableData.containsKey(key); 141 | } 142 | 143 | @Override 144 | public final boolean containsValue(Object value) { 145 | return unmodifiableData.containsValue(value); 146 | } 147 | 148 | @Override 149 | public final OnnxValue get(Object key) { 150 | return unmodifiableData.get(key); 151 | } 152 | 153 | @Override 154 | public final OnnxValue remove(Object key) { 155 | return unmodifiableData.remove(key); 156 | } 157 | 158 | @Override 159 | public final void clear() { 160 | unmodifiableData.clear(); 161 | } 162 | 163 | @Override 164 | public final Set keySet() { 165 | return unmodifiableData.keySet(); 166 | } 167 | 168 | @Override 169 | public final Collection values() { 170 | return unmodifiableData.values(); 171 | } 172 | 173 | @Override 174 | public final Set> entrySet() { 175 | return unmodifiableData.entrySet(); 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.LongBuffer; 9 | import java.util.ArrayList; 10 | import java.util.List; 11 | import java.util.Set; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxMapLongImpl extends OnnxMapImpl { 15 | 16 | OnnxMapLongImpl(MapInfoImpl mapInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 17 | super(mapInfo, valueContext, ortValueAddress); 18 | } 19 | 20 | @Override 21 | protected OnnxTensorLongImpl newKeyVector(TensorInfoImpl tensorInfo, MemorySegment keyAddress) { 22 | return new OnnxTensorLongImpl(tensorInfo, valueContext, keyAddress); 23 | } 24 | 25 | @Override 26 | public OnnxTypedMap asLongMap() { 27 | return this; 28 | } 29 | 30 | @Override 31 | protected void implodeKeyVector(OnnxTensorLongImpl keyVector, Set keys) { 32 | LongBuffer buffer = keyVector.buffer; 33 | for (Long key : keys) { 34 | buffer.put(key); 35 | } 36 | buffer.flip(); 37 | } 38 | 39 | @Override 40 | protected Stream explodeKeyVector(OnnxTensorLongImpl keyVector) { 41 | LongBuffer buf = keyVector.buffer; 42 | int count = buf.capacity(); 43 | List out = new ArrayList<>(count); 44 | for (int i = 0; i < buf.capacity(); i++) { 45 | out.add(buf.get(i)); 46 | } 47 | return out.stream(); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxMapStringImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.Arrays; 9 | import java.util.Set; 10 | import java.util.stream.Stream; 11 | 12 | final class OnnxMapStringImpl extends OnnxMapImpl { 13 | 14 | OnnxMapStringImpl(MapInfoImpl mapInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 15 | super(mapInfo, valueContext, ortValueAddress); 16 | } 17 | 18 | @Override 19 | protected OnnxTensorStringImpl newKeyVector(TensorInfoImpl tensorInfo, MemorySegment keyAddress) { 20 | return new OnnxTensorStringImpl(tensorInfo, valueContext, keyAddress); 21 | } 22 | 23 | @Override 24 | public OnnxTypedMap asStringMap() { 25 | return this; 26 | } 27 | 28 | @Override 29 | protected void implodeKeyVector(OnnxTensorStringImpl keyVector, Set keys) { 30 | String[] buffer = keyVector.getStringBuffer(); 31 | int i = 0; 32 | for (String key : keys) { 33 | buffer[i++] = key; 34 | } 35 | } 36 | 37 | @Override 38 | protected Stream explodeKeyVector(OnnxTensorStringImpl keyVector) { 39 | return Arrays.stream(keyVector.getStringBuffer()); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxOpaque.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.nio.ByteBuffer; 8 | 9 | interface OnnxOpaque { 10 | 11 | OpaqueInfo getInfo(); 12 | 13 | void set(ByteBuffer buffer); 14 | 15 | ByteBuffer get(); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxOptional.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | interface OnnxOptional { 8 | 9 | TypeInfo getInfo(); 10 | 11 | boolean isPresent(); 12 | 13 | OnnxValue get(); 14 | 15 | OnnxValue set(); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntime.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * This is the entrypoint into this library. It is a singleton that is accessible using {@link #get()}. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public interface OnnxRuntime { 13 | 14 | /** 15 | * The Microsoft ONNX runtime version. 16 | * @return x.y.z version 17 | */ 18 | String getVersion(); 19 | 20 | /** 21 | * This function returns the value of the ORT_API_VERSION constant. 22 | * @return the int version number 23 | * @since 1.2.0 24 | */ 25 | int getApiVersion(); 26 | 27 | /** 28 | * This method provides a view of the latest API. 29 | * @return a wrapper around the main 30 | */ 31 | Api getApi(); 32 | 33 | /** 34 | * On first use, this attempts to automatically load the library if it is discovered on the classpath. Otherwise, the Java library path will be used. 35 | * @return a singleton loaded instance 36 | */ 37 | static OnnxRuntime get() { 38 | return OnnxRuntimeImpl.INSTANCE; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeException.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * An exception thrown from within the ONNX runtime. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public class OnnxRuntimeException extends RuntimeException { 13 | private static final int UNKNOWN_CODE = -1; 14 | private final int code; 15 | 16 | private static final long serialVersionUID = 1322802709564007780L; 17 | 18 | public OnnxRuntimeException(int code) { 19 | super(); 20 | this.code = code; 21 | } 22 | 23 | public OnnxRuntimeException() { 24 | this(UNKNOWN_CODE); 25 | } 26 | 27 | public OnnxRuntimeException(int code, String message) { 28 | super(message + " (code=" + code + ")"); 29 | this.code = code; 30 | } 31 | 32 | public OnnxRuntimeException(String message) { 33 | this(UNKNOWN_CODE, message); 34 | } 35 | 36 | public OnnxRuntimeException(String message, Throwable cause) { 37 | super(message, cause); 38 | this.code = UNKNOWN_CODE; 39 | } 40 | 41 | public int getCode() { 42 | return code; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeExecutionMode.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * A preference for sequential or parallel execution. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public enum OnnxRuntimeExecutionMode { 13 | SEQUENTIAL(0), 14 | PARALLEL(1); 15 | 16 | private final int number; 17 | 18 | private OnnxRuntimeExecutionMode(int number) { 19 | this.number = number; 20 | } 21 | 22 | public int getNumber() { 23 | return number; 24 | } 25 | 26 | public static final OnnxRuntimeExecutionMode forNumber(int number) { 27 | return switch (number) { 28 | case 0 -> SEQUENTIAL; 29 | case 1 -> PARALLEL; 30 | default -> SEQUENTIAL; 31 | }; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.ORT_API_VERSION; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtGetApiBase; 9 | 10 | import com.jyuzawa.onnxruntime_extern.OrtApiBase; 11 | import java.lang.System.Logger.Level; 12 | import java.lang.foreign.MemorySegment; 13 | 14 | // NOTE: this class actually is more like OrtApiBase 15 | enum OnnxRuntimeImpl implements OnnxRuntime { 16 | INSTANCE; 17 | 18 | private final String version; 19 | private final ApiImpl api; 20 | private final int ortApiVersion; 21 | 22 | private OnnxRuntimeImpl() { 23 | Loader.load(); 24 | MemorySegment segment = OrtGetApiBase(); 25 | this.ortApiVersion = ORT_API_VERSION(); 26 | MemorySegment apiAddress = OrtApiBase.GetApiFunction(segment).apply(ortApiVersion); 27 | if (MemorySegment.NULL.address() == apiAddress.address()) { 28 | throw new UnsatisfiedLinkError( 29 | "Onnxruntime native library present, but does not provide API version " + ortApiVersion); 30 | } 31 | this.version = OrtApiBase.GetVersionStringFunction(segment).apply().getString(0); 32 | MemorySegment apiSegment = OrtApiBase.GetApiFunction(segment).apply(ORT_API_VERSION()); 33 | this.api = new ApiImpl(apiSegment); 34 | System.getLogger(OnnxRuntimeImpl.class.getName()) 35 | .log( 36 | Level.DEBUG, 37 | "Version: {0}, API Version: {1}, Build Info: {2}", 38 | version, 39 | ortApiVersion, 40 | api.getBuildString()); 41 | } 42 | 43 | @Override 44 | public String getVersion() { 45 | return version; 46 | } 47 | 48 | @Override 49 | public Api getApi() { 50 | return api; 51 | } 52 | 53 | @Override 54 | public int getApiVersion() { 55 | return ortApiVersion; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeLoggingLevel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_INT; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 9 | 10 | import java.lang.System.Logger; 11 | import java.lang.System.Logger.Level; 12 | import java.lang.foreign.Arena; 13 | import java.lang.foreign.FunctionDescriptor; 14 | import java.lang.foreign.Linker; 15 | import java.lang.foreign.MemorySegment; 16 | import java.lang.invoke.MethodHandle; 17 | import java.lang.invoke.MethodHandles; 18 | import java.lang.invoke.MethodType; 19 | 20 | /** 21 | * The level for the internal logger within the ONNX runtime. 22 | * 23 | * @since 1.0.0 24 | */ 25 | public enum OnnxRuntimeLoggingLevel { 26 | VERBOSE(0), 27 | INFO(1), 28 | WARNING(2), 29 | ERROR(3), 30 | FATAL(4); 31 | 32 | private static final Logger LOG = System.getLogger(Environment.class.getName()); 33 | 34 | static final MemorySegment LOG_CALLBACK = createCallback(); 35 | static final OnnxRuntimeLoggingLevel DEFAULT = getDefaultLogLevel(); 36 | 37 | private final int number; 38 | 39 | private OnnxRuntimeLoggingLevel(int number) { 40 | this.number = number; 41 | } 42 | 43 | public int getNumber() { 44 | return number; 45 | } 46 | 47 | /** 48 | * Get a level based off its internal number. 49 | * @param number the internal number of the level 50 | * @return the level, VERBOSE if not found 51 | */ 52 | public static final OnnxRuntimeLoggingLevel forNumber(int number) { 53 | return switch (number) { 54 | case 0 -> VERBOSE; 55 | case 1 -> INFO; 56 | case 2 -> WARNING; 57 | case 3 -> ERROR; 58 | case 4 -> FATAL; 59 | default -> VERBOSE; 60 | }; 61 | } 62 | 63 | @SuppressWarnings("unused") 64 | private static final void logCallback( 65 | MemorySegment parameterAddress, 66 | int level, 67 | MemorySegment categoryAddress, 68 | MemorySegment idAddress, 69 | MemorySegment locationAddress, 70 | MemorySegment messageAddress) { 71 | Level theLevel = 72 | switch (OnnxRuntimeLoggingLevel.forNumber(level)) { 73 | case VERBOSE -> Level.DEBUG; 74 | case INFO -> Level.INFO; 75 | case WARNING -> Level.WARNING; 76 | case FATAL, ERROR -> Level.ERROR; 77 | }; 78 | if (LOG.isLoggable(theLevel)) { 79 | String category = categoryAddress.getString(0); 80 | String id = idAddress.getString(0); 81 | String location = locationAddress.getString(0); 82 | String message = messageAddress.getString(0); 83 | LOG.log(theLevel, category + ' ' + id + ' ' + location + ' ' + message); 84 | } 85 | } 86 | 87 | private static final OnnxRuntimeLoggingLevel getDefaultLogLevel() { 88 | if (LOG.isLoggable(Level.DEBUG)) { 89 | return VERBOSE; 90 | } 91 | if (LOG.isLoggable(Level.INFO)) { 92 | return INFO; 93 | } 94 | if (LOG.isLoggable(Level.WARNING)) { 95 | return WARNING; 96 | } 97 | if (LOG.isLoggable(Level.ERROR)) { 98 | return ERROR; 99 | } 100 | return FATAL; 101 | } 102 | 103 | private static final MemorySegment createCallback() { 104 | try { 105 | FunctionDescriptor LOG_CALLBACK_DESCRIPTOR = 106 | FunctionDescriptor.ofVoid(C_POINTER, C_INT, C_POINTER, C_POINTER, C_POINTER, C_POINTER); 107 | MethodHandle LOG_CALLBACK_HANDLE = MethodHandles.lookup() 108 | .findStatic( 109 | OnnxRuntimeLoggingLevel.class, 110 | "logCallback", 111 | MethodType.methodType( 112 | void.class, 113 | MemorySegment.class, 114 | int.class, 115 | MemorySegment.class, 116 | MemorySegment.class, 117 | MemorySegment.class, 118 | MemorySegment.class)); 119 | return Linker.nativeLinker().upcallStub(LOG_CALLBACK_HANDLE, LOG_CALLBACK_DESCRIPTOR, Arena.global()); 120 | } catch (IllegalAccessException | NoSuchMethodException e) { 121 | throw new OnnxRuntimeException("failed to initialize logger", e); 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxRuntimeOptimizationLevel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * A level of optimization to target. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public enum OnnxRuntimeOptimizationLevel { 13 | DISABLE_ALL(0), 14 | ENABLE_BASIC(1), 15 | ENABLE_EXTENDED(2), 16 | ENABLE_ALL(99); 17 | 18 | private final int number; 19 | 20 | private OnnxRuntimeOptimizationLevel(int number) { 21 | this.number = number; 22 | } 23 | 24 | public int getNumber() { 25 | return number; 26 | } 27 | 28 | public static final OnnxRuntimeOptimizationLevel forNumber(int number) { 29 | return switch (number) { 30 | case 0 -> DISABLE_ALL; 31 | case 1 -> ENABLE_BASIC; 32 | case 2 -> ENABLE_EXTENDED; 33 | case 99 -> ENABLE_ALL; 34 | default -> DISABLE_ALL; 35 | }; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxSequence.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * A sequence view of an {@link OnnxValue}. Extends {@link java.util.List} to provide an UNMODIFIABLE view. 11 | * Use {@link #add()} to populate the sequence. 12 | * 13 | * @since 1.0.0 14 | */ 15 | public interface OnnxSequence extends List { 16 | 17 | /** 18 | * 19 | * @return The description of the type of the elements in the sequence. 20 | */ 21 | TypeInfo getInfo(); 22 | 23 | /** 24 | * 25 | * @return A new OnnxValue of correct type added to the end of the sequence. 26 | */ 27 | OnnxValue add(); 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 8 | 9 | import java.lang.foreign.Arena; 10 | import java.lang.foreign.MemorySegment; 11 | import java.util.ArrayList; 12 | import java.util.Collection; 13 | import java.util.Collections; 14 | import java.util.Iterator; 15 | import java.util.List; 16 | import java.util.ListIterator; 17 | 18 | final class OnnxSequenceImpl extends OnnxValueImpl implements OnnxSequence { 19 | 20 | private final List data; 21 | private final List unmodifiableData; 22 | private final TypeInfoImpl typeInfo; 23 | 24 | OnnxSequenceImpl(TypeInfoImpl typeInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 25 | super(OnnxType.SEQUENCE, valueContext); 26 | if (ortValueAddress == null) { 27 | this.data = new ArrayList<>(); 28 | } else { 29 | ApiImpl api = valueContext.api(); 30 | Arena arena = valueContext.arena(); 31 | int outputs = Math.toIntExact(api.extractLong(arena, out -> api.GetValueCount.apply(ortValueAddress, out))); 32 | this.data = new ArrayList<>(outputs); 33 | for (int i = 0; i < outputs; i++) { 34 | final int index = i; 35 | MemorySegment valueAddress = api.create( 36 | arena, 37 | out -> api.GetValue.apply(ortValueAddress, index, valueContext.ortAllocatorAddress(), out)); 38 | valueContext.closeables().add(() -> api.ReleaseValue.apply(valueAddress)); 39 | OnnxValueImpl value = typeInfo.newValue(valueContext, valueAddress); 40 | data.add(value); 41 | } 42 | } 43 | this.unmodifiableData = Collections.unmodifiableList(data); 44 | this.typeInfo = typeInfo; 45 | } 46 | 47 | @Override 48 | public String toString() { 49 | return "{OnnxSequence: info=" + typeInfo + ", data=" + data + "}"; 50 | } 51 | 52 | @Override 53 | public OnnxSequence asSequence() { 54 | return this; 55 | } 56 | 57 | @Override 58 | public TypeInfo getInfo() { 59 | return typeInfo; 60 | } 61 | 62 | private OnnxValueImpl newItem() { 63 | return typeInfo.newValue(valueContext, null); 64 | } 65 | 66 | @Override 67 | public OnnxValue add() { 68 | OnnxValueImpl item = newItem(); 69 | data.add(item); 70 | return item; 71 | } 72 | 73 | @Override 74 | public MemorySegment toNative() { 75 | ApiImpl api = valueContext.api(); 76 | Arena arena = valueContext.arena(); 77 | int size = data.size(); 78 | MemorySegment valuesArray = arena.allocate(C_POINTER, size); 79 | for (int i = 0; i < size; i++) { 80 | OnnxValueImpl value = data.get(i); 81 | valuesArray.setAtIndex(C_POINTER, i, value.toNative()); 82 | } 83 | return api.create(arena, out -> api.CreateValue.apply(valuesArray, size, OnnxType.SEQUENCE.getNumber(), out)); 84 | } 85 | 86 | @Override 87 | public boolean add(OnnxValue e) { 88 | return unmodifiableData.add(e); 89 | } 90 | 91 | @Override 92 | public boolean addAll(Collection c) { 93 | return unmodifiableData.addAll(c); 94 | } 95 | 96 | @Override 97 | public boolean addAll(int index, Collection c) { 98 | return unmodifiableData.addAll(index, c); 99 | } 100 | 101 | @Override 102 | public OnnxValue set(int index, OnnxValue element) { 103 | return unmodifiableData.set(index, element); 104 | } 105 | 106 | @Override 107 | public void add(int index, OnnxValue element) { 108 | unmodifiableData.add(index, element); 109 | } 110 | 111 | @Override 112 | public int size() { 113 | return unmodifiableData.size(); 114 | } 115 | 116 | @Override 117 | public boolean isEmpty() { 118 | return unmodifiableData.isEmpty(); 119 | } 120 | 121 | @Override 122 | public boolean contains(Object o) { 123 | return unmodifiableData.contains(o); 124 | } 125 | 126 | @Override 127 | public Iterator iterator() { 128 | return unmodifiableData.iterator(); 129 | } 130 | 131 | @Override 132 | public Object[] toArray() { 133 | return unmodifiableData.toArray(); 134 | } 135 | 136 | @Override 137 | public T[] toArray(T[] a) { 138 | return unmodifiableData.toArray(a); 139 | } 140 | 141 | @Override 142 | public boolean remove(Object o) { 143 | return unmodifiableData.remove(o); 144 | } 145 | 146 | @Override 147 | public boolean containsAll(Collection c) { 148 | return unmodifiableData.containsAll(c); 149 | } 150 | 151 | @Override 152 | public boolean removeAll(Collection c) { 153 | return unmodifiableData.removeAll(c); 154 | } 155 | 156 | @Override 157 | public boolean retainAll(Collection c) { 158 | return unmodifiableData.retainAll(c); 159 | } 160 | 161 | @Override 162 | public void clear() { 163 | unmodifiableData.clear(); 164 | } 165 | 166 | @Override 167 | public OnnxValue get(int index) { 168 | return unmodifiableData.get(index); 169 | } 170 | 171 | @Override 172 | public OnnxValue remove(int index) { 173 | return unmodifiableData.remove(index); 174 | } 175 | 176 | @Override 177 | public int indexOf(Object o) { 178 | return unmodifiableData.indexOf(o); 179 | } 180 | 181 | @Override 182 | public int lastIndexOf(Object o) { 183 | return unmodifiableData.lastIndexOf(o); 184 | } 185 | 186 | @Override 187 | public ListIterator listIterator() { 188 | return unmodifiableData.listIterator(); 189 | } 190 | 191 | @Override 192 | public ListIterator listIterator(int index) { 193 | return unmodifiableData.listIterator(index); 194 | } 195 | 196 | @Override 197 | public List subList(int fromIndex, int toIndex) { 198 | return unmodifiableData.subList(fromIndex, toIndex); 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxSparseTensor.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | interface OnnxSparseTensor { 8 | 9 | TensorInfo getInfo(); 10 | 11 | // TODO: representation 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensor.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.nio.ByteBuffer; 8 | import java.nio.DoubleBuffer; 9 | import java.nio.FloatBuffer; 10 | import java.nio.IntBuffer; 11 | import java.nio.LongBuffer; 12 | import java.nio.ShortBuffer; 13 | import java.util.NoSuchElementException; 14 | 15 | /** 16 | * A representation of a dense tensor. Use {@link #getInfo()} to select the proper buffer type. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type. 17 | * 18 | * @since 1.0.0 19 | */ 20 | public interface OnnxTensor { 21 | 22 | /** 23 | * Get tensor information: shape, type, etc. 24 | * @return type info 25 | */ 26 | TensorInfo getInfo(); 27 | 28 | /** 29 | * Get a byte view. 30 | * @return buffer 31 | * @throws NoSuchElementException if this typed view does not exist 32 | */ 33 | ByteBuffer getByteBuffer(); 34 | 35 | /** 36 | * Get a short view. 37 | * @return buffer 38 | * @throws NoSuchElementException if this typed view does not exist 39 | */ 40 | ShortBuffer getShortBuffer(); 41 | 42 | /** 43 | * Get a int view. 44 | * @return buffer 45 | * @throws NoSuchElementException if this typed view does not exist 46 | */ 47 | IntBuffer getIntBuffer(); 48 | 49 | /** 50 | * Get a long view. 51 | * @return buffer 52 | * @throws NoSuchElementException if this typed view does not exist 53 | */ 54 | LongBuffer getLongBuffer(); 55 | 56 | /** 57 | * Get a float view. 58 | * @return buffer 59 | * @throws NoSuchElementException if this typed view does not exist 60 | */ 61 | FloatBuffer getFloatBuffer(); 62 | 63 | /** 64 | * Get a double view. 65 | * @return buffer 66 | * @throws NoSuchElementException if this typed view does not exist 67 | */ 68 | DoubleBuffer getDoubleBuffer(); 69 | 70 | /** 71 | * Get a string view. 72 | * @return string array 73 | * @throws NoSuchElementException if this typed view does not exist 74 | */ 75 | String[] getStringBuffer(); 76 | } 77 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.nio.Buffer; 10 | import java.nio.ByteBuffer; 11 | import java.nio.ByteOrder; 12 | import java.util.function.Function; 13 | 14 | abstract class OnnxTensorBufferImpl extends OnnxTensorImpl { 15 | 16 | private final MemorySegment memorySegment; 17 | protected final T buffer; 18 | 19 | protected OnnxTensorBufferImpl( 20 | TensorInfoImpl tensorInfo, 21 | ValueContext valueContext, 22 | MemorySegment ortValueAddress, 23 | Function convert) { 24 | super(tensorInfo, valueContext); 25 | Arena arena = valueContext.arena(); 26 | if (ortValueAddress == null) { 27 | this.memorySegment = arena.allocate(tensorInfo.getType().getValueLayout(), tensorInfo.getElementCount()); 28 | } else { 29 | ApiImpl api = valueContext.api(); 30 | MemorySegment floatOutput = api.create(arena, out -> api.GetTensorMutableData.apply(ortValueAddress, out)); 31 | this.memorySegment = floatOutput.reinterpret(tensorInfo.getByteCount()); 32 | } 33 | this.buffer = convert.apply(memorySegment.asByteBuffer().order(ByteOrder.nativeOrder())); 34 | } 35 | 36 | @Override 37 | public final String toString() { 38 | return "{OnnxTensor: info=" + tensorInfo + ", buffer=" + buffer + "}"; 39 | } 40 | 41 | @Override 42 | public final MemorySegment toNative() { 43 | ApiImpl api = valueContext.api(); 44 | return api.create( 45 | valueContext.arena(), 46 | out -> api.CreateTensorWithDataAsOrtValue.apply( 47 | valueContext.memoryInfoAddress(), 48 | memorySegment, 49 | memorySegment.byteSize(), 50 | tensorInfo.shapeData, 51 | tensorInfo.getShape().size(), 52 | tensorInfo.getType().getNumber(), 53 | out)); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorByteImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.util.Collection; 10 | import java.util.function.Function; 11 | import java.util.stream.Stream; 12 | 13 | final class OnnxTensorByteImpl extends OnnxTensorBufferImpl { 14 | 15 | OnnxTensorByteImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 16 | super(tensorInfo, valueContext, ortValueAddress, Function.identity()); 17 | } 18 | 19 | @Override 20 | public ByteBuffer getByteBuffer() { 21 | return buffer; 22 | } 23 | 24 | @Override 25 | public void putScalars(Collection scalars) { 26 | for (OnnxTensorImpl scalar : scalars) { 27 | buffer.put(scalar.getByteBuffer().flip().get()); 28 | } 29 | } 30 | 31 | @Override 32 | public void getScalars(Stream scalars) { 33 | scalars.forEach(scalar -> scalar.getByteBuffer().put(buffer.get()).flip()); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorDoubleImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.nio.DoubleBuffer; 10 | import java.util.Collection; 11 | import java.util.function.Function; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxTensorDoubleImpl extends OnnxTensorBufferImpl { 15 | 16 | private static final Function CONVERT = ByteBuffer::asDoubleBuffer; 17 | 18 | OnnxTensorDoubleImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 19 | super(tensorInfo, valueContext, ortValueAddress, CONVERT); 20 | } 21 | 22 | @Override 23 | public DoubleBuffer getDoubleBuffer() { 24 | return buffer; 25 | } 26 | 27 | @Override 28 | public void putScalars(Collection scalars) { 29 | for (OnnxTensorImpl scalar : scalars) { 30 | buffer.put(scalar.getDoubleBuffer().flip().get()); 31 | } 32 | } 33 | 34 | @Override 35 | public void getScalars(Stream scalars) { 36 | scalars.forEach(scalar -> scalar.getDoubleBuffer().put(buffer.get()).flip()); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorElementDataType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_DOUBLE; 9 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_FLOAT; 10 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_INT; 11 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; 12 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_SHORT; 13 | 14 | import java.lang.foreign.ValueLayout; 15 | 16 | /** 17 | * A tensor type from ONNX. 18 | * 19 | * @since 1.0.0 20 | */ 21 | public enum OnnxTensorElementDataType { 22 | UNDEFINED(0, null), 23 | FLOAT(1, C_FLOAT), 24 | UINT8(2, null), 25 | INT8(3, C_CHAR), 26 | UINT16(4, null), 27 | INT16(5, C_SHORT), 28 | INT32(6, C_INT), 29 | INT64(7, C_LONG), 30 | STRING(8, null), 31 | BOOL(9, C_CHAR), 32 | FLOAT16(10, null), 33 | DOUBLE(11, C_DOUBLE), 34 | UINT32(12, null), 35 | UINT64(13, null), 36 | COMPLEX64(14, null), 37 | COMPLEX128(15, null), 38 | BFLOAT16(16, null); 39 | 40 | private final int number; 41 | private final ValueLayout valueLayout; 42 | 43 | private OnnxTensorElementDataType(int number, ValueLayout valueLayout) { 44 | this.number = number; 45 | this.valueLayout = valueLayout; 46 | } 47 | 48 | public int getNumber() { 49 | return number; 50 | } 51 | 52 | /** 53 | * Get a level based off its internal number. 54 | * 55 | * @param number 56 | * the internal number of the level 57 | * @return the level, UNDEFINED if not found 58 | */ 59 | public static final OnnxTensorElementDataType forNumber(int number) { 60 | return switch (number) { 61 | case 0 -> UNDEFINED; 62 | case 1 -> FLOAT; 63 | case 2 -> UINT8; 64 | case 3 -> INT8; 65 | case 4 -> UINT16; 66 | case 5 -> INT16; 67 | case 6 -> INT32; 68 | case 7 -> INT64; 69 | case 8 -> STRING; 70 | case 9 -> BOOL; 71 | case 10 -> FLOAT16; 72 | case 11 -> DOUBLE; 73 | case 12 -> UINT32; 74 | case 13 -> UINT64; 75 | case 14 -> COMPLEX64; 76 | case 15 -> COMPLEX128; 77 | case 16 -> BFLOAT16; 78 | default -> UNDEFINED; 79 | }; 80 | } 81 | 82 | ValueLayout getValueLayout() { 83 | return valueLayout; 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorFloatImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.nio.FloatBuffer; 10 | import java.util.Collection; 11 | import java.util.function.Function; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxTensorFloatImpl extends OnnxTensorBufferImpl { 15 | 16 | private static final Function CONVERT = ByteBuffer::asFloatBuffer; 17 | 18 | OnnxTensorFloatImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 19 | super(tensorInfo, valueContext, ortValueAddress, CONVERT); 20 | } 21 | 22 | @Override 23 | public FloatBuffer getFloatBuffer() { 24 | return buffer; 25 | } 26 | 27 | @Override 28 | public void putScalars(Collection scalars) { 29 | for (OnnxTensorImpl scalar : scalars) { 30 | buffer.put(scalar.getFloatBuffer().flip().get()); 31 | } 32 | } 33 | 34 | @Override 35 | public void getScalars(Stream scalars) { 36 | scalars.forEach(scalar -> scalar.getFloatBuffer().put(buffer.get()).flip()); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.nio.ByteBuffer; 8 | import java.nio.DoubleBuffer; 9 | import java.nio.FloatBuffer; 10 | import java.nio.IntBuffer; 11 | import java.nio.LongBuffer; 12 | import java.nio.ShortBuffer; 13 | import java.util.Collection; 14 | import java.util.NoSuchElementException; 15 | import java.util.stream.Stream; 16 | 17 | abstract class OnnxTensorImpl extends OnnxValueImpl implements OnnxTensor { 18 | 19 | protected final TensorInfoImpl tensorInfo; 20 | 21 | protected OnnxTensorImpl(TensorInfoImpl tensorInfo, ValueContext valueContext) { 22 | super(OnnxType.TENSOR, valueContext); 23 | this.tensorInfo = tensorInfo; 24 | } 25 | 26 | @Override 27 | public final OnnxTensor asTensor() { 28 | return this; 29 | } 30 | 31 | @Override 32 | public final TensorInfo getInfo() { 33 | return tensorInfo; 34 | } 35 | 36 | private final NoSuchElementException fail() { 37 | return new NoSuchElementException("Invalid access of OnnxTensor with type " + tensorInfo.getType()); 38 | } 39 | 40 | @Override 41 | public ByteBuffer getByteBuffer() { 42 | throw fail(); 43 | } 44 | 45 | @Override 46 | public ShortBuffer getShortBuffer() { 47 | throw fail(); 48 | } 49 | 50 | @Override 51 | public IntBuffer getIntBuffer() { 52 | throw fail(); 53 | } 54 | 55 | @Override 56 | public LongBuffer getLongBuffer() { 57 | throw fail(); 58 | } 59 | 60 | @Override 61 | public FloatBuffer getFloatBuffer() { 62 | throw fail(); 63 | } 64 | 65 | @Override 66 | public DoubleBuffer getDoubleBuffer() { 67 | throw fail(); 68 | } 69 | 70 | @Override 71 | public String[] getStringBuffer() { 72 | throw fail(); 73 | } 74 | 75 | abstract void putScalars(Collection scalars); 76 | 77 | abstract void getScalars(Stream scalars); 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorIntImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.nio.IntBuffer; 10 | import java.util.Collection; 11 | import java.util.function.Function; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxTensorIntImpl extends OnnxTensorBufferImpl { 15 | 16 | private static final Function CONVERT = ByteBuffer::asIntBuffer; 17 | 18 | OnnxTensorIntImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 19 | super(tensorInfo, valueContext, ortValueAddress, CONVERT); 20 | } 21 | 22 | @Override 23 | public IntBuffer getIntBuffer() { 24 | return buffer; 25 | } 26 | 27 | @Override 28 | public void putScalars(Collection scalars) { 29 | for (OnnxTensorImpl scalar : scalars) { 30 | buffer.put(scalar.getIntBuffer().flip().get()); 31 | } 32 | } 33 | 34 | @Override 35 | public void getScalars(Stream scalars) { 36 | scalars.forEach(scalar -> scalar.getIntBuffer().put(buffer.get()).flip()); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorLongImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.nio.LongBuffer; 10 | import java.util.Collection; 11 | import java.util.function.Function; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxTensorLongImpl extends OnnxTensorBufferImpl { 15 | 16 | private static final Function CONVERT = ByteBuffer::asLongBuffer; 17 | 18 | OnnxTensorLongImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 19 | super(tensorInfo, valueContext, ortValueAddress, CONVERT); 20 | } 21 | 22 | @Override 23 | public LongBuffer getLongBuffer() { 24 | return buffer; 25 | } 26 | 27 | @Override 28 | public void putScalars(Collection scalars) { 29 | for (OnnxTensorImpl scalar : scalars) { 30 | buffer.put(scalar.getLongBuffer().flip().get()); 31 | } 32 | } 33 | 34 | @Override 35 | public void getScalars(Stream scalars) { 36 | scalars.forEach(scalar -> scalar.getLongBuffer().put(buffer.get()).flip()); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorShortImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.nio.ByteBuffer; 9 | import java.nio.ShortBuffer; 10 | import java.util.Collection; 11 | import java.util.function.Function; 12 | import java.util.stream.Stream; 13 | 14 | final class OnnxTensorShortImpl extends OnnxTensorBufferImpl { 15 | 16 | private static final Function CONVERT = ByteBuffer::asShortBuffer; 17 | 18 | OnnxTensorShortImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 19 | super(tensorInfo, valueContext, ortValueAddress, CONVERT); 20 | } 21 | 22 | @Override 23 | public ShortBuffer getShortBuffer() { 24 | return buffer; 25 | } 26 | 27 | @Override 28 | public void putScalars(Collection scalars) { 29 | for (OnnxTensorImpl scalar : scalars) { 30 | buffer.put(scalar.getShortBuffer().flip().get()); 31 | } 32 | } 33 | 34 | @Override 35 | public void getScalars(Stream scalars) { 36 | scalars.forEach(scalar -> scalar.getShortBuffer().put(buffer.get()).flip()); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR; 8 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; 9 | 10 | import java.lang.foreign.Arena; 11 | import java.lang.foreign.MemorySegment; 12 | import java.util.Arrays; 13 | import java.util.Collection; 14 | import java.util.Iterator; 15 | import java.util.stream.Stream; 16 | 17 | final class OnnxTensorStringImpl extends OnnxTensorImpl { 18 | 19 | private final String[] buffer; 20 | 21 | OnnxTensorStringImpl(TensorInfoImpl tensorInfo, ValueContext valueContext, MemorySegment ortValueAddress) { 22 | super(tensorInfo, valueContext); 23 | this.buffer = new String[Math.toIntExact(tensorInfo.getElementCount())]; 24 | if (ortValueAddress != null) { 25 | ApiImpl api = valueContext.api(); 26 | Arena arena = valueContext.arena(); 27 | int numOutputs = buffer.length; 28 | for (int i = 0; i < numOutputs; i++) { 29 | final long index = i; 30 | long length = api.extractLong( 31 | arena, out -> api.GetStringTensorElementLength.apply(ortValueAddress, index, out)); 32 | // add a byte for the null termination 33 | MemorySegment output = arena.allocate(C_CHAR, length + 1); 34 | api.checkStatus(api.GetStringTensorElement.apply(ortValueAddress, length, index, output)); 35 | buffer[i] = output.getString(0); 36 | } 37 | } 38 | } 39 | 40 | @Override 41 | public String toString() { 42 | return "{OnnxTensor: info=" + tensorInfo + ", buffer=" + Arrays.toString(buffer) + "}"; 43 | } 44 | 45 | @Override 46 | public String[] getStringBuffer() { 47 | return buffer; 48 | } 49 | 50 | @Override 51 | public MemorySegment toNative() { 52 | int numOutputs = buffer.length; 53 | ApiImpl api = valueContext.api(); 54 | Arena arena = valueContext.arena(); 55 | MemorySegment stringArray = arena.allocate(C_POINTER, numOutputs); 56 | for (int i = 0; i < numOutputs; i++) { 57 | stringArray.setAtIndex(C_POINTER, i, arena.allocateFrom(buffer[i])); 58 | } 59 | MemorySegment tensor = api.create( 60 | arena, 61 | out -> api.CreateTensorAsOrtValue.apply( 62 | valueContext.ortAllocatorAddress(), 63 | tensorInfo.shapeData, 64 | tensorInfo.getShape().size(), 65 | tensorInfo.getType().getNumber(), 66 | out)); 67 | api.checkStatus(api.FillStringTensor.apply(tensor, stringArray, numOutputs)); 68 | return tensor; 69 | } 70 | 71 | @Override 72 | public void putScalars(Collection scalars) { 73 | int i = 0; 74 | for (OnnxTensorImpl scalar : scalars) { 75 | buffer[i++] = scalar.getStringBuffer()[0]; 76 | } 77 | } 78 | 79 | @Override 80 | public void getScalars(Stream scalars) { 81 | int i = 0; 82 | Iterator iter = scalars.iterator(); 83 | while (iter.hasNext()) { 84 | OnnxTensorImpl scalar = iter.next(); 85 | scalar.getStringBuffer()[0] = buffer[i++]; 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | /** 8 | * A type from ONNX. 9 | * 10 | * @since 1.0.0 11 | */ 12 | public enum OnnxType { 13 | UNKNOWN(0), 14 | TENSOR(1), 15 | SEQUENCE(2), 16 | MAP(3), 17 | OPAQUE(4), 18 | SPARSETENSOR(5), 19 | OPTIONAL(6); 20 | 21 | private final int number; 22 | 23 | private OnnxType(int number) { 24 | this.number = number; 25 | } 26 | 27 | public int getNumber() { 28 | return number; 29 | } 30 | 31 | /** 32 | * Get a level based off its internal number. 33 | * @param number the internal number of the level 34 | * @return the level, UNKNOWN if not found 35 | */ 36 | public static final OnnxType forNumber(int number) { 37 | return switch (number) { 38 | case 0 -> UNKNOWN; 39 | case 1 -> TENSOR; 40 | case 2 -> SEQUENCE; 41 | case 3 -> MAP; 42 | case 4 -> OPAQUE; 43 | case 5 -> SPARSETENSOR; 44 | case 6 -> OPTIONAL; 45 | default -> UNKNOWN; 46 | }; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxTypedMap.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * A type-safe map of an {@link OnnxValue}. Extends {@link java.util.Map} to provide an UNMODIFIABLE view. Use {@link #set(Object)} to populate the map. 11 | * 12 | * @param key type 13 | * @since 1.0.0 14 | */ 15 | public interface OnnxTypedMap extends Map { 16 | 17 | /** 18 | * 19 | * @param key 20 | * @return A new OnnxValue of correct type, which the key is mapped towards. 21 | */ 22 | OnnxValue set(K key); 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxValue.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.NoSuchElementException; 8 | 9 | /** 10 | * A representation of a value provided as an input or output. Use {@link #getType()} to find out which more specific view is present. A {@link NoSuchElementException} will be thrown if a view does not exist for this instance's type. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface OnnxValue { 15 | 16 | /** 17 | * Get type 18 | * @return type enum value 19 | */ 20 | OnnxType getType(); 21 | 22 | /** 23 | * A view as a tensor 24 | * @return tensor 25 | * @throws NoSuchElementException if the value is not a tensor 26 | */ 27 | OnnxTensor asTensor(); 28 | 29 | /** 30 | * A view as a sequence 31 | * @return sequence 32 | * @throws NoSuchElementException if the value is not a sequence 33 | */ 34 | OnnxSequence asSequence(); 35 | 36 | /** 37 | * A view as a map 38 | * @return map 39 | * @throws NoSuchElementException if the value is not a map 40 | */ 41 | OnnxMap asMap(); 42 | 43 | // OnnxOpaque asOpaque(); 44 | 45 | // OnnxSparseTensor asSparseTensor(); 46 | 47 | // OnnxOptional asOptional(); 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.NoSuchElementException; 9 | 10 | abstract class OnnxValueImpl implements OnnxValue { 11 | 12 | protected final OnnxType type; 13 | protected final ValueContext valueContext; 14 | 15 | protected OnnxValueImpl(OnnxType type, ValueContext valueContext) { 16 | this.type = type; 17 | this.valueContext = valueContext; 18 | } 19 | 20 | @Override 21 | public final OnnxType getType() { 22 | return type; 23 | } 24 | 25 | @Override 26 | public OnnxTensor asTensor() { 27 | throw new NoSuchElementException("OnnxValue is not a tensor"); 28 | } 29 | 30 | @Override 31 | public OnnxSequence asSequence() { 32 | throw new NoSuchElementException("OnnxValue is not a sequence"); 33 | } 34 | 35 | @Override 36 | public OnnxMap asMap() { 37 | throw new NoSuchElementException("OnnxValue is not a map"); 38 | } 39 | 40 | // @Override 41 | // public OnnxOpaque asOpaque() { 42 | // throw new NoSuchElementException("OnnxValue is not an opaque"); 43 | // } 44 | // 45 | // @Override 46 | // public OnnxSparseTensor asSparseTensor() { 47 | // throw new NoSuchElementException("OnnxValue is not a sparse tensor"); 48 | // } 49 | // 50 | // @Override 51 | // public OnnxOptional asOptional() { 52 | // throw new NoSuchElementException("OnnxValue is not an optional"); 53 | // } 54 | 55 | abstract MemorySegment toNative(); 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/OpaqueInfo.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | interface OpaqueInfo { 8 | 9 | String getDomain(); 10 | 11 | String getName(); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/Session.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.io.IOException; 8 | import java.nio.ByteBuffer; 9 | import java.nio.file.Path; 10 | import java.util.Collections; 11 | import java.util.Map; 12 | 13 | /** 14 | * An evaluation session loaded from an ONNX file or bytes. This class is thread safe. 15 | * 16 | * @since 1.0.0 17 | */ 18 | public interface Session extends AutoCloseable { 19 | 20 | /** 21 | * Release native resources. 22 | */ 23 | @Override 24 | void close(); 25 | 26 | /** 27 | * Get additional information about this session. 28 | * @return a read-only view of the session's metadata 29 | */ 30 | ModelMetadata getModelMetadata(); 31 | 32 | /** 33 | * Get the session's named inputs. 34 | * @return a named collection to access by name or index 35 | */ 36 | NamedCollection getInputs(); 37 | 38 | /** 39 | * Get the session's named outputs. 40 | * @return a named collection to access by name or index 41 | */ 42 | NamedCollection getOutputs(); 43 | 44 | /** 45 | * Get a session's named initializers. 46 | * @return a named collection to access by name or index 47 | */ 48 | NamedCollection getOverridableInitializers(); 49 | 50 | /** 51 | * A high precision monotonically increasing clock for use in profiling. 52 | * @return nanoseconds 53 | */ 54 | long getProfilingStartTimeInNs(); 55 | 56 | /** 57 | * Stop profile and get the resulting JSON file's location. 58 | * @return a path to the output 59 | */ 60 | Path endProfiling(); 61 | 62 | /** 63 | * Create a new transaction. 64 | * @return a builder 65 | */ 66 | Transaction.Builder newTransaction(); 67 | 68 | /** 69 | * Create a new I/O Binding. 70 | * @return a builder 71 | * @since 1.4.0 72 | */ 73 | IoBinding.Builder newIoBinding(); 74 | 75 | /** 76 | * Set DynamicOptions for EPs (Execution Providers) 77 | * @param epDynamicOptions 78 | * @since 2.0.0 79 | */ 80 | void setEpDynamicOptions(Map epDynamicOptions); 81 | 82 | /** 83 | * A builder of a {@link Session}. Must provide either bytes or a path. 84 | * 85 | * @since 1.0.0 86 | */ 87 | public interface Builder { 88 | /** 89 | * Load the session from a file. 90 | * @param path 91 | * @return the builder 92 | */ 93 | Builder setPath(Path path); 94 | 95 | /** 96 | * Load the session from a byte array. 97 | * @param bytes 98 | * @return the builder 99 | */ 100 | Builder setByteArray(byte[] bytes); 101 | 102 | /** 103 | * Load the session from a buffer. 104 | * @param byteBuffer 105 | * @return the builder 106 | */ 107 | Builder setByteBuffer(ByteBuffer byteBuffer); 108 | 109 | /** 110 | * Use the global thread pool. 111 | * @return the builder 112 | */ 113 | Builder disablePerSessionThreads(); 114 | 115 | /** 116 | * Set the prefix for the unique profile output files. 117 | * @param prefix the prefix from which a timestamp will be appended 118 | * @return the builder 119 | */ 120 | Builder setProfilingOutputPath(Path prefix); 121 | 122 | /** 123 | * Set the severity for logging for this entire session. 124 | * Can override the environment's logger's severity. 125 | * @param level 126 | * @return the builder 127 | */ 128 | Builder setLogSeverityLevel(OnnxRuntimeLoggingLevel level); 129 | 130 | /** 131 | * Set the verbosity for logging for this entire session. 132 | * Can override the environment's logger's verbosity. 133 | * @param level 134 | * @return the builder 135 | */ 136 | Builder setLogVerbosityLevel(int level); 137 | 138 | /** 139 | * Set custom parameters for this session. 140 | * @param config 141 | * @return the builder 142 | */ 143 | Builder setConfigMap(Map config); 144 | 145 | /** 146 | * Set sequential vs parallel execution. 147 | * @param mode 148 | * @return the builder 149 | */ 150 | Builder setExecutionMode(OnnxRuntimeExecutionMode mode); 151 | 152 | /** 153 | * Set the number of inter op threads in this session's own thread pool. 154 | * @param numThreads 155 | * @return the builder 156 | */ 157 | Builder setInterOpNumThreads(int numThreads); 158 | 159 | /** 160 | * Set the number of intra op threads in this session's own thread pool. 161 | * @param numThreads 162 | * @return the builder 163 | */ 164 | Builder setIntraOpNumThreads(int numThreads); 165 | 166 | /** 167 | * Set the logging identifier. 168 | * 169 | * @param id the identifier 170 | * @return the builder 171 | */ 172 | Builder setLogId(String id); 173 | 174 | /** 175 | * Enable or disable memory pattern optimization. 176 | * @param memoryPatternOptimization 177 | * @return the builder 178 | */ 179 | Builder setMemoryPatternOptimization(boolean memoryPatternOptimization); 180 | 181 | /** 182 | * Set the desired level of optimization. 183 | * @param level 184 | * @return the builder 185 | */ 186 | Builder setOptimizationLevel(OnnxRuntimeOptimizationLevel level); 187 | 188 | /** 189 | * Set the location of the optimized model. 190 | * @param optimizedFile 191 | * @return the builder 192 | */ 193 | Builder setOptimizationOutputPath(Path optimizedFile); 194 | 195 | /** 196 | * Enable deterministic compute. 197 | * @param value 198 | * @return the builder 199 | * @since 1.5.0 200 | */ 201 | Builder setDeterministicCompute(boolean value); 202 | 203 | /** 204 | * Add an execution provider with custom properties. 205 | * @param provider 206 | * @param properties 207 | * @return the builder 208 | */ 209 | Builder addProvider(ExecutionProvider provider, Map properties); 210 | 211 | default Builder addProvider(ExecutionProvider provider) { 212 | return addProvider(provider, Collections.emptyMap()); 213 | } 214 | 215 | /** 216 | * Add an execution provider with all default properties. 217 | * @param path 218 | * @return the builder. 219 | */ 220 | Builder addCustomOpsLibrary(Path path); 221 | 222 | /** 223 | * Construct a {@link Session}. 224 | * @return a new instance 225 | * @throws IOException 226 | */ 227 | Session build() throws IOException; 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/TensorInfo.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * A description of a tensor: type and shape. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface TensorInfo { 15 | 16 | OnnxTensorElementDataType getType(); 17 | 18 | List getShape(); 19 | 20 | /** 21 | * Get the number of elements in the buffer. 22 | * @return derived from shape 23 | */ 24 | long getElementCount(); 25 | 26 | /** 27 | * Get the size of the buffer from a memory standpoint. 28 | * @return the number of bytes the buffer takes up, based on shape and bytes for the type. 29 | */ 30 | long getByteCount(); 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; 8 | 9 | import java.lang.foreign.Arena; 10 | import java.lang.foreign.MemorySegment; 11 | import java.util.ArrayList; 12 | import java.util.Collections; 13 | import java.util.List; 14 | 15 | final class TensorInfoImpl implements TensorInfo { 16 | 17 | // TODO: symbolic dims 18 | private final OnnxTensorElementDataType type; 19 | final MemorySegment shapeData; 20 | private final List shape; 21 | private final long elementCount; 22 | 23 | TensorInfoImpl(OnnxTensorElementDataType type, MemorySegment shapeData, int dimCount, long elementCount) { 24 | this.type = type; 25 | List shape = new ArrayList<>(dimCount); 26 | for (int i = 0; i < dimCount; i++) { 27 | shape.add(shapeData.getAtIndex(C_LONG, i)); 28 | } 29 | this.shape = Collections.unmodifiableList(shape); 30 | this.shapeData = shapeData; 31 | this.elementCount = elementCount; 32 | } 33 | 34 | static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, Arena arena) { 35 | MemorySegment shapeData = arena.allocateFrom(C_LONG, new long[] {elementCount}); 36 | return new TensorInfoImpl(type, shapeData, 1, elementCount); 37 | } 38 | 39 | @Override 40 | public OnnxTensorElementDataType getType() { 41 | return type; 42 | } 43 | 44 | @Override 45 | public List getShape() { 46 | return shape; 47 | } 48 | 49 | @Override 50 | public long getElementCount() { 51 | return elementCount; 52 | } 53 | 54 | @Override 55 | public String toString() { 56 | return "tensor(" + type + shape + ")"; 57 | } 58 | 59 | @Override 60 | public long getByteCount() { 61 | // TODO: handle missing valueLayout 62 | return elementCount * type.getValueLayout().byteSize(); 63 | } 64 | 65 | final OnnxTensorImpl newValue(ValueContext valueContext, MemorySegment ortValueAddress) { 66 | switch (type) { 67 | case BOOL: 68 | case INT8: 69 | return new OnnxTensorByteImpl(this, valueContext, ortValueAddress); 70 | case INT16: 71 | return new OnnxTensorShortImpl(this, valueContext, ortValueAddress); 72 | case INT32: 73 | return new OnnxTensorIntImpl(this, valueContext, ortValueAddress); 74 | case INT64: 75 | return new OnnxTensorLongImpl(this, valueContext, ortValueAddress); 76 | case FLOAT: 77 | return new OnnxTensorFloatImpl(this, valueContext, ortValueAddress); 78 | case DOUBLE: 79 | return new OnnxTensorDoubleImpl(this, valueContext, ortValueAddress); 80 | case STRING: 81 | return new OnnxTensorStringImpl(this, valueContext, ortValueAddress); 82 | default: 83 | throw new UnsupportedOperationException("OnnxTensor with type " + type + " is not supported"); 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/Transaction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.Map; 8 | 9 | /** 10 | * A representation of a model evaluation. Should NOT be reused. This class is NOT thread-safe. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface Transaction extends AutoCloseable { 15 | /** 16 | * Add an input and get an OnnxValue to populate. 17 | * @param name 18 | * @return the value to be populated 19 | */ 20 | OnnxValue addInput(String name); 21 | 22 | /** 23 | * Add an input and get an OnnxValue to populate. 24 | * @param index 25 | * @return the value to be populated 26 | */ 27 | OnnxValue addInput(int index); 28 | 29 | /** 30 | * Request a specific output to be produced. 31 | * @param name 32 | */ 33 | void addOutput(String name); 34 | 35 | /** 36 | * Request a specific output to be produced. 37 | * @param index 38 | */ 39 | void addOutput(int index); 40 | 41 | /** 42 | * Run the model evaluation. 43 | * 44 | * @return the results 45 | */ 46 | NamedCollection run(); 47 | 48 | /** 49 | * Frees the native resources (typically buffers) associated with this transaction. 50 | */ 51 | @Override 52 | void close(); 53 | 54 | // void cancel(); 55 | 56 | /** 57 | * A builder of a {@link Transaction}. Should NOT be reused. This class is NOT thread-safe. 58 | * 59 | * @since 1.0.0 60 | */ 61 | public interface Builder { 62 | /** 63 | * Set the severity for logging for this specific transaction. 64 | * Can override the environment's or session's logger's severity. 65 | * @param level 66 | * @return the builder 67 | */ 68 | Builder setLogSeverityLevel(OnnxRuntimeLoggingLevel level); 69 | 70 | /** 71 | * Set the verbosity for logging for this specific transaction. 72 | * Can override the environment's or session's logger's verbosity. 73 | * @param level 74 | * @return the builder 75 | */ 76 | Builder setLogVerbosityLevel(int level); 77 | 78 | /** 79 | * Set the run tag (which is the logger id) 80 | * @param runTag 81 | * @return the builder 82 | */ 83 | Builder setRunTag(String runTag); 84 | 85 | /** 86 | * Set custom parameters for this transaction. 87 | * @param config 88 | * @return the builder 89 | */ 90 | Builder setConfigMap(Map config); 91 | 92 | /** 93 | * Construct a {@link Transaction}. 94 | * 95 | * @return a new instance 96 | */ 97 | Transaction build(); 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/TypeInfo.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.util.NoSuchElementException; 8 | 9 | /** 10 | * A description of the type of an input or output. Use {@link #getType()} to determine what type of additional information is present. A {@link NoSuchElementException} will be thrown from getters which are not valid for the instance's type. 11 | * 12 | * @since 1.0.0 13 | */ 14 | public interface TypeInfo { 15 | 16 | OnnxType getType(); 17 | 18 | TensorInfo getTensorInfo(); 19 | 20 | MapInfo getMapInfo(); 21 | 22 | TypeInfo getSequenceInfo(); 23 | 24 | // TypeInfo getOptionalInfo(); 25 | 26 | // OpaqueInfo getOpaqueInfo(); 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/TypeInfoImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; 8 | 9 | import java.lang.foreign.Arena; 10 | import java.lang.foreign.MemorySegment; 11 | import java.util.NoSuchElementException; 12 | 13 | final class TypeInfoImpl implements TypeInfo { 14 | 15 | // TODO: denotation 16 | private final OnnxType type; 17 | private final TensorInfoImpl tensorInfo; 18 | private final MapInfoImpl mapInfo; 19 | private final TypeInfoImpl sequenceInfo; 20 | 21 | TypeInfoImpl(ApiImpl api, MemorySegment typeInfo, Arena arena, Arena sessionArena, MemorySegment ortAllocator) { 22 | try { 23 | this.type = 24 | OnnxType.forNumber(api.extractInt(arena, out -> api.GetOnnxTypeFromTypeInfo.apply(typeInfo, out))); 25 | TensorInfoImpl tensorInfo = null; 26 | MapInfoImpl mapInfo = null; 27 | TypeInfoImpl sequenceInfo = null; 28 | 29 | if (type == OnnxType.TENSOR || type == OnnxType.SPARSETENSOR) { 30 | MemorySegment ortTensorInfo = 31 | api.create(arena, out -> api.CastTypeInfoToTensorInfo.apply(typeInfo, out)); 32 | OnnxTensorElementDataType dataType = OnnxTensorElementDataType.forNumber( 33 | api.extractInt(arena, out -> api.GetTensorElementType.apply(ortTensorInfo, out))); 34 | int dimCount = api.extractInt(arena, out -> api.GetDimensionsCount.apply(ortTensorInfo, out)); 35 | MemorySegment dims = sessionArena.allocate(C_LONG, dimCount); 36 | api.checkStatus(api.GetDimensions.apply(ortTensorInfo, dims, dimCount)); 37 | long elementCount = 38 | api.extractInt(arena, out -> api.GetTensorShapeElementCount.apply(ortTensorInfo, out)); 39 | tensorInfo = new TensorInfoImpl(dataType, dims, dimCount, elementCount); 40 | } else if (type == OnnxType.MAP) { 41 | MemorySegment ortMapInfo = api.create(arena, out -> api.CastTypeInfoToMapTypeInfo.apply(typeInfo, out)); 42 | OnnxTensorElementDataType keyType = OnnxTensorElementDataType.forNumber( 43 | api.extractInt(arena, out -> api.GetMapKeyType.apply(ortMapInfo, out))); 44 | MemorySegment valueTypeAddress = api.create(arena, out -> api.GetMapValueType.apply(ortMapInfo, out)); 45 | mapInfo = new MapInfoImpl( 46 | keyType, new TypeInfoImpl(api, valueTypeAddress, arena, sessionArena, ortAllocator)); 47 | } else if (type == OnnxType.SEQUENCE) { 48 | MemorySegment ortSequenceInfo = 49 | api.create(arena, out -> api.CastTypeInfoToSequenceTypeInfo.apply(typeInfo, out)); 50 | MemorySegment valueTypeAddress = 51 | api.create(arena, out -> api.GetSequenceElementType.apply(ortSequenceInfo, out)); 52 | sequenceInfo = new TypeInfoImpl(api, valueTypeAddress, arena, sessionArena, ortAllocator); 53 | } else { 54 | throw new UnsupportedOperationException("unsupported type: " + type); 55 | } 56 | this.tensorInfo = tensorInfo; 57 | this.mapInfo = mapInfo; 58 | this.sequenceInfo = sequenceInfo; 59 | } finally { 60 | api.ReleaseTypeInfo.apply(typeInfo); 61 | } 62 | } 63 | 64 | @Override 65 | public OnnxType getType() { 66 | return type; 67 | } 68 | 69 | @Override 70 | public TensorInfoImpl getTensorInfo() { 71 | if (tensorInfo == null) { 72 | throw new NoSuchElementException("tensor"); 73 | } 74 | return tensorInfo; 75 | } 76 | 77 | @Override 78 | public MapInfoImpl getMapInfo() { 79 | if (mapInfo == null) { 80 | throw new NoSuchElementException("map"); 81 | } 82 | return mapInfo; 83 | } 84 | 85 | @Override 86 | public String toString() { 87 | if (tensorInfo != null) { 88 | return tensorInfo.toString(); 89 | } else if (mapInfo != null) { 90 | return mapInfo.toString(); 91 | } else if (sequenceInfo != null) { 92 | return "sequence(" + sequenceInfo + ")"; 93 | } 94 | return "{TypeInfo: type=" + type + "}"; 95 | } 96 | 97 | @Override 98 | public TypeInfoImpl getSequenceInfo() { 99 | if (sequenceInfo == null) { 100 | throw new NoSuchElementException("sequence"); 101 | } 102 | return sequenceInfo; 103 | } 104 | 105 | final OnnxValueImpl newValue(ValueContext valueContext, MemorySegment ortValueAddress) { 106 | switch (type) { 107 | case TENSOR: 108 | return tensorInfo.newValue(valueContext, ortValueAddress); 109 | case SEQUENCE: 110 | return new OnnxSequenceImpl(sequenceInfo, valueContext, ortValueAddress); 111 | case MAP: 112 | return mapInfo.newValue(valueContext, ortValueAddress); 113 | // case OPAQUE: 114 | // return new OnnxOpaqueImpl(typeInfo.getOpaqueInfo()); 115 | // case OPTIONAL: 116 | // return new OnnxOptionalImpl(typeInfo.getOptionalInfo()); 117 | default: 118 | throw new UnsupportedOperationException("OnnxValue with type " + type + " is not supported"); 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/ValueContext.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import java.lang.foreign.Arena; 8 | import java.lang.foreign.MemorySegment; 9 | import java.util.List; 10 | 11 | record ValueContext( 12 | ApiImpl api, 13 | Arena arena, 14 | MemorySegment ortAllocatorAddress, 15 | MemorySegment memoryInfoAddress, 16 | List closeables) {} 17 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * A Java binding of Microsoft's ONNX Runtime. 3 | * 4 | * Use {@link com.jyuzawa.onnxruntime.OnnxRuntime#getApi()} to access this library. 5 | * 6 | * @since 1.0.0 7 | */ 8 | package com.jyuzawa.onnxruntime; 9 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/OrtCustomCreateThreadFn.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void *, OrtThreadWorkerFn, void *) 19 | * } 20 | */ 21 | public class OrtCustomCreateThreadFn { 22 | 23 | OrtCustomCreateThreadFn() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | MemorySegment apply( 32 | MemorySegment ort_custom_thread_creation_options, 33 | MemorySegment ort_thread_worker_fn, 34 | MemorySegment ort_worker_fn_param); 35 | } 36 | 37 | private static final FunctionDescriptor $DESC = FunctionDescriptor.of( 38 | onnxruntime_all_h.C_POINTER, 39 | onnxruntime_all_h.C_POINTER, 40 | onnxruntime_all_h.C_POINTER, 41 | onnxruntime_all_h.C_POINTER); 42 | 43 | /** 44 | * The descriptor of this function pointer 45 | */ 46 | public static FunctionDescriptor descriptor() { 47 | return $DESC; 48 | } 49 | 50 | private static final MethodHandle UP$MH = 51 | onnxruntime_all_h.upcallHandle(OrtCustomCreateThreadFn.Function.class, "apply", $DESC); 52 | 53 | /** 54 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 55 | * The lifetime of the returned segment is managed by {@code arena} 56 | */ 57 | public static MemorySegment allocate(OrtCustomCreateThreadFn.Function fi, Arena arena) { 58 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 59 | } 60 | 61 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 62 | 63 | /** 64 | * Invoke the upcall stub {@code funcPtr}, with given parameters 65 | */ 66 | public static MemorySegment invoke( 67 | MemorySegment funcPtr, 68 | MemorySegment ort_custom_thread_creation_options, 69 | MemorySegment ort_thread_worker_fn, 70 | MemorySegment ort_worker_fn_param) { 71 | try { 72 | return (MemorySegment) DOWN$MH.invokeExact( 73 | funcPtr, ort_custom_thread_creation_options, ort_thread_worker_fn, ort_worker_fn_param); 74 | } catch (Throwable ex$) { 75 | throw new AssertionError("should not reach here", ex$); 76 | } 77 | } 78 | 79 | /** 80 | * Get an implementation of the function interface from a function pointer. 81 | */ 82 | public static OrtCustomCreateThreadFn.Function function(MemorySegment funcPtr) { 83 | return (ort_custom_thread_creation_options, ort_thread_worker_fn, ort_worker_fn_param) -> 84 | invoke(funcPtr, ort_custom_thread_creation_options, ort_thread_worker_fn, ort_worker_fn_param); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/OrtCustomHandleType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * struct OrtCustomHandleType { 19 | * char __place_holder; 20 | * } 21 | * } 22 | */ 23 | public class OrtCustomHandleType { 24 | 25 | OrtCustomHandleType() { 26 | // Should not be called directly 27 | } 28 | 29 | private static final GroupLayout $LAYOUT = MemoryLayout.structLayout( 30 | onnxruntime_all_h.C_CHAR.withName("__place_holder")) 31 | .withName("OrtCustomHandleType"); 32 | 33 | /** 34 | * The layout of this struct 35 | */ 36 | public static final GroupLayout layout() { 37 | return $LAYOUT; 38 | } 39 | 40 | private static final OfByte __place_holder$LAYOUT = (OfByte) $LAYOUT.select(groupElement("__place_holder")); 41 | 42 | /** 43 | * Layout for field: 44 | * {@snippet lang=c : 45 | * char __place_holder 46 | * } 47 | */ 48 | public static final OfByte __place_holder$layout() { 49 | return __place_holder$LAYOUT; 50 | } 51 | 52 | private static final long __place_holder$OFFSET = 0; 53 | 54 | /** 55 | * Offset for field: 56 | * {@snippet lang=c : 57 | * char __place_holder 58 | * } 59 | */ 60 | public static final long __place_holder$offset() { 61 | return __place_holder$OFFSET; 62 | } 63 | 64 | /** 65 | * Getter for field: 66 | * {@snippet lang=c : 67 | * char __place_holder 68 | * } 69 | */ 70 | public static byte __place_holder(MemorySegment struct) { 71 | return struct.get(__place_holder$LAYOUT, __place_holder$OFFSET); 72 | } 73 | 74 | /** 75 | * Setter for field: 76 | * {@snippet lang=c : 77 | * char __place_holder 78 | * } 79 | */ 80 | public static void __place_holder(MemorySegment struct, byte fieldValue) { 81 | struct.set(__place_holder$LAYOUT, __place_holder$OFFSET, fieldValue); 82 | } 83 | 84 | /** 85 | * Obtains a slice of {@code arrayParam} which selects the array element at {@code index}. 86 | * The returned segment has address {@code arrayParam.address() + index * layout().byteSize()} 87 | */ 88 | public static MemorySegment asSlice(MemorySegment array, long index) { 89 | return array.asSlice(layout().byteSize() * index); 90 | } 91 | 92 | /** 93 | * The size (in bytes) of this struct 94 | */ 95 | public static long sizeof() { 96 | return layout().byteSize(); 97 | } 98 | 99 | /** 100 | * Allocate a segment of size {@code layout().byteSize()} using {@code allocator} 101 | */ 102 | public static MemorySegment allocate(SegmentAllocator allocator) { 103 | return allocator.allocate(layout()); 104 | } 105 | 106 | /** 107 | * Allocate an array of size {@code elementCount} using {@code allocator}. 108 | * The returned segment has size {@code elementCount * layout().byteSize()}. 109 | */ 110 | public static MemorySegment allocateArray(long elementCount, SegmentAllocator allocator) { 111 | return allocator.allocate(MemoryLayout.sequenceLayout(elementCount, layout())); 112 | } 113 | 114 | /** 115 | * Reinterprets {@code addr} using target {@code arena} and {@code cleanupAction} (if any). 116 | * The returned segment has size {@code layout().byteSize()} 117 | */ 118 | public static MemorySegment reinterpret(MemorySegment addr, Arena arena, Consumer cleanup) { 119 | return reinterpret(addr, 1, arena, cleanup); 120 | } 121 | 122 | /** 123 | * Reinterprets {@code addr} using target {@code arena} and {@code cleanupAction} (if any). 124 | * The returned segment has size {@code elementCount * layout().byteSize()} 125 | */ 126 | public static MemorySegment reinterpret( 127 | MemorySegment addr, long elementCount, Arena arena, Consumer cleanup) { 128 | return addr.reinterpret(layout().byteSize() * elementCount, arena, cleanup); 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/OrtCustomJoinThreadFn.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle) 19 | * } 20 | */ 21 | public class OrtCustomJoinThreadFn { 22 | 23 | OrtCustomJoinThreadFn() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | void apply(MemorySegment ort_custom_thread_handle); 32 | } 33 | 34 | private static final FunctionDescriptor $DESC = FunctionDescriptor.ofVoid(onnxruntime_all_h.C_POINTER); 35 | 36 | /** 37 | * The descriptor of this function pointer 38 | */ 39 | public static FunctionDescriptor descriptor() { 40 | return $DESC; 41 | } 42 | 43 | private static final MethodHandle UP$MH = 44 | onnxruntime_all_h.upcallHandle(OrtCustomJoinThreadFn.Function.class, "apply", $DESC); 45 | 46 | /** 47 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 48 | * The lifetime of the returned segment is managed by {@code arena} 49 | */ 50 | public static MemorySegment allocate(OrtCustomJoinThreadFn.Function fi, Arena arena) { 51 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 52 | } 53 | 54 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 55 | 56 | /** 57 | * Invoke the upcall stub {@code funcPtr}, with given parameters 58 | */ 59 | public static void invoke(MemorySegment funcPtr, MemorySegment ort_custom_thread_handle) { 60 | try { 61 | DOWN$MH.invokeExact(funcPtr, ort_custom_thread_handle); 62 | } catch (Throwable ex$) { 63 | throw new AssertionError("should not reach here", ex$); 64 | } 65 | } 66 | 67 | /** 68 | * Get an implementation of the function interface from a function pointer. 69 | */ 70 | public static OrtCustomJoinThreadFn.Function function(MemorySegment funcPtr) { 71 | return (ort_custom_thread_handle) -> invoke(funcPtr, ort_custom_thread_handle); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/OrtLoggingFunction.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef void (*OrtLoggingFunction)(void *, OrtLoggingLevel, const char *, const char *, const char *, const char *) 19 | * } 20 | */ 21 | public class OrtLoggingFunction { 22 | 23 | OrtLoggingFunction() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | void apply( 32 | MemorySegment param, 33 | int severity, 34 | MemorySegment category, 35 | MemorySegment logid, 36 | MemorySegment code_location, 37 | MemorySegment message); 38 | } 39 | 40 | private static final FunctionDescriptor $DESC = FunctionDescriptor.ofVoid( 41 | onnxruntime_all_h.C_POINTER, 42 | onnxruntime_all_h.C_INT, 43 | onnxruntime_all_h.C_POINTER, 44 | onnxruntime_all_h.C_POINTER, 45 | onnxruntime_all_h.C_POINTER, 46 | onnxruntime_all_h.C_POINTER); 47 | 48 | /** 49 | * The descriptor of this function pointer 50 | */ 51 | public static FunctionDescriptor descriptor() { 52 | return $DESC; 53 | } 54 | 55 | private static final MethodHandle UP$MH = 56 | onnxruntime_all_h.upcallHandle(OrtLoggingFunction.Function.class, "apply", $DESC); 57 | 58 | /** 59 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 60 | * The lifetime of the returned segment is managed by {@code arena} 61 | */ 62 | public static MemorySegment allocate(OrtLoggingFunction.Function fi, Arena arena) { 63 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 64 | } 65 | 66 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 67 | 68 | /** 69 | * Invoke the upcall stub {@code funcPtr}, with given parameters 70 | */ 71 | public static void invoke( 72 | MemorySegment funcPtr, 73 | MemorySegment param, 74 | int severity, 75 | MemorySegment category, 76 | MemorySegment logid, 77 | MemorySegment code_location, 78 | MemorySegment message) { 79 | try { 80 | DOWN$MH.invokeExact(funcPtr, param, severity, category, logid, code_location, message); 81 | } catch (Throwable ex$) { 82 | throw new AssertionError("should not reach here", ex$); 83 | } 84 | } 85 | 86 | /** 87 | * Get an implementation of the function interface from a function pointer. 88 | */ 89 | public static OrtLoggingFunction.Function function(MemorySegment funcPtr) { 90 | return (param, severity, category, logid, code_location, message) -> 91 | invoke(funcPtr, param, severity, category, logid, code_location, message); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/OrtThreadWorkerFn.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef void (*OrtThreadWorkerFn)(void *) 19 | * } 20 | */ 21 | public class OrtThreadWorkerFn { 22 | 23 | OrtThreadWorkerFn() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | void apply(MemorySegment ort_worker_fn_param); 32 | } 33 | 34 | private static final FunctionDescriptor $DESC = FunctionDescriptor.ofVoid(onnxruntime_all_h.C_POINTER); 35 | 36 | /** 37 | * The descriptor of this function pointer 38 | */ 39 | public static FunctionDescriptor descriptor() { 40 | return $DESC; 41 | } 42 | 43 | private static final MethodHandle UP$MH = 44 | onnxruntime_all_h.upcallHandle(OrtThreadWorkerFn.Function.class, "apply", $DESC); 45 | 46 | /** 47 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 48 | * The lifetime of the returned segment is managed by {@code arena} 49 | */ 50 | public static MemorySegment allocate(OrtThreadWorkerFn.Function fi, Arena arena) { 51 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 52 | } 53 | 54 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 55 | 56 | /** 57 | * Invoke the upcall stub {@code funcPtr}, with given parameters 58 | */ 59 | public static void invoke(MemorySegment funcPtr, MemorySegment ort_worker_fn_param) { 60 | try { 61 | DOWN$MH.invokeExact(funcPtr, ort_worker_fn_param); 62 | } catch (Throwable ex$) { 63 | throw new AssertionError("should not reach here", ex$); 64 | } 65 | } 66 | 67 | /** 68 | * Get an implementation of the function interface from a function pointer. 69 | */ 70 | public static OrtThreadWorkerFn.Function function(MemorySegment funcPtr) { 71 | return (ort_worker_fn_param) -> invoke(funcPtr, ort_worker_fn_param); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/RegisterCustomOpsFn.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef OrtStatus *(*RegisterCustomOpsFn)(OrtSessionOptions *, const OrtApiBase *) 19 | * } 20 | */ 21 | public class RegisterCustomOpsFn { 22 | 23 | RegisterCustomOpsFn() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | MemorySegment apply(MemorySegment options, MemorySegment api); 32 | } 33 | 34 | private static final FunctionDescriptor $DESC = FunctionDescriptor.of( 35 | onnxruntime_all_h.C_POINTER, onnxruntime_all_h.C_POINTER, onnxruntime_all_h.C_POINTER); 36 | 37 | /** 38 | * The descriptor of this function pointer 39 | */ 40 | public static FunctionDescriptor descriptor() { 41 | return $DESC; 42 | } 43 | 44 | private static final MethodHandle UP$MH = 45 | onnxruntime_all_h.upcallHandle(RegisterCustomOpsFn.Function.class, "apply", $DESC); 46 | 47 | /** 48 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 49 | * The lifetime of the returned segment is managed by {@code arena} 50 | */ 51 | public static MemorySegment allocate(RegisterCustomOpsFn.Function fi, Arena arena) { 52 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 53 | } 54 | 55 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 56 | 57 | /** 58 | * Invoke the upcall stub {@code funcPtr}, with given parameters 59 | */ 60 | public static MemorySegment invoke(MemorySegment funcPtr, MemorySegment options, MemorySegment api) { 61 | try { 62 | return (MemorySegment) DOWN$MH.invokeExact(funcPtr, options, api); 63 | } catch (Throwable ex$) { 64 | throw new AssertionError("should not reach here", ex$); 65 | } 66 | } 67 | 68 | /** 69 | * Get an implementation of the function interface from a function pointer. 70 | */ 71 | public static RegisterCustomOpsFn.Function function(MemorySegment funcPtr) { 72 | return (options, api) -> invoke(funcPtr, options, api); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/com/jyuzawa/onnxruntime_extern/RunAsyncCallbackFn.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime_extern; 6 | 7 | import static java.lang.foreign.MemoryLayout.PathElement.*; 8 | import static java.lang.foreign.ValueLayout.*; 9 | 10 | import java.lang.foreign.*; 11 | import java.lang.invoke.*; 12 | import java.util.*; 13 | import java.util.function.*; 14 | import java.util.stream.*; 15 | 16 | /** 17 | * {@snippet lang=c : 18 | * typedef void (*RunAsyncCallbackFn)(void *, OrtValue **, size_t, OrtStatusPtr) 19 | * } 20 | */ 21 | public class RunAsyncCallbackFn { 22 | 23 | RunAsyncCallbackFn() { 24 | // Should not be called directly 25 | } 26 | 27 | /** 28 | * The function pointer signature, expressed as a functional interface 29 | */ 30 | public interface Function { 31 | void apply(MemorySegment user_data, MemorySegment outputs, long num_outputs, MemorySegment status); 32 | } 33 | 34 | private static final FunctionDescriptor $DESC = FunctionDescriptor.ofVoid( 35 | onnxruntime_all_h.C_POINTER, 36 | onnxruntime_all_h.C_POINTER, 37 | onnxruntime_all_h.C_LONG, 38 | onnxruntime_all_h.C_POINTER); 39 | 40 | /** 41 | * The descriptor of this function pointer 42 | */ 43 | public static FunctionDescriptor descriptor() { 44 | return $DESC; 45 | } 46 | 47 | private static final MethodHandle UP$MH = 48 | onnxruntime_all_h.upcallHandle(RunAsyncCallbackFn.Function.class, "apply", $DESC); 49 | 50 | /** 51 | * Allocates a new upcall stub, whose implementation is defined by {@code fi}. 52 | * The lifetime of the returned segment is managed by {@code arena} 53 | */ 54 | public static MemorySegment allocate(RunAsyncCallbackFn.Function fi, Arena arena) { 55 | return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, arena); 56 | } 57 | 58 | private static final MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC); 59 | 60 | /** 61 | * Invoke the upcall stub {@code funcPtr}, with given parameters 62 | */ 63 | public static void invoke( 64 | MemorySegment funcPtr, 65 | MemorySegment user_data, 66 | MemorySegment outputs, 67 | long num_outputs, 68 | MemorySegment status) { 69 | try { 70 | DOWN$MH.invokeExact(funcPtr, user_data, outputs, num_outputs, status); 71 | } catch (Throwable ex$) { 72 | throw new AssertionError("should not reach here", ex$); 73 | } 74 | } 75 | 76 | /** 77 | * Get an implementation of the function interface from a function pointer. 78 | */ 79 | public static RunAsyncCallbackFn.Function function(MemorySegment funcPtr) { 80 | return (user_data, outputs, num_outputs, status) -> invoke(funcPtr, user_data, outputs, num_outputs, status); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/module-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * A Java binding of Microsoft's ONNX Runtime. 3 | *

4 | * This project's goals are to provide a type-safe, lightweight, and performant binding which abstracts a lot of the 5 | * native and C API intricacies away behind a Java-friendly interface. This is loosely coupled to the upstream project 6 | * and built off of the public (and stable) C API. 7 | *

8 | * This uses Java's Foreign Function & Memory API. This will 9 | * require the runtime to have the {@code --enable-native-access=ALL-UNNAMED} or similar JVM options. 10 | *

    11 | *
  • The {@code onnxruntime-cpu} artifact provides support for several common operating systems / CPU architecture 12 | * combinations. For use as an optional runtime dependency. Include one of the OS/Architecture classifiers like 13 | * {@code osx-x86_64} to provide specific support. 14 | *
  • The {@code onnxruntime} artifact contains only bindings and no libraries. This means the native library will need 15 | * to be provided. Use this artifact as a compile dependency if you want to allow your project's users to bring use 16 | * {@code onnxruntime-cpu} or their own native library as dependencies provided at runtime. 17 | *
18 | * 19 | * @since 1.0.0 20 | */ 21 | module com.jyuzawa.onnxruntime { 22 | exports com.jyuzawa.onnxruntime; 23 | } 24 | -------------------------------------------------------------------------------- /src/test/java/com/jyuzawa/onnxruntime/ConcurrencyTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 James Yuzawa (https://www.jyuzawa.com/) 3 | * SPDX-License-Identifier: MIT 4 | */ 5 | package com.jyuzawa.onnxruntime; 6 | 7 | import static org.junit.jupiter.api.Assertions.*; 8 | 9 | import java.io.IOException; 10 | import java.lang.System.Logger; 11 | import java.lang.System.Logger.Level; 12 | import java.util.Arrays; 13 | import java.util.Map; 14 | import java.util.concurrent.ExecutorService; 15 | import java.util.concurrent.Executors; 16 | import java.util.concurrent.TimeUnit; 17 | import onnx.OnnxMl.GraphProto; 18 | import onnx.OnnxMl.ModelProto; 19 | import onnx.OnnxMl.NodeProto; 20 | import onnx.OnnxMl.OperatorSetIdProto; 21 | import onnx.OnnxMl.TensorProto.DataType; 22 | import onnx.OnnxMl.TensorShapeProto; 23 | import onnx.OnnxMl.TensorShapeProto.Dimension; 24 | import onnx.OnnxMl.TypeProto; 25 | import onnx.OnnxMl.TypeProto.Tensor; 26 | import onnx.OnnxMl.ValueInfoProto; 27 | import org.junit.jupiter.api.AfterAll; 28 | import org.junit.jupiter.api.BeforeAll; 29 | import org.junit.jupiter.api.Test; 30 | 31 | public class ConcurrencyTest { 32 | 33 | private static final Logger LOG = System.getLogger(ConcurrencyTest.class.getName()); 34 | private static Environment environment; 35 | private static byte[] model; 36 | 37 | @BeforeAll 38 | public static void setup() { 39 | OnnxRuntime apiBase = OnnxRuntime.get(); 40 | Api api = apiBase.getApi(); 41 | environment = api.newEnvironment() 42 | .setLogSeverityLevel(OnnxRuntimeLoggingLevel.WARNING) 43 | .setGlobalDenormalAsZero(false) 44 | .setGlobalInterOpNumThreads(0) 45 | .setGlobalIntraOpNumThreads(0) 46 | .setGlobalSpinControl(false) 47 | .build(); 48 | TypeProto type = TypeProto.newBuilder() 49 | .setTensorType(Tensor.newBuilder() 50 | .setElemType(DataType.FLOAT_VALUE) 51 | .setShape(TensorShapeProto.newBuilder() 52 | .addDim(Dimension.newBuilder().setDimValue(1)) 53 | .addDim(Dimension.newBuilder().setDimValue(3)))) 54 | .build(); 55 | model = ModelProto.newBuilder() 56 | .setIrVersion(8) 57 | .addOpsetImport(OperatorSetIdProto.newBuilder().setVersion(15)) 58 | .setGraph(GraphProto.newBuilder() 59 | .addNode(NodeProto.newBuilder() 60 | .addInput("input") 61 | .addOutput("output") 62 | .setOpType("Identity")) 63 | .addInput(ValueInfoProto.newBuilder().setName("input").setType(type)) 64 | .addOutput(ValueInfoProto.newBuilder().setName("output").setType(type))) 65 | .build() 66 | .toByteArray(); 67 | } 68 | 69 | @AfterAll 70 | public static void tearDown() { 71 | environment.close(); 72 | } 73 | 74 | @Test 75 | public void loadUnloadTest() throws IOException { 76 | for (int i = 0; i < 10; i++) { 77 | LOG.log(Level.INFO, "Loading session " + i); 78 | try (Session session = environment 79 | .newSession() 80 | .setByteArray(model) 81 | .disablePerSessionThreads() 82 | .setMemoryPatternOptimization(false) 83 | .setConfigMap(Map.of("session.use_env_allocators", "1")) 84 | .build()) { 85 | for (int j = 0; j < 10; j++) { 86 | try (Transaction txn = session.newTransaction().build()) { 87 | float[] rawInput = new float[] {554354, 52345234, 143646}; 88 | txn.addInput(0).asTensor().getFloatBuffer().put(rawInput); 89 | txn.addOutput(0); 90 | NamedCollection output = txn.run(); 91 | float[] rawOutput = new float[3]; 92 | OnnxValue outputValue = output.get(0); 93 | OnnxTensor outputTensor = outputValue.asTensor(); 94 | outputTensor.getFloatBuffer().get(rawOutput); 95 | assertArrayEquals(rawInput, rawOutput); 96 | } 97 | } 98 | } 99 | } 100 | } 101 | 102 | @Test 103 | public void poolTest() throws Exception { 104 | int numThreads = 4; 105 | try (Session session = environment 106 | .newSession() 107 | .setByteArray(model) 108 | .disablePerSessionThreads() 109 | .setMemoryPatternOptimization(false) 110 | .addProvider(ExecutionProvider.CPU_EXECUTION_PROVIDER, Map.of("use_arena", "0")) 111 | .setConfigMap(Map.of("session.use_env_allocators", "1")) 112 | .build()) { 113 | ExecutorService executor = Executors.newFixedThreadPool(numThreads); 114 | for (int i = 0; i < numThreads; i++) { 115 | executor.submit(() -> { 116 | for (int j = 0; j < 10000; j++) { 117 | try (Transaction txn = session.newTransaction().build()) { 118 | float[] rawInput = new float[] {554354, 52345234, 143646}; 119 | txn.addInput(0).asTensor().getFloatBuffer().put(rawInput); 120 | txn.addOutput(0); 121 | NamedCollection output = txn.run(); 122 | float[] rawOutput = new float[3]; 123 | OnnxValue outputValue = output.get(0); 124 | OnnxTensor outputTensor = outputValue.asTensor(); 125 | outputTensor.getFloatBuffer().get(rawOutput); 126 | assertTrue(Arrays.equals(rawInput, rawOutput)); 127 | } 128 | } 129 | }); 130 | } 131 | executor.shutdown(); 132 | executor.awaitTermination(1, TimeUnit.MINUTES); 133 | } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/test/resources/libcustom_op_library.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuzawa-san/onnxruntime-java/922a424b9c103e4f235e7b8f21afa1be520faa11/src/test/resources/libcustom_op_library.dylib -------------------------------------------------------------------------------- /src/test/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | --------------------------------------------------------------------------------