├── menoh ├── src │ ├── test │ │ ├── resources │ │ │ └── models │ │ │ │ ├── invalid_format.onnx │ │ │ │ ├── and_op.onnx │ │ │ │ ├── unsupported_onnx_opset_version.onnx │ │ │ │ └── make_onnx.py │ │ └── java │ │ │ └── jp │ │ │ └── preferred │ │ │ └── menoh │ │ │ ├── TestUtils.java │ │ │ ├── DTypeTest.java │ │ │ ├── MenohExceptionTest.java │ │ │ ├── ErrorCodeTest.java │ │ │ ├── ModelDataTest.java │ │ │ ├── ModelRunnerTest.java │ │ │ ├── BufferUtilsTest.java │ │ │ ├── VariableProfileTableTest.java │ │ │ └── ModelTest.java │ └── main │ │ └── java │ │ └── jp │ │ └── preferred │ │ └── menoh │ │ ├── MenohRunnerException.java │ │ ├── VariableProfile.java │ │ ├── DType.java │ │ ├── ErrorCode.java │ │ ├── ModelData.java │ │ ├── Variable.java │ │ ├── MenohException.java │ │ ├── VariableProfileTable.java │ │ ├── MenohNative.java │ │ ├── VariableProfileTableBuilder.java │ │ ├── Model.java │ │ ├── BufferUtils.java │ │ ├── ModelRunner.java │ │ ├── ModelBuilder.java │ │ └── ModelRunnerBuilder.java └── pom.xml ├── .gitignore ├── LICENSE ├── menoh-examples ├── README.md ├── pom.xml └── src │ └── main │ └── java │ └── jp │ └── preferred │ └── menoh │ └── examples │ └── Vgg16.java ├── LICENSE_THIRD_PARTY ├── README.md ├── pom.xml └── config └── checkstyle └── checkstyle.xml /menoh/src/test/resources/models/invalid_format.onnx: -------------------------------------------------------------------------------- 1 | Here is an invalid ONNX model file. -------------------------------------------------------------------------------- /menoh/src/test/resources/models/and_op.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-java/HEAD/menoh/src/test/resources/models/and_op.onnx -------------------------------------------------------------------------------- /menoh/src/test/resources/models/unsupported_onnx_opset_version.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/menoh-java/HEAD/menoh/src/test/resources/models/unsupported_onnx_opset_version.onnx -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | target/ 4 | cppbuild/ 5 | 6 | # Log file 7 | *.log 8 | 9 | # BlueJ files 10 | *.ctxt 11 | 12 | # Mobile Tools for Java (J2ME) 13 | .mtj.tmp/ 14 | 15 | # Package Files # 16 | *.jar 17 | *.war 18 | *.ear 19 | *.zip 20 | *.tar.gz 21 | *.rar 22 | 23 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 24 | hs_err_pid* 25 | 26 | # eclipse specific 27 | .metadata 28 | jrebel.lic 29 | .settings 30 | .classpath 31 | .project 32 | 33 | .ensime* 34 | *.sublime-* 35 | .cache 36 | 37 | # intellij 38 | *.eml 39 | *.iml 40 | *.ipr 41 | *.iws 42 | .*.sw? 43 | .idea 44 | 45 | # image files 46 | *.bmp 47 | *.jpg 48 | *.png 49 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/MenohRunnerException.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | public class MenohRunnerException extends RuntimeException { 4 | public MenohRunnerException(String message) { 5 | super(message); 6 | } 7 | 8 | public MenohRunnerException(String message, Throwable cause) { 9 | super(message, cause); 10 | } 11 | 12 | public MenohRunnerException(Throwable cause) { 13 | super(cause); 14 | } 15 | 16 | protected MenohRunnerException( 17 | String message, 18 | Throwable cause, 19 | boolean enableSuppression, 20 | boolean writableStackTrace) { 21 | super(message, cause, enableSuppression, writableStackTrace); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/VariableProfile.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | /** 4 | * A variable profile added to {@link VariableProfileTable}. 5 | */ 6 | public class VariableProfile { 7 | private final DType dtype; 8 | private final int[] dims; 9 | 10 | VariableProfile(DType dtype, int[] dims) { 11 | this.dtype = dtype; 12 | this.dims = dims; 13 | } 14 | 15 | /** 16 | * A data type of the variable. 17 | */ 18 | public DType dtype() { 19 | return this.dtype; 20 | } 21 | 22 | /** 23 | * An array of dimensions. 24 | */ 25 | public int[] dims() { 26 | if (dims != null) { 27 | return this.dims.clone(); 28 | } else { 29 | return new int[0]; 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/TestUtils.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import java.io.FileNotFoundException; 4 | import java.io.IOException; 5 | import java.net.URISyntaxException; 6 | import java.net.URL; 7 | import java.nio.file.Paths; 8 | 9 | public class TestUtils { 10 | /** 11 | * Convert resource name into file path. 12 | */ 13 | public static String getResourceFilePath(String name) throws IOException, URISyntaxException { 14 | // ref. https://stackoverflow.com/a/17870390/1014818 15 | URL url = TestUtils.class.getClassLoader().getResource(name); 16 | if (url != null) { 17 | return Paths.get(url.toURI()).toFile().getCanonicalPath(); 18 | } else { 19 | throw new FileNotFoundException("The specified resource not found: " + name); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Preferred Networks, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/DTypeTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertAll; 4 | import static org.junit.jupiter.api.Assertions.assertEquals; 5 | import static org.junit.jupiter.api.Assertions.assertThrows; 6 | 7 | import org.junit.jupiter.api.Test; 8 | 9 | public class DTypeTest { 10 | @Test 11 | public void floatDType() { 12 | assertAll("DType.FLOAT", 13 | () -> assertEquals(0, DType.FLOAT.getId()), 14 | () -> assertEquals(4, DType.FLOAT.size()) 15 | ); 16 | } 17 | 18 | @Test 19 | public void valueOfValidDType() { 20 | assertEquals(DType.FLOAT, DType.valueOf(0)); 21 | } 22 | 23 | @Test 24 | public void valueOfIntMinDType() { 25 | MenohException e = assertThrows(MenohException.class, () -> DType.valueOf(Integer.MIN_VALUE)); 26 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 27 | } 28 | 29 | @Test 30 | public void valueOfIntMaxDType() { 31 | MenohException e = assertThrows(MenohException.class, () -> DType.valueOf(Integer.MAX_VALUE)); 32 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/DType.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | /** 4 | * A data type. 5 | */ 6 | public enum DType { 7 | UNDEFINED(-1), // values[0] 8 | FLOAT(0); 9 | 10 | private final int id; 11 | 12 | // cache values() to avoid cloning the backing array every time 13 | private static final DType[] values = values(); 14 | 15 | DType(int id) { 16 | this.id = id; 17 | } 18 | 19 | public int getId() { 20 | return id; 21 | } 22 | 23 | /** 24 | * The byte size of this {@link DType}. 25 | */ 26 | public int size() throws MenohException { 27 | switch (this) { 28 | case FLOAT: 29 | return 4; 30 | default: 31 | throw new MenohException( 32 | ErrorCode.UNDEFINED, String.format("the size of dtype is unknown: %d", id)); 33 | } 34 | } 35 | 36 | /** 37 | * Returns the enum constant of the specified enum type with the specified ID. 38 | */ 39 | public static DType valueOf(int value) throws MenohException { 40 | final int index = value + 1; 41 | if (1 <= index && index < values.length) { 42 | final DType ret = values[index]; 43 | assert ret.id == value; 44 | 45 | return ret; 46 | } else { 47 | throw new MenohException(ErrorCode.UNDEFINED, "undefined dtype: " + value); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/MenohExceptionTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | // CHECKSTYLE:OFF 6 | import static org.junit.jupiter.api.Assertions.*; 7 | // CHECKSTYLE:ON 8 | 9 | import org.junit.jupiter.api.Test; 10 | 11 | public class MenohExceptionTest { 12 | @Test 13 | public void checkErrorSuccess() { 14 | checkError(ErrorCode.SUCCESS.getId()); 15 | } 16 | 17 | @Test 18 | public void checkErrorFailed() { 19 | MenohException e = assertThrows(MenohException.class, () -> checkError(ErrorCode.STD_ERROR.getId())); 20 | assertAll("ErrorCode.STD_ERROR", 21 | () -> assertEquals(ErrorCode.STD_ERROR, e.getErrorCode()), 22 | () -> assertTrue(e.getMessage().endsWith(" (std_error)"), 23 | String.format("%s doesn't end with \"(std_error)\".", e.getMessage())) 24 | ); 25 | } 26 | 27 | @Test 28 | public void checkErrorUnknownErrorCode() { 29 | MenohException e = assertThrows(MenohException.class, () -> checkError(Integer.MAX_VALUE)); 30 | assertAll("invalid ErrorCode", 31 | () -> assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()), 32 | () -> assertTrue(e.getMessage().endsWith(" (2147483647)"), 33 | String.format("%s doesn't end with \"(2147483647)\".", e.getMessage())) 34 | ); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /menoh-examples/README.md: -------------------------------------------------------------------------------- 1 | # menoh-examples 2 | *Non*-adversarial code examples in menoh-java. 3 | 4 | ## Requirements 5 | The examples requires the native Menoh Core library in the JNA search path. See the ["Getting Started"](../README.md#getting-started) section in README document. 6 | 7 | ## Usage 8 | 9 | ### VGG16 10 | Classify an input image by using VGG16, a pre-trained CNN model. See the [Menoh tutorial](https://pfnet-research.github.io/menoh/md_tutorial.html) for the details. 11 | 12 | ```bash 13 | $ cd menoh-examples 14 | $ mvn compile 15 | $ mkdir data 16 | $ wget -qN -P data/ https://upload.wikimedia.org/wikipedia/commons/5/54/Light_sussex_hen.jpg 17 | 18 | $ mvn exec:java -Dexec.mainClass=jp.preferred.menoh.examples.Vgg16 -Dexec.args="data/Light_sussex_hen.jpg" 19 | ... 20 | 21 | -21.302792 -34.792274 -11.198693 23.983116 0.8887066 -6.290651 -26.45818 -24.885696 -4.925243 14.338675 22 | top 5 categories are 23 | index: 8, score: 0.9705545, category: n01514859 hen 24 | index: 7, score: 0.028219773, category: n01514668 cock 25 | index: 86, score: 9.522993E-4, category: n01807496 partridge 26 | index: 82, score: 1.6639252E-4, category: n01797886 ruffed grouse, partridge, Bonasa umbellus 27 | index: 97, score: 2.0030011E-5, category: n01847000 drake 28 | ``` 29 | 30 | ## License 31 | `menoh-examples` package uses `VGG16.onnx`. It is generated by [ONNX-Chainer](https://github.com/chainer/onnx-chainer) from the pre-trained model which is available from the [website](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) of Visual Geometry Group from University of Oxford under [Creative Commons Attribution License](https://creativecommons.org/licenses/by/4.0/). 32 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/ErrorCode.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | /** 4 | * A representation of menoh_error_code. 5 | */ 6 | public enum ErrorCode { 7 | UNDEFINED(Integer.MIN_VALUE), // values[0] 8 | SUCCESS(0), 9 | STD_ERROR(1), 10 | UNKNOWN_ERROR(2), 11 | INVALID_FILENAME(3), 12 | UNSUPPORTED_ONNX_OPSET_VERSION(4), 13 | ONNX_PARSE_ERROR(5), 14 | INVALID_DTYPE(6), 15 | INVALID_ATTRIBUTE_TYPE(7), 16 | UNSUPPORTED_OPERATOR_ATTRIBUTE(8), 17 | DIMENSION_MISMATCH(9), 18 | VARIABLE_NOT_FOUND(10), 19 | INDEX_OUT_OF_RANGE(11), 20 | JSON_PARSE_ERROR(12), 21 | INVALID_BACKEND_NAME(13), 22 | UNSUPPORTED_OPERATOR(14), 23 | FAILED_TO_CONFIGURE_OPERATOR(15), 24 | BACKEND_ERROR(16), 25 | SAME_NAMED_VARIABLE_ALREADY_EXIST(17), 26 | UNSUPPORTED_INPUT_DIMS(18); 27 | 28 | private final int id; 29 | 30 | // cache values() to avoid cloning the backing array every time 31 | private static final ErrorCode[] values = values(); 32 | 33 | ErrorCode(int id) { 34 | this.id = id; 35 | } 36 | 37 | public int getId() { 38 | return this.id; 39 | } 40 | 41 | /** 42 | * Returns the enum constant of the specified enum type with the specified ID. 43 | */ 44 | public static ErrorCode valueOf(int value) throws MenohException { 45 | final int index = value + 1; 46 | if (1 <= index && index < values.length) { 47 | final ErrorCode ret = values[index]; 48 | assert ret.id == value; 49 | 50 | return ret; 51 | } else { 52 | throw new MenohException(UNDEFINED, "The error code is undefined: " + value); 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/ModelData.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.PointerByReference; 7 | 8 | /** 9 | *

A model data.

10 | * 11 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 12 | * in the native heap.

13 | */ 14 | public class ModelData implements AutoCloseable { 15 | private Pointer handle; 16 | 17 | private ModelData(Pointer handle) { 18 | this.handle = handle; 19 | } 20 | 21 | Pointer nativeHandle() { 22 | return this.handle; 23 | } 24 | 25 | /** 26 | * Optimizes this model data. 27 | * 28 | * @return this object 29 | */ 30 | public ModelData optimize(VariableProfileTable vpt) throws MenohException { 31 | checkError(MenohNative.INSTANCE.menoh_model_data_optimize(handle, vpt.nativeHandle())); 32 | 33 | return this; 34 | } 35 | 36 | @Override 37 | public void close() { 38 | synchronized (this) { 39 | if (handle != Pointer.NULL) { 40 | MenohNative.INSTANCE.menoh_delete_model_data(handle); 41 | handle = Pointer.NULL; 42 | } 43 | } 44 | } 45 | 46 | /** 47 | * Loads an ONNX model from the specified file. 48 | */ 49 | public static ModelData fromOnnxFile(String onnxModelPath) throws MenohException { 50 | final PointerByReference handle = new PointerByReference(); 51 | checkError(MenohNative.INSTANCE.menoh_make_model_data_from_onnx(onnxModelPath, handle)); 52 | 53 | return new ModelData(handle.getValue()); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/Variable.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import com.sun.jna.Pointer; 4 | 5 | import java.nio.ByteBuffer; 6 | 7 | /** 8 | * An input or output variable in the model. 9 | */ 10 | public class Variable { 11 | private final DType dtype; 12 | private final int[] dims; 13 | private final Pointer bufferHandle; 14 | 15 | Variable(DType dtype, int[] dims, Pointer bufferHandle) { 16 | this.dtype = dtype; 17 | this.dims = dims; 18 | this.bufferHandle = bufferHandle; 19 | } 20 | 21 | /** 22 | * A data type of the variable. 23 | */ 24 | public DType dtype() { 25 | return this.dtype; 26 | } 27 | 28 | /** 29 | * An array of dimensions. 30 | */ 31 | public int[] dims() { 32 | if (dims != null) { 33 | return this.dims.clone(); 34 | } else { 35 | return new int[0]; 36 | } 37 | } 38 | 39 | /** 40 | * A direct {@link ByteBuffer} which points to the native buffer of the variable. The buffer can be read 41 | * and written via the methods of ByteBuffer before and after running the model. 42 | */ 43 | public ByteBuffer buffer() throws MenohException { 44 | return this.bufferHandle.getByteBuffer(0, bufferLength()); 45 | } 46 | 47 | /** 48 | * The length of the buffer in bytes. 49 | */ 50 | long bufferLength() { 51 | if (dims.length > 0) { 52 | final long elementSize = dtype.size(); 53 | long length = 1; 54 | for (int d : dims) { 55 | length *= d; 56 | } 57 | 58 | if (length <= 0) { 59 | throw new MenohException(ErrorCode.UNDEFINED, "buffer is empty"); 60 | } 61 | 62 | return elementSize * length; 63 | } else { 64 | throw new MenohException(ErrorCode.UNDEFINED, "buffer is empty"); 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/ErrorCodeTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static org.junit.jupiter.api.Assertions.assertAll; 4 | import static org.junit.jupiter.api.Assertions.assertEquals; 5 | import static org.junit.jupiter.api.Assertions.assertThrows; 6 | 7 | import org.junit.jupiter.api.Test; 8 | 9 | public class ErrorCodeTest { 10 | @Test 11 | public void undefinedErrorCode() { 12 | assertEquals(Integer.MIN_VALUE, ErrorCode.UNDEFINED.getId()); 13 | } 14 | 15 | @Test 16 | public void successErrorCode() { 17 | assertEquals(0, ErrorCode.SUCCESS.getId()); 18 | } 19 | 20 | @Test 21 | public void valueOfValidErrorCode() { 22 | assertAll("valid ErrorCodes", 23 | () -> assertEquals(ErrorCode.SUCCESS, ErrorCode.valueOf(0)), 24 | () -> assertEquals(ErrorCode.UNSUPPORTED_INPUT_DIMS, ErrorCode.valueOf(18)) 25 | ); 26 | } 27 | 28 | @Test 29 | public void valueOfInvalidErrorCode() { 30 | assertAll("invalid ErrorCodes", 31 | () -> { 32 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(-1)); 33 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 34 | }, 35 | () -> { 36 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(18 + 1)); 37 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 38 | }); 39 | } 40 | 41 | @Test 42 | public void valueOfIntMinErrorCode() { 43 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(Integer.MIN_VALUE)); 44 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 45 | } 46 | 47 | @Test 48 | public void valueOfIntMaxErrorCode() { 49 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(Integer.MAX_VALUE)); 50 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/MenohException.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import java.util.Locale; 4 | 5 | public class MenohException extends RuntimeException { 6 | private final ErrorCode errorCode; 7 | 8 | public MenohException(ErrorCode errorCode, String message) { 9 | super(message); 10 | this.errorCode = errorCode; 11 | } 12 | 13 | public MenohException(ErrorCode errorCode, String message, Throwable cause) { 14 | super(message, cause); 15 | this.errorCode = errorCode; 16 | } 17 | 18 | public MenohException(ErrorCode errorCode, Throwable cause) { 19 | super(cause); 20 | this.errorCode = errorCode; 21 | } 22 | 23 | protected MenohException( 24 | ErrorCode errorCode, 25 | String message, 26 | Throwable cause, 27 | boolean enableSuppression, 28 | boolean writableStackTrace) { 29 | super(message, cause, enableSuppression, writableStackTrace); 30 | this.errorCode = errorCode; 31 | } 32 | 33 | public ErrorCode getErrorCode() { 34 | return this.errorCode; 35 | } 36 | 37 | /** 38 | * Throws {@link MenohException} if the errorCode of a native Menoh function is 39 | * non-zero. This method must be called right after the invocation because it uses 40 | * menoh_get_last_error_message. 41 | * 42 | * @param errorCode an error code returned from the Menoh function 43 | */ 44 | static void checkError(int errorCode) throws MenohException { 45 | if (errorCode != ErrorCode.SUCCESS.getId()) { 46 | final String errorMessage = MenohNative.INSTANCE.menoh_get_last_error_message(); 47 | 48 | ErrorCode ec; 49 | try { 50 | ec = ErrorCode.valueOf(errorCode); 51 | } catch (MenohException e) { 52 | // ErrorCode.valueOf() throws MenohException if the error code is undefined 53 | throw new MenohException(ErrorCode.UNDEFINED, String.format("%s (%d)", errorMessage, errorCode), e); 54 | } 55 | 56 | final String errorCodeName = ec.toString().toLowerCase(Locale.ENGLISH); 57 | throw new MenohException(ec, String.format("%s (%s)", errorMessage, errorCodeName)); 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/VariableProfileTable.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.IntByReference; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | /** 10 | *

A table which holds {@link VariableProfile}s for the {@link Model}. It will be built by 11 | * {@link VariableProfileTableBuilder}.

12 | * 13 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 14 | * in the native heap.

15 | */ 16 | public class VariableProfileTable implements AutoCloseable { 17 | private Pointer handle; 18 | 19 | VariableProfileTable(Pointer handle) { 20 | this.handle = handle; 21 | } 22 | 23 | Pointer nativeHandle() { 24 | return this.handle; 25 | } 26 | 27 | @Override 28 | public void close() { 29 | synchronized (this) { 30 | if (handle != Pointer.NULL) { 31 | MenohNative.INSTANCE.menoh_delete_variable_profile_table(handle); 32 | handle = Pointer.NULL; 33 | } 34 | } 35 | } 36 | 37 | /** 38 | * Creates a {@link VariableProfileTableBuilder}. 39 | */ 40 | public static VariableProfileTableBuilder builder() throws MenohException { 41 | final PointerByReference ref = new PointerByReference(); 42 | checkError(MenohNative.INSTANCE.menoh_make_variable_profile_table_builder(ref)); 43 | 44 | return new VariableProfileTableBuilder(ref.getValue()); 45 | } 46 | 47 | /** 48 | * A {@link VariableProfile} with the specified name. 49 | */ 50 | public VariableProfile variableProfile(String variableName) throws MenohException { 51 | final IntByReference dtype = new IntByReference(0); 52 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dtype(handle, variableName, dtype)); 53 | 54 | final IntByReference dimsSize = new IntByReference(); 55 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dims_size(handle, variableName, dimsSize)); 56 | 57 | final int[] dims = new int[dimsSize.getValue()]; 58 | 59 | for (int i = 0; i < dimsSize.getValue(); ++i) { 60 | final IntByReference d = new IntByReference(); 61 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dims_at(handle, variableName, i, d)); 62 | dims[i] = d.getValue(); 63 | } 64 | 65 | return new VariableProfile(DType.valueOf(dtype.getValue()), dims); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/MenohNative.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import com.sun.jna.Library; 4 | import com.sun.jna.Native; 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.IntByReference; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | // CHECKSTYLE:OFF 10 | interface MenohNative extends Library { 11 | MenohNative INSTANCE = (MenohNative) Native.loadLibrary("menoh", MenohNative.class); 12 | 13 | String menoh_get_last_error_message(); 14 | 15 | int menoh_make_model_data_from_onnx(String onnx_filename, PointerByReference dst_handle); 16 | 17 | void menoh_delete_model_data(Pointer model_data); 18 | 19 | int menoh_model_data_optimize(Pointer model_data, Pointer variable_profile_table); 20 | 21 | int menoh_make_variable_profile_table_builder(PointerByReference dst_handle); 22 | 23 | void menoh_delete_variable_profile_table_builder(Pointer builder); 24 | 25 | int menoh_variable_profile_table_builder_add_input_profile_dims_2(Pointer builder, String name, int dtype, int num, int size); 26 | 27 | int menoh_variable_profile_table_builder_add_input_profile_dims_4(Pointer builder, String name, int dtype, int num, int channel, int height, int width); 28 | 29 | int menoh_variable_profile_table_builder_add_output_profile(Pointer builder, String name, int dtype); 30 | 31 | int menoh_build_variable_profile_table(Pointer builder, Pointer model_data, PointerByReference dst_handle); 32 | 33 | void menoh_delete_variable_profile_table(Pointer variable_profile_table); 34 | 35 | int menoh_variable_profile_table_get_dtype(Pointer variable_profile_table, String variable_name, IntByReference dst_dtype); 36 | 37 | int menoh_variable_profile_table_get_dims_size(Pointer variable_profile_table, String variable_name, IntByReference dst_size); 38 | 39 | int menoh_variable_profile_table_get_dims_at(Pointer variable_profile_table, String variable_name, int index, IntByReference dst_size); 40 | 41 | int menoh_make_model_builder(Pointer variable_profile_table, PointerByReference dst_handle); 42 | 43 | void menoh_delete_model_builder(Pointer model_builder); 44 | 45 | int menoh_model_builder_attach_external_buffer(Pointer builder, String variable_name, Pointer buffer_handle); 46 | 47 | int menoh_build_model(Pointer builder, Pointer model_data, String backend_name, String backend_config, PointerByReference dst_model_handle); 48 | 49 | void menoh_delete_model(Pointer model); 50 | 51 | int menoh_model_get_variable_dtype(Pointer model, String variable_name, IntByReference dst_dtype); 52 | 53 | int menoh_model_run(Pointer model); 54 | 55 | int menoh_model_get_variable_dims_size(Pointer model, String variable_name, IntByReference dst_size); 56 | 57 | int menoh_model_get_variable_dims_at(Pointer model, String variable_name, int index, IntByReference dst_size); 58 | 59 | int menoh_model_get_variable_buffer_handle(Pointer model, String variable_name, PointerByReference dst_data); 60 | } 61 | // CHECKSTYLE:ON 62 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/VariableProfileTableBuilder.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.PointerByReference; 7 | 8 | /** 9 | *

A builder object for {@link VariableProfileTable}.

10 | * 11 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 12 | * in the native heap.

13 | */ 14 | public class VariableProfileTableBuilder implements AutoCloseable { 15 | private Pointer handle; 16 | 17 | VariableProfileTableBuilder(Pointer handle) { 18 | this.handle = handle; 19 | } 20 | 21 | Pointer nativeHandle() { 22 | return this.handle; 23 | } 24 | 25 | @Override 26 | public void close() { 27 | synchronized (this) { 28 | if (handle != Pointer.NULL) { 29 | MenohNative.INSTANCE.menoh_delete_variable_profile_table_builder(handle); 30 | handle = Pointer.NULL; 31 | } 32 | } 33 | } 34 | 35 | /** 36 | * Adds an input profile to configure the specified variable in the model. 37 | * 38 | * @return this object 39 | */ 40 | public VariableProfileTableBuilder addInputProfile(String name, DType dtype, int[] dims) throws MenohException { 41 | if (dims.length == 2) { 42 | checkError( 43 | MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_input_profile_dims_2( 44 | handle, name, dtype.getId(), dims[0], dims[1])); 45 | } else if (dims.length == 4) { 46 | checkError( 47 | MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_input_profile_dims_4( 48 | handle, name, dtype.getId(), dims[0], dims[1], dims[2], dims[3])); 49 | } else { 50 | throw new MenohException( 51 | ErrorCode.UNDEFINED, 52 | String.format("%s has an invalid dims size: %d (it must be 2 or 4)", name, dims.length)); 53 | } 54 | 55 | return this; 56 | } 57 | 58 | /** 59 | * Adds an output profile to configure the specified variable in the model. 60 | * 61 | * @return this object 62 | */ 63 | public VariableProfileTableBuilder addOutputProfile(String name, DType dtype) throws MenohException { 64 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_output_profile( 65 | handle, name, dtype.getId())); 66 | 67 | return this; 68 | } 69 | 70 | /** 71 | * Builds a {@link VariableProfileTable}. It is used for making {@link ModelBuilder}. 72 | */ 73 | public VariableProfileTable build(ModelData modelData) throws MenohException { 74 | final PointerByReference ref = new PointerByReference(); 75 | checkError(MenohNative.INSTANCE.menoh_build_variable_profile_table( 76 | this.handle, modelData.nativeHandle(), ref)); 77 | 78 | return new VariableProfileTable(ref.getValue()); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/Model.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.IntByReference; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | *

A representation of the model. It will be built by {@link ModelBuilder}.

13 | * 14 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 15 | * in the native heap.

16 | */ 17 | public class Model implements AutoCloseable { 18 | private Pointer handle; 19 | 20 | /** 21 | * A reference to the pointers to prevent them from getting garbage collected. 22 | */ 23 | private final List externalBuffers; 24 | 25 | Model(Pointer handle, List externalBuffers) { 26 | this.handle = handle; 27 | this.externalBuffers = externalBuffers; 28 | } 29 | 30 | Pointer nativeHandle() { 31 | return this.handle; 32 | } 33 | 34 | List externalBuffers() { 35 | return this.externalBuffers; 36 | } 37 | 38 | @Override 39 | public void close() { 40 | synchronized (this) { 41 | if (handle != Pointer.NULL) { 42 | MenohNative.INSTANCE.menoh_delete_model(handle); 43 | handle = Pointer.NULL; 44 | externalBuffers.clear(); 45 | } 46 | } 47 | } 48 | 49 | /** 50 | * Creates a {@link ModelBuilder}. 51 | */ 52 | public static ModelBuilder builder(VariableProfileTable vpt) throws MenohException { 53 | final PointerByReference ref = new PointerByReference(); 54 | checkError(MenohNative.INSTANCE.menoh_make_model_builder(vpt.nativeHandle(), ref)); 55 | 56 | return new ModelBuilder(ref.getValue()); 57 | } 58 | 59 | /** 60 | * Returns a {@link Variable} with the specified name. 61 | */ 62 | public Variable variable(String variableName) throws MenohException { 63 | final IntByReference dtype = new IntByReference(); 64 | 65 | checkError(MenohNative.INSTANCE.menoh_model_get_variable_dtype(handle, variableName, dtype)); 66 | 67 | final IntByReference dimsSize = new IntByReference(); 68 | checkError(MenohNative.INSTANCE.menoh_model_get_variable_dims_size(handle, variableName, dimsSize)); 69 | 70 | final int[] dims = new int[dimsSize.getValue()]; 71 | 72 | for (int i = 0; i < dimsSize.getValue(); ++i) { 73 | final IntByReference d = new IntByReference(); 74 | checkError(MenohNative.INSTANCE.menoh_model_get_variable_dims_at(handle, variableName, i, d)); 75 | dims[i] = d.getValue(); 76 | } 77 | 78 | final PointerByReference buffer = new PointerByReference(); 79 | checkError(MenohNative.INSTANCE.menoh_model_get_variable_buffer_handle(handle, variableName, buffer)); 80 | 81 | return new Variable(DType.valueOf(dtype.getValue()), dims, buffer.getValue()); 82 | } 83 | 84 | /** 85 | * Run this model. 86 | */ 87 | public void run() throws MenohException { 88 | checkError(MenohNative.INSTANCE.menoh_model_run(handle)); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /menoh-examples/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | jp.preferred.menoh 8 | menoh-examples 9 | 0.1.0 10 | jar 11 | 12 | Menoh Java Examples 13 | https://github.com/pfnet-research/menoh-java 14 | 15 | 16 | ${project.version} 17 | UTF-8 18 | 1.7 19 | 1.7 20 | 21 | 22 | 23 | 24 | jp.preferred.menoh 25 | menoh 26 | ${menoh.version} 27 | 28 | 29 | junit 30 | junit 31 | test 32 | 4.11 33 | 34 | 35 | 36 | 37 | 38 | 39 | org.codehaus.mojo 40 | wagon-maven-plugin 41 | 2.0.0 42 | 43 | 44 | download-vgg16-onnx 45 | process-resources 46 | 47 | download-single 48 | 49 | 50 | https://www.dropbox.com 51 | s/bjfn9kehukpbmcm/VGG16.onnx?dl=1 52 | ${project.build.outputDirectory}/data/VGG16.onnx 53 | true 54 | 55 | 56 | 57 | download-synset-words-txt 58 | process-resources 59 | 60 | download-single 61 | 62 | 63 | https://raw.githubusercontent.com 64 | HoldenCaulfieldRye/caffe/master/data/ilsvrc12/synset_words.txt 65 | ${project.build.outputDirectory}/data/synset_words.txt 66 | true 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/BufferUtils.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import com.sun.jna.Memory; 4 | import com.sun.jna.Native; 5 | import com.sun.jna.Pointer; 6 | 7 | import java.nio.ByteBuffer; 8 | import java.nio.ByteOrder; 9 | 10 | class BufferUtils { 11 | /** 12 | *

Copies a buffer to a native memory.

13 | * 14 | *

If the buffer is direct, it will be attached to the model directly without copying. 15 | * Otherwise, it copies the content of the buffer to a newly allocated memory in the native heap ranging 16 | * from position() to (limit() - 1) without changing its position.

17 | * 18 | *

Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because 19 | * the native byte order of your platform may differ from JVM.

20 | * 21 | * @param buffer the non-empty byte buffer from which to copy 22 | * 23 | * @return the pointer to the allocated native memory 24 | * @throws IllegalArgumentException if buffer is null or empty 25 | */ 26 | static Pointer copyToNativeMemory(final ByteBuffer buffer) { 27 | if (buffer == null || buffer.remaining() <= 0) { 28 | throw new IllegalArgumentException("buffer must not be null or empty"); 29 | } 30 | 31 | final int length = buffer.remaining(); 32 | 33 | if (buffer.isDirect()) { 34 | final int offset = buffer.position(); 35 | 36 | // return a pointer to the direct buffer without copying 37 | return Native.getDirectBufferPointer(buffer).share(offset, length); 38 | } else { 39 | final Memory mem = new Memory(length); 40 | 41 | int index; 42 | byte[] bytes; 43 | if (buffer.hasArray()) { // it is array-backed and not read-only 44 | index = buffer.arrayOffset() + buffer.position(); 45 | bytes = buffer.array(); 46 | } else { 47 | index = 0; 48 | // use duplicated buffer to avoid changing `position` 49 | bytes = new byte[length]; 50 | buffer.duplicate().get(bytes); 51 | } 52 | mem.write(0, bytes, index, length); 53 | 54 | return mem.share(0, length); 55 | } 56 | } 57 | 58 | /** 59 | *

Copies the array to a newly allocated memory in the native heap.

60 | * 61 | * @param values the non-empty array from which to copy 62 | * @param offset the array index from which to start copying 63 | * @param length the number of elements from values that must be copied 64 | * 65 | * @return the pointer to the allocated native memory 66 | * @throws IllegalArgumentException if values is null or empty 67 | */ 68 | static Pointer copyToNativeMemory(final float[] values, final int offset, final int length) { 69 | if (values == null || values.length <= 0) { 70 | throw new IllegalArgumentException("values must not be null or empty"); 71 | } 72 | 73 | final Memory mem = new Memory((long) length * 4); 74 | mem.write(0, values, offset, length); 75 | 76 | return mem.share(0, length); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /menoh/src/test/resources/models/make_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import numpy as np 5 | from unittest import mock 6 | 7 | import chainer 8 | import chainer.functions as F 9 | import chainer.links as L 10 | 11 | import onnx_chainer 12 | 13 | 14 | # e.g. [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 15 | class AND(chainer.Chain): 16 | def __init__(self): 17 | super().__init__() 18 | with self.init_scope(): 19 | self.fc1 = L.Linear(2, 1) 20 | 21 | self.fc1.W.array[:] = np.array([[1.0, 1.0]]) 22 | self.fc1.b.array[:] = np.array([[-1.0]]) 23 | 24 | def __call__(self, x): 25 | x.node._onnx_name = 'input' 26 | h = F.absolute(F.relu(self.fc1(x))) 27 | h.node._onnx_name = 'output' 28 | return h 29 | 30 | 31 | # e.g. [[0, 1, 2], [3, 4, 5]] -> [[0, 0, 15, 96, 177], [0, 0, 51, 312, 573]] 32 | class MLP(chainer.Chain): 33 | 34 | def __init__(self): 35 | super().__init__() 36 | with self.init_scope(): 37 | self.fc1 = L.Linear(3, 4) 38 | self.fc2 = L.Linear(4, 5) 39 | 40 | self.fc1.W.array[:] = np.arange(-6, 6).reshape((4, 3)) 41 | self.fc1.b.array[:] = np.arange(-2, 2) 42 | 43 | self.fc2.W.array[:] = np.arange(-10, 10).reshape((5, 4)) 44 | self.fc2.b.array[:] = np.arange(-2, 3) 45 | 46 | def __call__(self, x): 47 | x.node._onnx_name = 'input' 48 | h = F.relu(self.fc1(x)) 49 | h.node._onnx_name = 'fc1' 50 | h = F.relu(self.fc2(h)) 51 | h.node._onnx_name = 'fc2' 52 | return h 53 | 54 | 55 | class IDGenerator(object): 56 | 57 | def __init__(self): 58 | # keep original 59 | self._id = id 60 | 61 | def __call__(self, obj): 62 | return getattr(obj, '_onnx_name', self._id(obj)) 63 | 64 | 65 | def main(logger): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--model') 68 | parser.add_argument('--out', required=False) 69 | args = parser.parse_args() 70 | 71 | if args.out is None: 72 | output_filename = args.model + ".onnx" 73 | else: 74 | output_filename = args.out 75 | 76 | logger.info("Generating `{}` model and save it to `{}`".format(args.model, output_filename)) 77 | 78 | try: 79 | if args.model == 'and_op': 80 | model = AND() 81 | x = np.empty((1, 2), dtype=np.float32) 82 | with chainer.using_config('train', False), \ 83 | mock.patch('builtins.id', IDGenerator()): 84 | onnx_chainer.export(model, x, filename=output_filename) 85 | 86 | elif args.model == 'mlp': 87 | model = MLP() 88 | x = np.empty((1, 3), dtype=np.float32) 89 | with chainer.using_config('train', False), \ 90 | mock.patch('builtins.id', IDGenerator()): 91 | onnx_chainer.export(model, x, filename=output_filename) 92 | 93 | except Exception: 94 | logger.exception("An error occurred during generation of the model") 95 | 96 | 97 | if __name__ == '__main__': 98 | logging.basicConfig() 99 | logger = logging.getLogger(__name__) 100 | logger.setLevel(logging.INFO) 101 | 102 | main(logger) 103 | -------------------------------------------------------------------------------- /LICENSE_THIRD_PARTY: -------------------------------------------------------------------------------- 1 | This software depends on 2 | 3 | ==== Intel(R) MKL-DNN ==== 4 | 5 | Intel(R) MKL-DNN distributed under Apache License 2.0 licence 6 | 7 | https://github.com/intel/mkl-dnn 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | ==== Xbyak ==== 12 | 13 | Xbyak distributed under 3-clause BSD licence 14 | 15 | https://github.com/intel/mkl-dnn/blob/master/src/cpu/xbyak/COPYRIGHT 16 | 17 | ==== mklml ==== 18 | 19 | Copyright (c) 2016-2018, Intel Corporation 20 | All rights reserved. 21 | 22 | Redistribution and use in source and binary forms, with or without 23 | modification, are permitted provided that the following conditions are met: 24 | 25 | * Redistributions of source code must retain the above copyright notice, 26 | this list of conditions and the following disclaimer. 27 | * Redistributions in binary form must reproduce the above copyright 28 | notice, this list of conditions and the following disclaimer in the 29 | documentation and/or other materials provided with the distribution. 30 | * Neither the name of Intel Corporation nor the names of its contributors 31 | may be used to endorse or promote products derived from this software 32 | without specific prior written permission. 33 | 34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 44 | 45 | ==== Protocol Buffers ==== 46 | 47 | Copyright 2008 Google Inc. All rights reserved. 48 | 49 | Redistribution and use in source and binary forms, with or without 50 | modification, are permitted provided that the following conditions are 51 | met: 52 | 53 | * Redistributions of source code must retain the above copyright 54 | notice, this list of conditions and the following disclaimer. 55 | * Redistributions in binary form must reproduce the above 56 | copyright notice, this list of conditions and the following disclaimer 57 | in the documentation and/or other materials provided with the 58 | distribution. 59 | * Neither the name of Google Inc. nor the names of its 60 | contributors may be used to endorse or promote products derived from 61 | this software without specific prior written permission. 62 | 63 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 64 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 65 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 66 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 67 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 68 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 69 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 70 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 71 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 72 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 73 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 74 | 75 | Code generated by the Protocol Buffer compiler is owned by the owner 76 | of the input file used when generating it. This code is not 77 | standalone and requires a support library to be linked with it. This 78 | support library is itself covered by the above license. 79 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/ModelDataTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | // CHECKSTYLE:OFF 4 | import static jp.preferred.menoh.TestUtils.*; 5 | import static org.junit.jupiter.api.Assertions.*; 6 | // CHECKSTYLE:ON 7 | 8 | import org.junit.jupiter.api.Test; 9 | 10 | public class ModelDataTest { 11 | @Test 12 | public void makeFromValidOnnxFile() throws Exception { 13 | final String path = getResourceFilePath("models/and_op.onnx"); 14 | 15 | try (ModelData modelData = ModelData.fromOnnxFile(path)) { 16 | assertNotNull(modelData.nativeHandle()); 17 | } 18 | } 19 | 20 | @Test 21 | public void closeModelData() throws Exception { 22 | final String path = getResourceFilePath("models/and_op.onnx"); 23 | 24 | final ModelData modelData = ModelData.fromOnnxFile(path); 25 | try { 26 | assertNotNull(modelData.nativeHandle()); 27 | } finally { 28 | modelData.close(); 29 | assertNull(modelData.nativeHandle()); 30 | 31 | // close() is an idempotent operation 32 | modelData.close(); 33 | } 34 | } 35 | 36 | @Test 37 | public void makeFromNonExistentOnnxFile() { 38 | MenohException e = assertThrows( 39 | MenohException.class, () -> ModelData.fromOnnxFile("__NON_EXISTENT_FILENAME__")); 40 | assertAll("non-existent onnx file", 41 | () -> assertEquals(ErrorCode.INVALID_FILENAME, e.getErrorCode()), 42 | () -> assertEquals( 43 | "menoh invalid filename error: __NON_EXISTENT_FILENAME__ (invalid_filename)", 44 | e.getMessage()) 45 | ); 46 | } 47 | 48 | @Test 49 | public void makeFromInvalidOnnxFile() throws Exception { 50 | final String path = getResourceFilePath("models/invalid_format.onnx"); 51 | 52 | MenohException e = assertThrows(MenohException.class, () -> ModelData.fromOnnxFile(path)); 53 | assertAll("invalid onnx file", 54 | () -> assertEquals(ErrorCode.ONNX_PARSE_ERROR, e.getErrorCode()), 55 | () -> assertEquals( 56 | String.format("menoh onnx parse error: %s (onnx_parse_error)", path), 57 | e.getMessage()) 58 | ); 59 | } 60 | 61 | @Test 62 | public void makeFromUnsupportedOnnxOpsetVersionFile() throws Exception { 63 | // Note: This file is a copy of and_op.onnx which is edited the last byte to 127 (0x7e) 64 | final String path = getResourceFilePath("models/unsupported_onnx_opset_version.onnx"); 65 | 66 | MenohException e = assertThrows(MenohException.class, () -> ModelData.fromOnnxFile(path)); 67 | assertAll("invalid onnx file", 68 | () -> assertEquals(ErrorCode.UNSUPPORTED_ONNX_OPSET_VERSION, e.getErrorCode()), 69 | () -> assertEquals( 70 | String.format( 71 | "menoh unsupported onnx opset version error: %s has " 72 | + "onnx opset version %d > %d (unsupported_onnx_opset_version)", 73 | path, 127, 7), 74 | e.getMessage()) 75 | ); 76 | } 77 | 78 | @Test 79 | public void optimizeModelData() throws Exception { 80 | final String path = getResourceFilePath("models/and_op.onnx"); 81 | 82 | try ( 83 | ModelData modelData = ModelData.fromOnnxFile(path); 84 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 85 | .addInputProfile("input", DType.FLOAT, new int[] {1, 2}) 86 | .addOutputProfile("output", DType.FLOAT); 87 | VariableProfileTable vpt = vptBuilder.build(modelData) 88 | ) { 89 | modelData.optimize(vpt); 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /menoh/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | jp.preferred.menoh 9 | menoh-parent 10 | 0.2.0-SNAPSHOT 11 | 12 | 13 | jp.preferred.menoh 14 | menoh 15 | jar 16 | 17 | Menoh Java 18 | A Java binding for Menoh DNN inference library 19 | https://github.com/pfnet-research/menoh-java 20 | 21 | 22 | ${project.parent.basedir}/config/checkstyle/checkstyle.xml 23 | ${project.build.directory}/site/checkstyle/checkstyle-result.xml 24 | 25 | 26 | 27 | 28 | net.java.dev.jna 29 | jna 30 | 4.5.2 31 | 32 | 33 | com.google.code.findbugs 34 | findbugs-annotations 35 | provided 36 | 37 | 38 | org.junit.jupiter 39 | junit-jupiter-engine 40 | test 41 | 42 | 43 | 44 | 45 | 46 | 47 | org.apache.maven.plugins 48 | maven-checkstyle-plugin 49 | 50 | 51 | com.puppycrawl.tools 52 | checkstyle 53 | ${checkstyle.version} 54 | 55 | 56 | 57 | ${checkstyle.config.location} 58 | error 59 | UTF-8 60 | true 61 | true 62 | ${checkstyle.config.outputFile} 63 | 64 | 65 | 66 | verify 67 | 68 | check 69 | 70 | 71 | 72 | 73 | 74 | org.codehaus.mojo 75 | findbugs-maven-plugin 76 | 77 | Max 78 | Low 79 | true 80 | true 81 | UTF-8 82 | target/site/findbugs 83 | target/site/findbugs 84 | 85 | 86 | 87 | verify 88 | 89 | check 90 | 91 | 92 | 93 | 94 | 95 | org.apache.maven.plugins 96 | maven-source-plugin 97 | 98 | 99 | attach-sources 100 | 101 | jar-no-fork 102 | 103 | 104 | 105 | 106 | 107 | org.apache.maven.plugins 108 | maven-javadoc-plugin 109 | 110 | 111 | attach-javadocs 112 | 113 | jar 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | org.apache.maven.plugins 125 | maven-jxr-plugin 126 | 2.5 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/ModelRunner.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.nio.ByteOrder; 5 | import java.util.Collections; 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | /** 10 | *

A convenient wrapper for building and running {@link Model}.

11 | * 12 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 13 | * in the native heap.

14 | */ 15 | public class ModelRunner implements AutoCloseable { 16 | private final Model model; 17 | 18 | private static final String DEFAULT_BACKEND_NAME = "mkldnn"; 19 | 20 | private static final String DEFAULT_BACKEND_CONFIG = ""; 21 | 22 | ModelRunner(Model model) { 23 | this.model = model; 24 | } 25 | 26 | /** 27 | * Returns the underlying {@link Model}. 28 | */ 29 | Model model() { 30 | return this.model; 31 | } 32 | 33 | @Override 34 | public void close() { 35 | model.close(); 36 | } 37 | 38 | /** 39 | * Loads an ONNX model from the specified file. 40 | */ 41 | public static ModelRunnerBuilder fromOnnxFile(String path) { 42 | ModelData modelData = null; 43 | VariableProfileTableBuilder vptBuilder = null; 44 | try { 45 | modelData = ModelData.fromOnnxFile(path); 46 | vptBuilder = VariableProfileTable.builder(); 47 | 48 | return new ModelRunnerBuilder( 49 | modelData, 50 | vptBuilder, 51 | DEFAULT_BACKEND_NAME, 52 | DEFAULT_BACKEND_CONFIG, 53 | new HashMap()); 54 | } catch (Throwable t) { 55 | if (modelData != null) { 56 | modelData.close(); 57 | } 58 | if (vptBuilder != null) { 59 | vptBuilder.close(); 60 | } 61 | 62 | throw t; 63 | } 64 | } 65 | 66 | /** 67 | * Returns a {@link Variable} with the specified name. 68 | */ 69 | public Variable variable(String variableName) throws MenohException { 70 | return model.variable(variableName); 71 | } 72 | 73 | /** 74 | *

Run this model after assigning a non-empty array to the specified variable.

75 | * 76 | * @param name the name of the input variable 77 | * @param values the values to be copied to the input variable 78 | */ 79 | public void run(String name, float[] values) { 80 | run(name, values, 0, values.length); 81 | } 82 | 83 | /** 84 | *

Run this model after assigning a non-empty array to the specified variable. It copies the content 85 | * ranging from offset to (offset + length - 1).

86 | * 87 | * @param name the name of the input variable 88 | * @param values the values to be copied to the input variable 89 | */ 90 | public void run(String name, float[] values, int offset, int length) { 91 | final ByteBuffer buffer = ByteBuffer.allocateDirect(length * 4).order(ByteOrder.nativeOrder()); 92 | buffer.asFloatBuffer().put(values, offset, length); 93 | 94 | run(name, buffer); 95 | } 96 | 97 | /** 98 | *

Run this model after assigning a non-empty buffer to the specified variable. It copies the content 99 | * ranging from position() to (limit() - 1) without changing them, even if 100 | * the buffer is direct unlike {@link ModelRunnerBuilder#attachExternalBuffer(String, ByteBuffer)}. 101 | *

102 | * 103 | *

Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because 104 | * the native byte order of your platform may differ from JVM.

105 | * 106 | * @param name the name of the input variable 107 | * @param buffer the buffer to be copied to the input variable 108 | */ 109 | public void run(String name, ByteBuffer buffer) { 110 | run(Collections.singletonMap(name, buffer)); 111 | } 112 | 113 | /** 114 | *

Run this model after assigning non-empty buffers to the specified variables. It copies the content 115 | * ranging from position() to (limit() - 1) without changing them, even if 116 | * the buffer is direct unlike {@link ModelRunnerBuilder#attachExternalBuffer(String, ByteBuffer)}. 117 | *

118 | * 119 | *

Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because 120 | * the native byte order of your platform may differ from JVM.

121 | * 122 | * @param buffers the buffers to be copied to the variables 123 | */ 124 | public void run(final Map buffers) { 125 | assignToVariables(buffers); 126 | model.run(); 127 | } 128 | 129 | /** 130 | * Run this model. 131 | */ 132 | public void run() { 133 | model.run(); 134 | } 135 | 136 | /** 137 | * Assign data to the variables in the model. 138 | */ 139 | private void assignToVariables(final Map data) { 140 | for (Map.Entry e : data.entrySet()) { 141 | final String name = e.getKey(); 142 | final ByteBuffer dataBuf = e.getValue(); 143 | final long dataLen = dataBuf.remaining(); 144 | 145 | final Variable v = model.variable(name); 146 | final long varLen = v.bufferLength(); 147 | 148 | if (varLen < dataLen) { 149 | throw new MenohRunnerException(String.format( 150 | "The data with length > %d can't be assigned to the variable `%s`.", varLen, name)); 151 | } 152 | 153 | final ByteBuffer varBuf = v.buffer(); 154 | varBuf.clear(); 155 | varBuf.put(dataBuf.duplicate()).rewind(); 156 | } 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/ModelBuilder.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.BufferUtils.copyToNativeMemory; 4 | import static jp.preferred.menoh.MenohException.checkError; 5 | 6 | import com.sun.jna.Pointer; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | import java.nio.ByteBuffer; 10 | import java.nio.ByteOrder; 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | 14 | /** 15 | *

A builder object for {@link Model}.

16 | * 17 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 18 | * in the native heap.

19 | */ 20 | public class ModelBuilder implements AutoCloseable { 21 | private Pointer handle; 22 | 23 | /** 24 | * A reference to the pointers to prevent them from getting garbage collected. 25 | */ 26 | private final List externalBuffers = new ArrayList<>(); 27 | 28 | ModelBuilder(Pointer handle) { 29 | this.handle = handle; 30 | } 31 | 32 | Pointer nativeHandle() { 33 | return this.handle; 34 | } 35 | 36 | List externalBuffers() { 37 | return this.externalBuffers; 38 | } 39 | 40 | /** 41 | * Releases the model builder. 42 | */ 43 | @Override 44 | public void close() { 45 | synchronized (this) { 46 | if (handle != Pointer.NULL) { 47 | MenohNative.INSTANCE.menoh_delete_model_builder(handle); 48 | handle = Pointer.NULL; 49 | externalBuffers.clear(); 50 | } 51 | } 52 | } 53 | 54 | /** 55 | *

Attaches a non-empty external buffer to the specified variable.

56 | * 57 | *

If the specified buffer is direct, it will be attached to the model directly without 58 | * copying. Otherwise, it copies the content to a newly allocated buffer in the native heap ranging from 59 | * position() to (limit() - 1) without changing its position.

60 | * 61 | *

The attached buffer can be accessed through {@link Model#variable(String)}.

62 | * 63 | *

Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because 64 | * the native byte order of your platform may differ from JVM.

65 | * 66 | * @param variableName the name of the variable 67 | * @param buffer the byte buffer from which to copy 68 | * @return this object 69 | * 70 | * @throws IllegalArgumentException if buffer is null or empty 71 | */ 72 | public ModelBuilder attachExternalBuffer(String variableName, ByteBuffer buffer) throws MenohException { 73 | final Pointer bufferHandle = copyToNativeMemory(buffer); 74 | synchronized (this) { 75 | externalBuffers.add(bufferHandle); 76 | } 77 | 78 | return attachImpl(variableName, bufferHandle); 79 | } 80 | 81 | /** 82 | *

Attaches a non-empty external buffer to the specified variable. It also copies the content of the 83 | * values to a newly allocated buffer in the native heap.

84 | * 85 | *

The buffer can be accessed through {@link Model#variable(String)}.

86 | * 87 | * @param variableName the name of the variable 88 | * @param values the byte buffer from which to copy 89 | * @return this object 90 | * 91 | * @throws IllegalArgumentException if values is null or empty 92 | */ 93 | public ModelBuilder attachExternalBuffer(String variableName, float[] values) throws MenohException { 94 | return attachExternalBuffer(variableName, values, 0, values.length); 95 | } 96 | 97 | /** 98 | *

Attaches a non-empty external buffer to the specified variable. It also copies the content of the 99 | * values to a newly allocated buffer in the native heap ranging from offset 100 | * to (offset + length - 1).

101 | * 102 | *

The buffer can be accessed through {@link Model#variable(String)}.

103 | * 104 | * @param variableName the name of the variable 105 | * @param values the byte buffer from which to copy 106 | * @param offset the array index from which to start copying 107 | * @param length the number of elements from values that must be copied 108 | * @return this object 109 | * 110 | * @throws IllegalArgumentException if values is null or empty 111 | */ 112 | public ModelBuilder attachExternalBuffer( 113 | String variableName, float[] values, int offset, int length) throws MenohException { 114 | final Pointer bufferHandle = copyToNativeMemory(values, offset, length); 115 | synchronized (this) { 116 | externalBuffers.add(bufferHandle); 117 | } 118 | 119 | return attachImpl(variableName, bufferHandle); 120 | } 121 | 122 | private ModelBuilder attachImpl(String variableName, Pointer bufferHandle) throws MenohException { 123 | checkError(MenohNative.INSTANCE.menoh_model_builder_attach_external_buffer( 124 | handle, variableName, bufferHandle)); 125 | 126 | return this; 127 | } 128 | 129 | /** 130 | *

Builds a {@link Model} to run() by using the specified backend (e.g. "mkldnn").

131 | * 132 | *

Menoh will allocate a new buffer for input and output variables to which an external buffer is not 133 | * attached. It can be accessed via {@link Model#variable(String)}.

134 | */ 135 | public Model build(ModelData modelData, String backendName, String backendConfig) throws MenohException { 136 | final PointerByReference ref = new PointerByReference(); 137 | checkError(MenohNative.INSTANCE.menoh_build_model( 138 | this.handle, modelData.nativeHandle(), backendName, backendConfig, ref)); 139 | 140 | synchronized (this) { 141 | return new Model(ref.getValue(), new ArrayList<>(this.externalBuffers)); 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # menoh-java 2 | *menoh-java* enables you to build your own Deep Neural Network (DNN) application with a few lines of code in Java. 3 | 4 | This is a Java binding for [Menoh](https://github.com/pfnet-research/menoh/) DNN inference library, which supports [ONNX](http://onnx.ai/) model format. 5 | 6 | ## Getting Started 7 | 8 | ### Add a dependency 9 | Using Gradle: 10 | 11 | ```groovy 12 | dependencies { 13 | implementation 'jp.preferred.menoh:menoh:0.1.0' 14 | } 15 | ``` 16 | 17 | Using Maven: 18 | 19 | ```xml 20 | 21 | jp.preferred.menoh 22 | menoh 23 | 0.1.0 24 | 25 | ``` 26 | 27 | ### Install native libraries 28 | menoh-java requires [Menoh Core](https://github.com/pfnet-research/menoh/) and its dependent native shared libraries to maximize the utilization of hardware resources. You need to [install](https://github.com/pfnet-research/menoh/blob/master/README.md#installation-using-package-manager-or-binary-packages) them to the JVM classpath or the system library path before running. 29 | 30 | ## Examples 31 | Please see [menoh-examples](menoh-examples) directory in this repository. 32 | 33 | ## Usage 34 | menoh-java provides two types of APIs: high-level and low-level. The high-level API is a simple wrapper of the low-level API. You should choose the high-level API in most cases because it manages lifecycle of the low-level objects on behalf of you. 35 | 36 | ### High-level API 37 | `ModelRunner` is a high-level API of menoh-java. What you only need to do is to configure `ModelRunnerBuilder` by using your ONNX model, and `build()` a runner object. 38 | 39 | Note that you must `close()` both the runner and its builder objects explicitly because it frees the memory for objects in the native heap which will not be garbage collected by the JVM. 40 | 41 | ```java 42 | import jp.preferred.menoh.ModelRunner; 43 | import jp.preferred.menoh.ModelRunnerBuilder; 44 | 45 | try ( 46 | ModelRunnerBuilder builder = ModelRunner 47 | // Load ONNX model data 48 | .fromOnnxFile(onnxModelPath) 49 | 50 | // Define input profile (name, dtype, dims) and output profile (name, dtype) 51 | // Menoh calculates dims of outputs automatically at build time 52 | .addInputProfile(conv11InName, DType.FLOAT, new int[] {batchSize, channelNum, height, width}) 53 | .addOutputProfile(fc6OutName, DType.FLOAT) 54 | .addOutputProfile(softmaxOutName, DType.FLOAT) 55 | 56 | // Configure backend 57 | .backendName("mkldnn") 58 | .backendConfig(""); 59 | ModelRunner runner = builder.build() 60 | ) { 61 | // The builder can be deleted explicitly after building a model runner 62 | builder.close(); 63 | ... 64 | ``` 65 | 66 | Once you create the `ModelRunner`, you can `run()` the model with input data again and again: 67 | 68 | ```java 69 | // Run the inference 70 | runner.run(conv11InName, imageData); 71 | 72 | final Variable softmaxOut = runner.variable(softmaxOutName); 73 | 74 | final int[] softmaxOutDims = softmaxOut.dims(); 75 | final ByteBuffer softmaxOutBuf = softmaxOut.buffer(); 76 | 77 | // Note: use `get()` instead of `array()` because it is a direct buffer 78 | final float[] scores = new float[softmaxOutDims[1]]; 79 | softmaxOutBuf.asFloatBuffer().get(scores); 80 | ... 81 | ``` 82 | 83 | ### Low-level API 84 | The low-level API consists of `ModelData`, `VariableProfileTable` and `Model`. You don't need to use them in most cases other than managing lifecycle of the builder objects and the variable buffers by hand. 85 | 86 | ## Building from Source 87 | ```bash 88 | $ git clone https://github.com/pfnet-research/menoh-java.git 89 | $ cd menoh-java 90 | $ mvn package 91 | ``` 92 | 93 | Note that `mvn test` requires that Menoh Core is available in the [JNA search path](http://java-native-access.github.io/jna/4.5.2/javadoc/com/sun/jna/NativeLibrary.html). 94 | 95 | ## FAQ 96 | 97 | ### menoh-java fails with `java.lang.UnsatisfiedLinkError` 98 | menoh-java depends on the native Menoh Core library. You'll get `java.lang.UnsatisfiedLinkError` at startup if it isn't located in [JNA search path](http://java-native-access.github.io/jna/4.5.2/javadoc/com/sun/jna/NativeLibrary.html) even if it exists in the local system. 99 | 100 | ``` 101 | java.lang.UnsatisfiedLinkError: Unable to load library 'menoh': Native library (win32-x86-64/menoh.dll) not found in resource path ([file:/C:/workspace/menoh-java/menoh-examples/target/classes/, file:/C:/Users/user/.m2/repository/jp/preferred/menoh/menoh/1.0.0-SNAPSHOT/menoh-1.0.0-SNAPSHOT.jar, file:/C:/Users/user/.m2/repository/net/java/dev/jna/jna/4.5.2/jna-4.5.2.jar]) 102 | at com.sun.jna.NativeLibrary.loadLibrary(NativeLibrary.java:303) 103 | at com.sun.jna.NativeLibrary.getInstance(NativeLibrary.java:427) 104 | at com.sun.jna.Library$Handler.(Library.java:179) 105 | at com.sun.jna.Native.loadLibrary(Native.java:569) 106 | at com.sun.jna.Native.loadLibrary(Native.java:544) 107 | at jp.preferred.menoh.MenohNative. (MenohNative.java:11) 108 | ... 109 | ``` 110 | 111 | If you fall into this situation, you need to configure the system property (`jna.library.path`) or install the library file to the JVM classpath or the system library path, which depends on your platform (`PATH` on Windows, `LD_LIBRARY_PATH` on Linux and `DYLD_LIBRARY_PATH` on OSX). See the [JNA's document](https://github.com/java-native-access/jna/blob/master/www/GettingStarted.md) for more details. 112 | 113 | To inspect the problem, you may set the system property `jna.debug_load=true` to know what is getting wrong: 114 | 115 | ``` 116 | $ mvn exec:java ... -Djna.debug_load=true 117 | Looking in classpath from java.net.URLClassLoader@3e997fa1 for /com/sun/jna/win32-x86-64/jnidispatch.dll 118 | Found library resource at jar:file:/C:/Users/.../.m2/repository/net/java/dev/jna/jna/4.5.2/jna-4.5.2.jar!/com/sun/jna/win32-x86-64/jnidispatch.dll 119 | Looking for library 'menoh' 120 | Adding paths from jna.library.path: null 121 | Trying menoh.dll 122 | Adding system paths: [] 123 | Trying menoh.dll 124 | Looking for lib- prefix 125 | Trying libmenoh.dll 126 | Looking in classpath from java.net.URLClassLoader@3e997fa1 for menoh 127 | ``` 128 | 129 | And you can also see `jna.platform.library.path`: 130 | 131 | ```java 132 | NativeLibrary.getProcess(); // a trick to initialize `NativeLibrary` explicitly in this place 133 | System.err.println("jna.library.path: " + System.getProperty("jna.library.path")); 134 | System.err.println("jna.platform.library.path: " + System.getProperty("jna.platform.library.path")); 135 | ``` 136 | 137 | ## Limitation 138 | This library only works on 64-bit architecture at the moment. 139 | 140 | ## License 141 | menoh-java is released under MIT License. Please see the [LICENSE](LICENSE) file for details. 142 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/ModelRunnerTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | // CHECKSTYLE:OFF 4 | import static jp.preferred.menoh.TestUtils.*; 5 | import static org.junit.jupiter.api.Assertions.*; 6 | // CHECKSTYLE:ON 7 | 8 | import org.junit.jupiter.api.Test; 9 | 10 | public class ModelRunnerTest { 11 | @Test 12 | public void runModelRunner() throws Exception { 13 | final String path = getResourceFilePath("models/and_op.onnx"); 14 | final int batchSize = 4; 15 | final int inputDim = 2; 16 | final float[] inputData1 = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 17 | final float[] inputData2 = new float[] {1f, 1f, 1f, 0f, 0f, 1f, 0f, 0f}; 18 | final int outputDim = 1; 19 | final float[] expectedOutput1 = new float[] {0f, 0f, 0f, 1f}; 20 | final float[] expectedOutput2 = new float[] {1f, 0f, 0f, 0f}; 21 | 22 | try ( 23 | ModelRunnerBuilder builder = ModelRunner 24 | .fromOnnxFile(path) 25 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 26 | .addOutputProfile("output", DType.FLOAT) 27 | .attachExternalBuffer("input", inputData1) 28 | .backendName("mkldnn") 29 | .backendConfig(""); 30 | ModelRunner runner = builder.build() 31 | ) { 32 | assertAll("model data in builder", 33 | () -> assertNotNull(builder.modelData()), 34 | () -> assertNotNull(builder.modelData().nativeHandle()) 35 | ); 36 | assertAll("vpt builder in builder", 37 | () -> assertNotNull(builder.vptBuilder()), 38 | () -> assertNotNull(builder.vptBuilder().nativeHandle()) 39 | ); 40 | assertAll("backend config in builder", 41 | () -> assertEquals("mkldnn", builder.backendName()), 42 | () -> assertEquals("", builder.backendConfig()) 43 | ); 44 | assertAll("attached external buffers in builder", 45 | () -> assertNotNull(builder.externalBuffers()), 46 | () -> assertNotNull(builder.externalBuffers().get("input")) 47 | ); 48 | 49 | assertAll("model in runner", 50 | () -> assertNotNull(runner.model()), 51 | () -> assertNotNull(runner.model().nativeHandle()) 52 | ); 53 | 54 | // run the model 55 | runner.run(); 56 | 57 | final Model model = runner.model(); 58 | final Variable outputVar = model.variable("output"); 59 | assertAll("output variable", 60 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 61 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 62 | ); 63 | final float[] outputBuf = new float[outputVar.dims()[0]]; 64 | outputVar.buffer().asFloatBuffer().get(outputBuf); 65 | assertArrayEquals(expectedOutput1, outputBuf); 66 | 67 | // rewrite the internal buffer and run again 68 | runner.run("input", inputData2); 69 | 70 | final Variable outputVar2 = model.variable("output"); 71 | assertAll("output variable", 72 | () -> assertEquals(DType.FLOAT, outputVar2.dtype()), 73 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar2.dims()) 74 | ); 75 | final float[] outputBuf2 = new float[outputVar2.dims()[0]]; 76 | outputVar.buffer().asFloatBuffer().get(outputBuf2); 77 | assertArrayEquals(expectedOutput2, outputBuf2); 78 | } 79 | } 80 | 81 | @Test 82 | public void closeModelRunner() throws Exception { 83 | final String path = getResourceFilePath("models/and_op.onnx"); 84 | final int batchSize = 4; 85 | final int inputDim = 2; 86 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 87 | 88 | final ModelRunnerBuilder builder = ModelRunner 89 | .fromOnnxFile(path) 90 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 91 | .addOutputProfile("output", DType.FLOAT) 92 | .attachExternalBuffer("input", inputData) 93 | .backendName("mkldnn") 94 | .backendConfig(""); 95 | final ModelRunner runner = builder.build(); 96 | try { 97 | assertNotNull(builder); 98 | assertAll("model data in builder", 99 | () -> assertNotNull(builder.modelData()), 100 | () -> assertNotNull(builder.modelData().nativeHandle()) 101 | ); 102 | assertAll("vpt builder in builder", 103 | () -> assertNotNull(builder.vptBuilder()), 104 | () -> assertNotNull(builder.vptBuilder().nativeHandle()) 105 | ); 106 | assertAll("backend config in builder", 107 | () -> assertEquals("mkldnn", builder.backendName()), 108 | () -> assertEquals("", builder.backendConfig()) 109 | ); 110 | assertAll("attached external buffers in builder", 111 | () -> assertNotNull(builder.externalBuffers()), 112 | () -> assertNotNull(builder.externalBuffers().get("input")) 113 | ); 114 | 115 | assertNotNull(runner); 116 | assertAll("model in runner", 117 | () -> assertNotNull(runner.model()), 118 | () -> assertNotNull(runner.model().nativeHandle()) 119 | ); 120 | } finally { 121 | builder.close(); 122 | assertAll("model data in builder", 123 | () -> assertNotNull(builder.modelData()), 124 | () -> assertNull(builder.modelData().nativeHandle()) 125 | ); 126 | assertAll("vpt builder in builder", 127 | () -> assertNotNull(builder.vptBuilder()), 128 | () -> assertNull(builder.vptBuilder().nativeHandle()) 129 | ); 130 | assertAll("attached external buffers in builder", 131 | () -> assertNotNull(builder.externalBuffers()), 132 | () -> assertTrue(builder.externalBuffers().isEmpty(), 133 | "externalBuffers should be empty") 134 | ); 135 | 136 | runner.close(); 137 | assertAll("model", 138 | () -> assertNotNull(runner.model()), 139 | () -> assertNull(runner.model().nativeHandle()) 140 | ); 141 | 142 | // close() is an idempotent operation 143 | builder.close(); 144 | runner.close(); 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | jp.preferred.menoh 8 | menoh-parent 9 | 0.2.0-SNAPSHOT 10 | pom 11 | 12 | Menoh Java Parent 13 | A Java binding for Menoh DNN inference library 14 | https://github.com/pfnet-research/menoh-java 15 | 16 | 17 | menoh 18 | 19 | 20 | 21 | 22 | durswd 23 | swd 24 | 25 | 26 | okapies 27 | Yuta Okamoto 28 | okapies@gmail.com 29 | 30 | 31 | 32 | 33 | https://github.com/pfnet-research/menoh-java 34 | scm:git:https://github.com/pfnet-research/menoh-java.git 35 | scm:git:git@github.com:pfnet-research/menoh-java.git 36 | HEAD 37 | 38 | 39 | 40 | UTF-8 41 | 1.7 42 | 1.7 43 | 1.8 44 | 1.8 45 | 8.11 46 | 47 | 48 | 49 | 50 | 51 | com.google.code.findbugs 52 | findbugs-annotations 53 | 3.0.1 54 | 55 | 56 | org.junit.jupiter 57 | junit-jupiter-engine 58 | 5.2.0 59 | 60 | 61 | 62 | 63 | 64 | 65 | ossrh 66 | https://oss.sonatype.org/content/repositories/snapshots 67 | 68 | 69 | ossrh 70 | https://oss.sonatype.org/service/local/staging/deploy/maven2 71 | 72 | 73 | 74 | 75 | 76 | MIT License 77 | https://opensource.org/licenses/mit-license.php 78 | repo 79 | 80 | 81 | 82 | 83 | 84 | 85 | org.apache.maven.plugins 86 | maven-gpg-plugin 87 | 88 | 89 | sign-artifacts 90 | verify 91 | 92 | sign 93 | 94 | 95 | 96 | 97 | 98 | org.sonatype.plugins 99 | nexus-staging-maven-plugin 100 | 101 | 102 | 103 | 104 | 105 | maven-clean-plugin 106 | 3.0.0 107 | 108 | 109 | 110 | maven-resources-plugin 111 | 3.0.2 112 | 113 | 114 | maven-compiler-plugin 115 | 3.7.0 116 | 117 | 118 | maven-surefire-plugin 119 | 2.22.0 120 | 121 | 122 | maven-jar-plugin 123 | 3.0.2 124 | 125 | 126 | maven-install-plugin 127 | 2.5.2 128 | 129 | 130 | maven-deploy-plugin 131 | 2.8.2 132 | 133 | 134 | org.apache.maven.plugins 135 | maven-checkstyle-plugin 136 | 3.0.0 137 | 138 | 139 | org.codehaus.mojo 140 | findbugs-maven-plugin 141 | 3.0.5 142 | 143 | 144 | org.apache.maven.plugins 145 | maven-source-plugin 146 | 3.0.1 147 | 148 | 149 | org.apache.maven.plugins 150 | maven-javadoc-plugin 151 | 3.0.1 152 | 153 | 154 | org.apache.maven.plugins 155 | maven-gpg-plugin 156 | 1.6 157 | 158 | 159 | org.sonatype.plugins 160 | nexus-staging-maven-plugin 161 | 1.6.8 162 | true 163 | 164 | ossrh 165 | https://oss.sonatype.org/ 166 | false 167 | 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/ModelRunnerBuilder.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.nio.ByteOrder; 5 | import java.util.Map; 6 | 7 | /** 8 | *

A builder object for {@link ModelRunner}.

9 | * 10 | *

Make sure to {@link #close()} this object after finishing the process to free the underlying memory 11 | * in the native heap.

12 | */ 13 | public class ModelRunnerBuilder implements AutoCloseable { 14 | private final ModelData modelData; 15 | 16 | private final VariableProfileTableBuilder vptBuilder; 17 | 18 | private String backendName; 19 | 20 | private String backendConfig; 21 | 22 | private final Map externalBuffers; 23 | 24 | ModelRunnerBuilder( 25 | ModelData modelData, 26 | VariableProfileTableBuilder vptBuilder, 27 | String backendName, 28 | String backendConfig, 29 | Map externalBuffers) { 30 | this.modelData = modelData; 31 | this.vptBuilder = vptBuilder; 32 | this.backendName = backendName; 33 | this.backendConfig = backendConfig; 34 | this.externalBuffers = externalBuffers; 35 | } 36 | 37 | ModelData modelData() { 38 | return this.modelData; 39 | } 40 | 41 | VariableProfileTableBuilder vptBuilder() { 42 | return this.vptBuilder; 43 | } 44 | 45 | public String backendName() { 46 | return this.backendName; 47 | } 48 | 49 | public ModelRunnerBuilder backendName(String backendName) { 50 | this.backendName = backendName; 51 | return this; 52 | } 53 | 54 | public String backendConfig() { 55 | return this.backendConfig; 56 | } 57 | 58 | public ModelRunnerBuilder backendConfig(String backendConfig) { 59 | this.backendConfig = backendConfig; 60 | return this; 61 | } 62 | 63 | Map externalBuffers() { 64 | return this.externalBuffers; 65 | } 66 | 67 | @Override 68 | public void close() { 69 | vptBuilder.close(); 70 | modelData.close(); 71 | 72 | // allow the attached external buffers to GC its allocated memory 73 | externalBuffers.clear(); 74 | } 75 | 76 | /** 77 | * Adds an input profile to configure the specified variable in the model. 78 | * 79 | * @return this object 80 | */ 81 | public ModelRunnerBuilder addInputProfile(String name, DType dtype, int[] dims) { 82 | vptBuilder.addInputProfile(name, dtype, dims); 83 | return this; 84 | } 85 | 86 | /** 87 | * Adds an output profile to configure the specified variable in the model. 88 | * 89 | * @return this object 90 | */ 91 | public ModelRunnerBuilder addOutputProfile(String name, DType dtype) { 92 | vptBuilder.addOutputProfile(name, dtype); 93 | return this; 94 | } 95 | 96 | /** 97 | *

Attaches a non-empty external buffer to the specified variable.

98 | * 99 | *

If the specified buffer is direct, it will be attached to the model directly without 100 | * copying. Otherwise, it copies the content to a newly allocated buffer in the native heap ranging from 101 | * position() to (limit() - 1) without changing its position.

102 | * 103 | *

The buffer can be accessed through {@link Model#variable(String)}.

104 | * 105 | *

Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because 106 | * the native byte order of your platform may differ from JVM.

107 | * 108 | * @param variableName the name of the variable 109 | * @param buffer the byte buffer from which to copy 110 | * @return this object 111 | * 112 | * @throws IllegalArgumentException if buffer is null or empty 113 | */ 114 | public ModelRunnerBuilder attachExternalBuffer(String variableName, ByteBuffer buffer) throws MenohException { 115 | externalBuffers.put(variableName, buffer); 116 | return this; 117 | } 118 | 119 | /** 120 | *

Attaches a non-empty external buffer to the specified variable. It also copies the content of the 121 | * values to a newly allocated buffer in the native heap.

122 | 123 | *

The buffer can be accessed through {@link Model#variable(String)}.

124 | * 125 | * @param variableName the name of the variable 126 | * @param values the byte buffer from which to copy 127 | * @return this object 128 | * 129 | * @throws IllegalArgumentException if values is null or empty 130 | */ 131 | public ModelRunnerBuilder attachExternalBuffer(String variableName, float[] values) throws MenohException { 132 | return attachExternalBuffer(variableName, values, 0, values.length); 133 | } 134 | 135 | /** 136 | *

Attaches a non-empty external buffer to the specified variable. It also copies the content of the 137 | * values to a newly allocated buffer in the native heap ranging from offset 138 | * to (offset + length - 1).

139 | * 140 | *

The buffer can be accessed through {@link Model#variable(String)}.

141 | * 142 | * @param variableName the name of the variable 143 | * @param values the byte buffer from which to copy 144 | * @param offset the array index from which to start copying 145 | * @param length the number of elements from values that must be copied 146 | * @return this object 147 | * 148 | * @throws IllegalArgumentException if values is null or empty 149 | */ 150 | public ModelRunnerBuilder attachExternalBuffer( 151 | String variableName, 152 | float[] values, 153 | int offset, 154 | int length) throws MenohException { 155 | // copy the array to a direct buffer 156 | final ByteBuffer buffer = ByteBuffer.allocateDirect(length * 4).order(ByteOrder.nativeOrder()); 157 | buffer.asFloatBuffer().put(values, offset, length); 158 | 159 | externalBuffers.put(variableName, buffer); 160 | return this; 161 | } 162 | 163 | /** 164 | *

Builds a {@link ModelRunner} to run() by using the specified backend (e.g. "mkldnn").

165 | * 166 | *

Menoh will allocate a new buffer for input and output variables to which an external buffer is not 167 | * attached. It can be accessed via {@link Model#variable(String)} in the ModelRunner object.

168 | */ 169 | public ModelRunner build() { 170 | try ( 171 | VariableProfileTable vpt = vptBuilder.build(modelData); 172 | ModelBuilder modelBuilder = Model.builder(vpt) 173 | ) { 174 | for (Map.Entry e : externalBuffers.entrySet()) { 175 | modelBuilder.attachExternalBuffer(e.getKey(), e.getValue()); 176 | } 177 | 178 | // reduce the memory footprint of the model data 179 | modelData.optimize(vpt); 180 | 181 | final Model model = modelBuilder.build(modelData, backendName, backendConfig); 182 | return new ModelRunner(model); 183 | } 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/BufferUtilsTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | // CHECKSTYLE:OFF 4 | import static jp.preferred.menoh.TestUtils.*; 5 | import static org.junit.jupiter.api.Assertions.*; 6 | // CHECKSTYLE:ON 7 | 8 | import com.sun.jna.Pointer; 9 | import java.nio.ByteBuffer; 10 | import java.nio.ByteOrder; 11 | 12 | import org.junit.jupiter.api.Test; 13 | 14 | public class BufferUtilsTest { 15 | @Test 16 | public void copyDirectBufferToNativeMemory() { 17 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 18 | 19 | final ByteBuffer buf = ByteBuffer.allocateDirect(bytes.length); 20 | buf.put(bytes).rewind(); 21 | assertEquals(8, buf.remaining()); 22 | 23 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 24 | assertNotNull(ptr); 25 | assertArrayEquals(bytes, ptr.getByteArray(0, bytes.length)); 26 | assertAll("buffer's state", 27 | () -> assertEquals(0, buf.position()), 28 | () -> assertEquals(8, buf.limit()) 29 | ); 30 | } 31 | 32 | @Test 33 | public void copySlicedDirectBufferToNativeMemory() { 34 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 35 | final byte[] slicedBytes = new byte[] {2, 3, 4, 5}; 36 | 37 | final ByteBuffer buf = ByteBuffer.allocateDirect(bytes.length); 38 | buf.put(bytes).rewind(); 39 | 40 | // slice 4 bytes 41 | buf.position(2).limit(6); 42 | assertEquals(4, buf.remaining()); 43 | 44 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 45 | assertNotNull(ptr); 46 | assertArrayEquals(slicedBytes, ptr.getByteArray(0, slicedBytes.length)); 47 | assertAll("buffer's state", 48 | () -> assertEquals(2, buf.position()), 49 | () -> assertEquals(6, buf.limit()) 50 | ); 51 | } 52 | 53 | @Test 54 | public void copyEmptyDirectBufferToNativeMemory() { 55 | final ByteBuffer buf = ByteBuffer.allocateDirect(0); 56 | assertEquals(0, buf.remaining()); 57 | 58 | assertThrows( 59 | IllegalArgumentException.class, 60 | () -> BufferUtils.copyToNativeMemory(buf)); 61 | } 62 | 63 | @Test 64 | public void copyArrayBackedBufferToNativeMemory() { 65 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 66 | 67 | final ByteBuffer buf = ByteBuffer.wrap(bytes); 68 | assertEquals(8, buf.remaining()); 69 | 70 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 71 | assertNotNull(ptr); 72 | assertArrayEquals(bytes, ptr.getByteArray(0, bytes.length)); 73 | assertAll("buffer's state", 74 | () -> assertEquals(0, buf.position()), 75 | () -> assertEquals(8, buf.limit()) 76 | ); 77 | } 78 | 79 | @Test 80 | public void copySlicedArrayBackedBufferToNativeMemory() { 81 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 82 | final byte[] slicedBytes = new byte[] {2, 3, 4, 5}; 83 | 84 | final ByteBuffer buf = ByteBuffer.wrap(bytes, 1, 6).slice(); 85 | assertAll("non-zero array offset", 86 | () -> assertEquals(1, buf.arrayOffset()), 87 | () -> assertEquals(0, buf.position()), 88 | () -> assertEquals(6, buf.limit()) 89 | ); 90 | 91 | // slice 4 bytes 92 | buf.position(1).limit(5); 93 | assertEquals(4, buf.remaining()); 94 | 95 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 96 | assertNotNull(ptr); 97 | assertArrayEquals(slicedBytes, ptr.getByteArray(0, slicedBytes.length)); 98 | assertAll("buffer's state", 99 | () -> assertEquals(1, buf.position()), 100 | () -> assertEquals(5, buf.limit()) 101 | ); 102 | } 103 | 104 | @Test 105 | public void copyReadOnlyBufferToNativeMemory() { 106 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 107 | 108 | final ByteBuffer buf = ByteBuffer.wrap(bytes).asReadOnlyBuffer(); 109 | assertEquals(8, buf.remaining()); 110 | 111 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 112 | assertNotNull(ptr); 113 | assertArrayEquals(bytes, ptr.getByteArray(0, bytes.length)); 114 | assertAll("buffer's state", 115 | () -> assertEquals(0, buf.position()), 116 | () -> assertEquals(8, buf.limit()) 117 | ); 118 | } 119 | 120 | @Test 121 | public void copySlicedReadOnlyBufferToNativeMemory() { 122 | final byte[] bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; 123 | final byte[] slicedBytes = new byte[] {2, 3, 4, 5}; 124 | 125 | final ByteBuffer buf = ByteBuffer.wrap(bytes).asReadOnlyBuffer(); 126 | 127 | // slice 4 bytes 128 | buf.position(2).limit(6); 129 | assertEquals(4, buf.remaining()); 130 | 131 | final Pointer ptr = BufferUtils.copyToNativeMemory(buf); 132 | assertNotNull(ptr); 133 | assertArrayEquals(slicedBytes, ptr.getByteArray(0, slicedBytes.length)); 134 | assertAll("buffer's state", 135 | () -> assertEquals(2, buf.position()), 136 | () -> assertEquals(6, buf.limit()) 137 | ); 138 | } 139 | 140 | @Test 141 | public void copyFloatArrayToNativeMemory() { 142 | final float[] values = new float[] {0f, 1f, 2f, 3f}; 143 | final ByteBuffer valuesBuf = ByteBuffer.allocate(values.length * 4).order(ByteOrder.nativeOrder()); 144 | valuesBuf.asFloatBuffer().put(values); 145 | final int offset = 0; 146 | final int length = values.length; 147 | 148 | final Pointer ptr = BufferUtils.copyToNativeMemory(values, offset, length); 149 | assertNotNull(ptr); 150 | assertAll("copied values in native memory", 151 | () -> assertArrayEquals(values, ptr.getFloatArray(0, length)), 152 | // check the byte order 153 | () -> assertArrayEquals(valuesBuf.array(), ptr.getByteArray(0, length * 4)) 154 | ); 155 | } 156 | 157 | @Test 158 | public void copySlicedFloatArrayToNativeMemory() { 159 | final float[] values = new float[] {0f, 1f, 2f, 3f}; 160 | final float[] slicedValues = new float[] {1f, 2f}; 161 | final ByteBuffer slicedValuesBuf = ByteBuffer.allocate(slicedValues.length * 4).order(ByteOrder.nativeOrder()); 162 | slicedValuesBuf.asFloatBuffer().put(slicedValues); 163 | final int offset = 1; 164 | final int length = slicedValues.length; 165 | 166 | // slice 2 elements (8 bytes) 167 | final Pointer ptr = BufferUtils.copyToNativeMemory(values, offset, length); 168 | assertNotNull(ptr); 169 | assertAll("copied values in native memory", 170 | () -> assertArrayEquals(slicedValues, ptr.getFloatArray(0, length)), 171 | // check the byte order 172 | () -> assertArrayEquals(slicedValuesBuf.array(), ptr.getByteArray(0, length * 4)) 173 | ); 174 | } 175 | 176 | @Test 177 | public void copyEmptyFloatArrayToNativeMemory() { 178 | final float[] values = new float[] {}; // test case 179 | final int offset = 0; 180 | final int length = values.length; 181 | 182 | assertThrows( 183 | IllegalArgumentException.class, 184 | () -> BufferUtils.copyToNativeMemory(values, offset, length)); 185 | } 186 | 187 | @Test 188 | public void copyNullFloatArrayToNativeMemory() { 189 | final float[] values = null; // test case 190 | final int offset = 0; 191 | final int length = 0; 192 | 193 | assertThrows( 194 | IllegalArgumentException.class, 195 | () -> BufferUtils.copyToNativeMemory(values, offset, length)); 196 | } 197 | 198 | @Test 199 | public void copyNegativeOffsetFloatArrayToNativeMemory() { 200 | final float[] values = new float[] {0f, 1f, 2f, 3f}; 201 | final int offset = -1; // test case 202 | final int length = values.length; 203 | 204 | assertThrows( 205 | ArrayIndexOutOfBoundsException.class, 206 | () -> BufferUtils.copyToNativeMemory(values, offset, length)); 207 | } 208 | 209 | @Test 210 | public void copyNegativeLengthFloatArrayToNativeMemory() { 211 | final float[] values = new float[] {0f, 1f, 2f, 3f}; 212 | final int offset = 0; 213 | final int length = -1; // test case 214 | 215 | assertThrows( 216 | IllegalArgumentException.class, 217 | () -> BufferUtils.copyToNativeMemory(values, offset, length)); 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /menoh-examples/src/main/java/jp/preferred/menoh/examples/Vgg16.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh.examples; 2 | 3 | import com.sun.jna.NativeLibrary; 4 | import jp.preferred.menoh.*; 5 | 6 | import java.awt.Graphics2D; 7 | import java.awt.image.BufferedImage; 8 | import java.awt.image.Raster; 9 | import java.io.*; 10 | import java.net.URISyntaxException; 11 | import java.net.URL; 12 | import java.nio.FloatBuffer; 13 | import java.nio.file.Paths; 14 | import java.util.ArrayList; 15 | import java.util.Collections; 16 | import java.util.Comparator; 17 | import java.util.List; 18 | import javax.imageio.ImageIO; 19 | 20 | /** 21 | *

Classify an input image by using VGG16, a pre-trained CNN model.

See the 22 | * Menoh tutorial for the details.

23 | */ 24 | public class Vgg16 { 25 | public static void main(String[] args) throws Exception { 26 | if (args.length <= 0) { 27 | System.err.println("You must specify a filename of the input image in the argument."); 28 | System.exit(1); 29 | } 30 | 31 | System.err.println("Working Directory = " + System.getProperty("user.dir")); 32 | if (Boolean.getBoolean("jna.debug_load")) { 33 | NativeLibrary.getProcess(); // a trick to initialize `NativeLibrary` explicitly in this place 34 | System.err.println("jna.library.path: " + System.getProperty("jna.library.path")); 35 | System.err.println("jna.platform.library.path: " + System.getProperty("jna.platform.library.path")); 36 | } 37 | 38 | final String conv11InName = "140326425860192"; 39 | final String fc6OutName = "140326200777584"; 40 | final String softmaxOutName = "140326200803680"; 41 | 42 | final int batchSize = 1; 43 | final int channelNum = 3; 44 | final int height = 224; 45 | final int width = 224; 46 | 47 | final String inputImagePath = args[0]; 48 | final String onnxModelPath = getResourceFilePath("data/VGG16.onnx"); 49 | final String synsetWordsPath = getResourceFilePath("data/synset_words.txt"); 50 | 51 | // Read and pre-process an input 52 | final BufferedImage image = ImageIO.read(new File(inputImagePath)); 53 | if (image == null) { 54 | throw new Exception("Invalid input image path: " + inputImagePath); 55 | } 56 | final float[] imageData = renderToNctw(cropAndResize(image, width, height)); 57 | 58 | // Build a ModelRunner 59 | try ( 60 | // Note: You must `close()` the runner and builder to free the native memory explicitly 61 | ModelRunnerBuilder builder = ModelRunner 62 | // Load ONNX model data 63 | .fromOnnxFile(onnxModelPath) 64 | 65 | // Define input profile (name, dtype, dims) and output profile (name, dtype) 66 | // Menoh calculates dims of outputs automatically at build time 67 | .addInputProfile(conv11InName, DType.FLOAT, new int[] {batchSize, channelNum, height, width}) 68 | .addOutputProfile(fc6OutName, DType.FLOAT) 69 | .addOutputProfile(softmaxOutName, DType.FLOAT) 70 | 71 | // Configure backend 72 | .backendName("mkldnn") 73 | .backendConfig(""); 74 | ModelRunner runner = builder.build() 75 | ) { 76 | // The builder can be deleted explicitly after building a model runner 77 | builder.close(); 78 | 79 | // Run the inference 80 | runner.run(conv11InName, imageData); 81 | 82 | // Get output of `fc6` (the first hidden FC layer after CNNs in VGG16) 83 | final Variable fc6Out = runner.variable(fc6OutName); 84 | 85 | final FloatBuffer fc6OutFloatBuf = fc6Out.buffer().asFloatBuffer(); 86 | for (int i = 0; i < 10; i++) { 87 | System.out.print(Float.toString(fc6OutFloatBuf.get(i)) + " "); 88 | } 89 | System.out.println(); 90 | 91 | // Get output of `softmax` (the output layer in VGG16) 92 | final Variable softmaxOut = runner.variable(softmaxOutName); 93 | 94 | final int[] softmaxOutDims = softmaxOut.dims(); 95 | final float[] scores = new float[softmaxOutDims[1]]; 96 | 97 | // Note: use `get()` instead of `array()` because it is a direct buffer 98 | softmaxOut.buffer().asFloatBuffer().get(scores); 99 | 100 | final int topK = 5; 101 | final List categories = loadCategories(synsetWordsPath); 102 | 103 | System.out.println("top " + topK + " categories are"); 104 | 105 | final List topKIndices = extractTopKIndices(scores, 0, softmaxOutDims[1], topK); 106 | for (Score s : topKIndices) { 107 | int ki = s.index; 108 | System.out.println("index: " + ki + ", score: " + scores[ki] + ", category: " + categories.get(ki)); 109 | } 110 | } 111 | } 112 | 113 | private static String getResourceFilePath(String name) throws IOException, URISyntaxException { 114 | // ref. https://stackoverflow.com/a/17870390/1014818 115 | final URL url = Vgg16.class.getClassLoader().getResource(name); 116 | if (url != null) { 117 | return Paths.get(url.toURI()).toFile().getCanonicalPath(); 118 | } else { 119 | throw new FileNotFoundException("The specified resource not found: " + name); 120 | } 121 | } 122 | 123 | private static BufferedImage resizeImage(BufferedImage image, int width, int height) { 124 | final BufferedImage destImage = new BufferedImage(width, height, image.getType()); 125 | final Graphics2D g = destImage.createGraphics(); 126 | g.drawImage(image, 0, 0, width, height, null); 127 | g.dispose(); 128 | 129 | return destImage; 130 | } 131 | 132 | /** 133 | * Crop and resize an input image to fit the size required by the VGG16 model. 134 | */ 135 | private static BufferedImage cropAndResize(BufferedImage image, int width, int height) { 136 | final int shortEdge = Math.min(image.getWidth(), image.getHeight()); 137 | final int x0 = (image.getWidth() - shortEdge) / 2; 138 | final int y0 = (image.getHeight() - shortEdge) / 2; 139 | final int width0 = shortEdge; 140 | final int height0 = shortEdge; 141 | 142 | BufferedImage cropped = image.getSubimage(x0, y0, width0, height0); 143 | BufferedImage resized = resizeImage(cropped, width, height); 144 | 145 | return resized; 146 | } 147 | 148 | /** 149 | * Convert a HWC (Height x Width x Channels) format image into a NCHW (N x Channels x Height x Width) format. 150 | */ 151 | private static float[] renderToNctw(BufferedImage image) { 152 | final int height = image.getHeight(); 153 | final int width = image.getWidth(); 154 | final int channelSize = 3; 155 | final float[] data = new float[channelSize * image.getWidth() * image.getHeight()]; 156 | 157 | final Raster raster = image.getData(); 158 | final int[] buf = new int[channelSize]; 159 | 160 | for (int y = 0; y < image.getHeight(); y++) { 161 | for (int x = 0; x < image.getWidth(); x++) { 162 | raster.getPixel(x, y, buf); 163 | data[0 * (height * width) + y * width + x] = buf[2]; 164 | data[1 * (height * width) + y * width + x] = buf[1]; 165 | data[2 * (height * width) + y * width + x] = buf[0]; 166 | } 167 | } 168 | 169 | return data; 170 | } 171 | 172 | static class Score { 173 | final int index; 174 | final float score; 175 | 176 | public Score(int index, float score) { 177 | this.index = index; 178 | this.score = score; 179 | } 180 | } 181 | 182 | private static List extractTopKIndices(float[] scores, int offset, int length, int k) { 183 | final List q = new ArrayList<>(); 184 | for (int i = offset; i < offset + length; i++) { 185 | q.add(new Score(i, scores[i])); 186 | } 187 | 188 | Collections.sort( 189 | q, 190 | new Comparator() { 191 | @Override 192 | public int compare(Score obj1, Score obj2) { 193 | if (obj1.score < obj2.score) { 194 | return 1; 195 | } else if (obj1.score > obj2.score) { 196 | return -1; 197 | } 198 | return 0; 199 | } 200 | } 201 | ); 202 | 203 | return q.subList(0, 5); 204 | } 205 | 206 | private static List loadCategories(String synsetWordsPath) throws IOException { 207 | final List ret = new ArrayList<>(); 208 | try (BufferedReader br = new BufferedReader(new FileReader(synsetWordsPath))) { 209 | String line; 210 | while ((line = br.readLine()) != null) { 211 | ret.add(line); 212 | } 213 | } 214 | 215 | return ret; 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/VariableProfileTableTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | // CHECKSTYLE:OFF 4 | import static jp.preferred.menoh.TestUtils.*; 5 | import static org.junit.jupiter.api.Assertions.*; 6 | // CHECKSTYLE:ON 7 | 8 | import org.junit.jupiter.api.Test; 9 | 10 | public class VariableProfileTableTest { 11 | @Test 12 | public void makeVptBuilderIsSuccessful() { 13 | final VariableProfileTableBuilder builder = VariableProfileTable.builder(); 14 | try { 15 | assertNotNull(builder.nativeHandle()); 16 | } finally { 17 | builder.close(); 18 | assertNull(builder.nativeHandle()); 19 | 20 | // close() is an idempotent operation 21 | builder.close(); 22 | } 23 | } 24 | 25 | @Test 26 | public void addValidInputProfile() { 27 | try (VariableProfileTableBuilder builder = VariableProfileTable.builder()) { 28 | builder.addInputProfile("foo", DType.FLOAT, new int[] {1, 1}); 29 | } 30 | } 31 | 32 | @Test 33 | public void addValidInputProfileWithInvalidDims() { 34 | try (VariableProfileTableBuilder builder = VariableProfileTable.builder()) { 35 | MenohException e = assertThrows(MenohException.class, 36 | // test case: dims.length == 1 37 | () -> builder.addInputProfile("foo", DType.FLOAT, new int[] {1, 1, 1, 1, 1})); 38 | assertAll("invalid dim size", 39 | () -> assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()), 40 | () -> assertTrue(e.getMessage().startsWith("foo has an invalid dims size: 5"), 41 | String.format("%s doesn't start with \"foo has an invalid dims size: 5\"", 42 | e.getMessage()))); 43 | } 44 | } 45 | 46 | @Test 47 | public void addValidOutputProfile() { 48 | try (VariableProfileTableBuilder builder = VariableProfileTable.builder()) { 49 | builder.addOutputProfile("foo", DType.FLOAT); 50 | } 51 | } 52 | 53 | @Test 54 | public void buildVariableProfileTable() throws Exception { 55 | final String path = getResourceFilePath("models/and_op.onnx"); 56 | final int batchSize = 1; 57 | final int inputDim = 2; 58 | final int outputDim = 1; 59 | 60 | try ( 61 | ModelData modelData = ModelData.fromOnnxFile(path); 62 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 63 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 64 | .addOutputProfile("output", DType.FLOAT); 65 | VariableProfileTable vpt = vptBuilder.build(modelData) 66 | ) { 67 | assertNotNull(vpt.nativeHandle()); 68 | 69 | // check variable profiles 70 | VariableProfile inputVp = vpt.variableProfile("input"); 71 | assertAll("input variable profile", 72 | () -> assertEquals(DType.FLOAT, inputVp.dtype()), 73 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVp.dims()) 74 | ); 75 | VariableProfile outputVp = vpt.variableProfile("output"); 76 | assertAll("output variable profile", 77 | () -> assertEquals(DType.FLOAT, outputVp.dtype()), 78 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVp.dims()) 79 | ); 80 | } 81 | } 82 | 83 | @Test 84 | public void buildVariableProfileTableWithBatchedInput() throws Exception { 85 | final String path = getResourceFilePath("models/and_op.onnx"); 86 | final int batchSize = 16; // test case 87 | final int inputDim = 2; 88 | final int outputDim = 1; 89 | 90 | try ( 91 | ModelData modelData = ModelData.fromOnnxFile(path); 92 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 93 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 94 | .addOutputProfile("output", DType.FLOAT); 95 | VariableProfileTable vpt = vptBuilder.build(modelData) 96 | ) { 97 | assertNotNull(vpt.nativeHandle()); 98 | 99 | // check variable profiles 100 | VariableProfile inputVp = vpt.variableProfile("input"); 101 | assertAll("input variable profile", 102 | () -> assertEquals(DType.FLOAT, inputVp.dtype()), 103 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVp.dims()) 104 | ); 105 | VariableProfile outputVp = vpt.variableProfile("output"); 106 | assertAll("output variable profile", 107 | () -> assertEquals(DType.FLOAT, outputVp.dtype()), 108 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVp.dims()) 109 | ); 110 | } 111 | } 112 | 113 | @Test 114 | public void closeVariableProfileTable() throws Exception { 115 | final String path = getResourceFilePath("models/and_op.onnx"); 116 | final int batchSize = 1; 117 | final int inputDim = 2; 118 | 119 | try ( 120 | ModelData modelData = ModelData.fromOnnxFile(path); 121 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 122 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 123 | .addOutputProfile("output", DType.FLOAT) 124 | ) { 125 | final VariableProfileTable vpt = vptBuilder.build(modelData); 126 | try { 127 | assertNotNull(vpt.nativeHandle()); 128 | } finally { 129 | vpt.close(); 130 | assertNull(vpt.nativeHandle()); 131 | 132 | // close() is an idempotent operation 133 | vpt.close(); 134 | } 135 | } 136 | } 137 | 138 | @Test 139 | public void buildVariableProfileTableIfInputProfileNameNotFound() throws Exception { 140 | final String path = getResourceFilePath("models/and_op.onnx"); 141 | final int batchSize = 1; 142 | final int inputDim = 2; 143 | final String inputProfileName = "__non_existent_variable__"; // test case 144 | final String inputVariableNameInModel = "input"; 145 | final String outputProfileName = "output"; 146 | 147 | try ( 148 | ModelData modelData = ModelData.fromOnnxFile(path); 149 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 150 | .addInputProfile(inputProfileName, DType.FLOAT, new int[] {batchSize, inputDim}) 151 | .addOutputProfile(outputProfileName, DType.FLOAT) 152 | ) { 153 | MenohException e = assertThrows(MenohException.class, () -> vptBuilder.build(modelData)); 154 | assertAll("input profile name not found", 155 | () -> assertEquals(ErrorCode.VARIABLE_NOT_FOUND, e.getErrorCode()), 156 | () -> assertEquals( 157 | String.format("menoh variable not found error: %s (variable_not_found)", 158 | inputVariableNameInModel), // not `inputProfileName` 159 | e.getMessage()) 160 | ); 161 | } 162 | } 163 | 164 | @Test 165 | public void buildVariableProfileTableIfInputProfileDimsMismatched() throws Exception { 166 | final String path = getResourceFilePath("models/and_op.onnx"); 167 | final int batchSize = 1; 168 | final int inputDim = 3; // test case 169 | final String inputProfileName = "input"; 170 | final String outputProfileName = "output"; 171 | 172 | try ( 173 | ModelData modelData = ModelData.fromOnnxFile(path); 174 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 175 | .addInputProfile(inputProfileName, DType.FLOAT, new int[] {batchSize, inputDim}) 176 | .addOutputProfile(outputProfileName, DType.FLOAT) 177 | ) { 178 | MenohException e = assertThrows(MenohException.class, () -> vptBuilder.build(modelData)); 179 | assertAll("mismatched input dims", 180 | () -> assertEquals(ErrorCode.DIMENSION_MISMATCH, e.getErrorCode()), 181 | () -> assertEquals( 182 | String.format( 183 | "menoh dimension mismatch error: Gemm issuing \"%s\": input[1] and weight[1] " 184 | + "actual value: %d valid value: %d (dimension_mismatch)", 185 | "140211424823896", // the name of output[0] in Gemm layer 186 | 3, 2), 187 | e.getMessage()) 188 | ); 189 | } 190 | } 191 | 192 | @Test 193 | public void buildVariableProfileTableIfOutputProfileNameNotFound() throws Exception { 194 | final String path = getResourceFilePath("models/and_op.onnx"); 195 | final int batchSize = 1; 196 | final int inputDim = 2; 197 | final String inputProfileName = "input"; 198 | final String outputProfileName = "__non_existent_variable__"; // test case 199 | 200 | try ( 201 | ModelData modelData = ModelData.fromOnnxFile(path); 202 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 203 | .addInputProfile(inputProfileName, DType.FLOAT, new int[] {batchSize, inputDim}) 204 | .addOutputProfile(outputProfileName, DType.FLOAT) 205 | ) { 206 | MenohException e = assertThrows(MenohException.class, () -> vptBuilder.build(modelData)); 207 | assertAll("output profile name not found", 208 | () -> assertEquals(ErrorCode.VARIABLE_NOT_FOUND, e.getErrorCode()), 209 | () -> assertEquals( 210 | String.format("menoh variable not found error: %s (variable_not_found)", outputProfileName), 211 | e.getMessage()) 212 | ); 213 | } 214 | } 215 | } 216 | -------------------------------------------------------------------------------- /config/checkstyle/checkstyle.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 37 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 56 | 57 | 58 | 59 | 60 | 61 | 64 | 65 | 66 | 67 | 68 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 79 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 123 | 124 | 125 | 127 | 128 | 129 | 130 | 132 | 133 | 134 | 135 | 137 | 138 | 139 | 140 | 142 | 143 | 144 | 145 | 147 | 148 | 149 | 150 | 151 | 153 | 154 | 155 | 156 | 158 | 159 | 160 | 161 | 163 | 164 | 165 | 166 | 168 | 169 | 170 | 171 | 173 | 175 | 177 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 203 | 204 | 205 | 206 | 207 | 208 | 211 | 212 | 213 | 214 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 227 | 228 | 229 | 230 | 231 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /menoh/src/test/java/jp/preferred/menoh/ModelTest.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | // CHECKSTYLE:OFF 4 | import static jp.preferred.menoh.TestUtils.*; 5 | import static org.junit.jupiter.api.Assertions.*; 6 | // CHECKSTYLE:ON 7 | 8 | import java.nio.ByteBuffer; 9 | import java.nio.ByteOrder; 10 | 11 | import org.junit.jupiter.api.Test; 12 | 13 | public class ModelTest { 14 | @Test 15 | public void makeModelBuilderIsSuccessful() throws Exception { 16 | final String path = getResourceFilePath("models/and_op.onnx"); 17 | final int batchSize = 1; 18 | final int inputDim = 2; 19 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 20 | 21 | try ( 22 | ModelData modelData = ModelData.fromOnnxFile(path); 23 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 24 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 25 | .addOutputProfile("output", DType.FLOAT); 26 | VariableProfileTable vpt = vptBuilder.build(modelData); 27 | ModelBuilder modelBuilder = Model.builder(vpt) 28 | ) { 29 | modelBuilder.attachExternalBuffer("input", inputData); 30 | 31 | assertAll("model builder", 32 | () -> assertNotNull(modelBuilder.nativeHandle()), 33 | () -> assertNotNull(modelBuilder.externalBuffers()), 34 | () -> assertTrue(!modelBuilder.externalBuffers().isEmpty(), 35 | "externalBuffers should not be empty") 36 | ); 37 | } 38 | } 39 | 40 | @Test 41 | public void closeModelBuilder() throws Exception { 42 | final String path = getResourceFilePath("models/and_op.onnx"); 43 | final int batchSize = 4; 44 | final int inputDim = 2; 45 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 46 | 47 | try ( 48 | ModelData modelData = ModelData.fromOnnxFile(path); 49 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 50 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 51 | .addOutputProfile("output", DType.FLOAT); 52 | VariableProfileTable vpt = vptBuilder.build(modelData) 53 | ) { 54 | final ModelBuilder modelBuilder = Model.builder(vpt); 55 | try { 56 | modelBuilder.attachExternalBuffer("input", inputData); 57 | 58 | assertAll("model builder", 59 | () -> assertNotNull(modelBuilder.nativeHandle()), 60 | () -> assertNotNull(modelBuilder.externalBuffers()), 61 | () -> assertTrue(!modelBuilder.externalBuffers().isEmpty(), 62 | "externalBuffers should not be empty") 63 | ); 64 | } finally { 65 | modelBuilder.close(); 66 | assertAll("model builder", 67 | () -> assertNull(modelBuilder.nativeHandle()), 68 | () -> assertNotNull(modelBuilder.externalBuffers()), 69 | () -> assertTrue(modelBuilder.externalBuffers().isEmpty(), 70 | "externalBuffers should be empty") 71 | ); 72 | 73 | // close() is an idempotent operation 74 | modelBuilder.close(); 75 | } 76 | } 77 | } 78 | 79 | @Test 80 | public void buildModelIsSuccessful() throws Exception { 81 | final String path = getResourceFilePath("models/and_op.onnx"); 82 | final int batchSize = 4; 83 | final int inputDim = 2; 84 | final String backendName = "mkldnn"; 85 | final String backendConfig = ""; 86 | 87 | try ( 88 | ModelData modelData = ModelData.fromOnnxFile(path); 89 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 90 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 91 | .addOutputProfile("output", DType.FLOAT); 92 | VariableProfileTable vpt = vptBuilder.build(modelData); 93 | ModelBuilder modelBuilder = Model.builder(vpt) 94 | ) { 95 | modelBuilder.attachExternalBuffer("input", new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}); 96 | assertTrue(!modelBuilder.externalBuffers().isEmpty(), "externalBuffers should not be empty"); 97 | 98 | try (Model model = modelBuilder.build(modelData, backendName, backendConfig)) { 99 | assertAll("model", 100 | () -> assertNotNull(model.nativeHandle()), 101 | () -> assertNotNull(model.externalBuffers()), 102 | () -> assertArrayEquals( 103 | modelBuilder.externalBuffers().toArray(), model.externalBuffers().toArray()) 104 | ); 105 | } 106 | } 107 | } 108 | 109 | @Test 110 | public void closeModelBeforeModelBuilder() throws Exception { 111 | final String path = getResourceFilePath("models/and_op.onnx"); 112 | final int batchSize = 4; 113 | final int inputDim = 2; 114 | final float[] input = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 115 | final String backendName = "mkldnn"; 116 | final String backendConfig = ""; 117 | 118 | try ( 119 | ModelData modelData = ModelData.fromOnnxFile(path); 120 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 121 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 122 | .addOutputProfile("output", DType.FLOAT); 123 | VariableProfileTable vpt = vptBuilder.build(modelData); 124 | ModelBuilder modelBuilder = Model.builder(vpt) 125 | ) { 126 | modelBuilder.attachExternalBuffer("input", input); 127 | assertTrue(!modelBuilder.externalBuffers().isEmpty(), "externalBuffers should not be empty"); 128 | 129 | final Model model = modelBuilder.build(modelData, backendName, backendConfig); 130 | try { 131 | assertAll("model", 132 | () -> assertNotNull(model.nativeHandle()), 133 | () -> assertNotNull(model.externalBuffers()), 134 | () -> assertArrayEquals( 135 | modelBuilder.externalBuffers().toArray(), model.externalBuffers().toArray()) 136 | ); 137 | } finally { 138 | model.close(); 139 | assertAll("model", 140 | () -> assertNull(model.nativeHandle()), 141 | () -> assertNotNull(model.externalBuffers()), 142 | () -> assertTrue(model.externalBuffers().isEmpty(), "externalBuffers should be empty") 143 | ); 144 | assertAll("model builder", 145 | () -> assertNotNull(modelBuilder.externalBuffers()), 146 | () -> assertTrue(!modelBuilder.externalBuffers().isEmpty(), 147 | "externalBuffers should not be empty") 148 | ); 149 | 150 | // close() is an idempotent operation 151 | model.close(); 152 | } 153 | } 154 | } 155 | 156 | @Test 157 | public void closeModelBuilderBeforeModel() throws Exception { 158 | final String path = getResourceFilePath("models/and_op.onnx"); 159 | final int batchSize = 4; 160 | final int inputDim = 2; 161 | final float[] input = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 162 | final String backendName = "mkldnn"; 163 | final String backendConfig = ""; 164 | 165 | try ( 166 | ModelData modelData = ModelData.fromOnnxFile(path); 167 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 168 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 169 | .addOutputProfile("output", DType.FLOAT); 170 | VariableProfileTable vpt = vptBuilder.build(modelData) 171 | ) { 172 | try (ModelBuilder modelBuilder = Model.builder(vpt)) { 173 | modelBuilder.attachExternalBuffer("input", input); 174 | assertTrue(!modelBuilder.externalBuffers().isEmpty(), "externalBuffers should not be empty"); 175 | 176 | final Model model = modelBuilder.build(modelData, backendName, backendConfig); 177 | try { 178 | // close model builder before model 179 | modelBuilder.close(); 180 | 181 | assertAll("model builder", 182 | () -> assertNotNull(modelBuilder.externalBuffers()), 183 | () -> assertTrue(modelBuilder.externalBuffers().isEmpty(), 184 | "externalBuffers should be empty") 185 | ); 186 | assertAll("model", 187 | () -> assertNotNull(model.nativeHandle()), 188 | () -> assertNotNull(model.externalBuffers()), 189 | () -> assertTrue(!model.externalBuffers().isEmpty(), 190 | "externalBuffers should not be empty") 191 | ); 192 | } finally { 193 | model.close(); 194 | } 195 | } 196 | } 197 | } 198 | 199 | @Test 200 | public void buildModelIfBackendNotFound() throws Exception { 201 | final String path = getResourceFilePath("models/and_op.onnx"); 202 | final int batchSize = 4; 203 | final int inputDim = 2; 204 | final String backendName = "__non_existent_backend__"; // test case 205 | final String backendConfig = ""; 206 | 207 | try ( 208 | ModelData modelData = ModelData.fromOnnxFile(path); 209 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 210 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 211 | .addOutputProfile("output", DType.FLOAT); 212 | VariableProfileTable vpt = vptBuilder.build(modelData); 213 | ModelBuilder modelBuilder = Model.builder(vpt) 214 | ) { 215 | modelBuilder.attachExternalBuffer("input", new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}); 216 | 217 | MenohException e = assertThrows( 218 | MenohException.class, () -> modelBuilder.build(modelData, backendName, backendConfig)); 219 | assertAll("backendName is invalid", 220 | () -> assertEquals(ErrorCode.INVALID_BACKEND_NAME, e.getErrorCode()), 221 | () -> assertEquals( 222 | String.format("menoh invalid backend name error: %s (invalid_backend_name)", backendName), 223 | e.getMessage()) 224 | ); 225 | } 226 | } 227 | 228 | @Test 229 | public void buildAndRunModelWithoutAttachingInput() throws Exception { 230 | // [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 231 | final String path = getResourceFilePath("models/and_op.onnx"); 232 | final int batchSize = 4; 233 | final int inputDim = 2; 234 | final float[] inputData1 = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 235 | final float[] inputData2 = new float[] {1f, 1f, 1f, 0f, 0f, 1f, 0f, 0f}; 236 | final int outputDim = 1; 237 | final float[] expectedOutput1 = new float[] {0f, 0f, 0f, 1f}; 238 | final float[] expectedOutput2 = new float[] {1f, 0f, 0f, 0f}; 239 | 240 | try ( 241 | ModelData modelData = ModelData.fromOnnxFile(path); 242 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 243 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 244 | .addOutputProfile("output", DType.FLOAT); 245 | VariableProfileTable vpt = vptBuilder.build(modelData); 246 | ModelBuilder modelBuilder = Model.builder(vpt); // test case 247 | Model model = modelBuilder.build(modelData, "mkldnn", "") 248 | ) { 249 | // you can delete modelData explicitly after building a model 250 | modelData.close(); 251 | 252 | final Variable inputVar = model.variable("input"); 253 | assertAll("input variable", 254 | () -> assertEquals(DType.FLOAT, inputVar.dtype()), 255 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVar.dims()) 256 | ); 257 | 258 | // write input data to the internal buffer allocated by native Menoh 259 | inputVar.buffer().asFloatBuffer().put(inputData1); 260 | model.run(); 261 | 262 | final Variable outputVar = model.variable("output"); 263 | assertAll("output variable", 264 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 265 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 266 | ); 267 | final float[] outputBuf = new float[outputVar.dims()[0]]; 268 | outputVar.buffer().asFloatBuffer().get(outputBuf); 269 | assertArrayEquals(expectedOutput1, outputBuf); 270 | 271 | // rewrite the internal buffer and run again 272 | inputVar.buffer().asFloatBuffer().put(inputData2); 273 | model.run(); 274 | 275 | final Variable outputVar2 = model.variable("output"); 276 | assertAll("output variable", 277 | () -> assertEquals(DType.FLOAT, outputVar2.dtype()), 278 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar2.dims()) 279 | ); 280 | final float[] outputBuf2 = new float[outputVar2.dims()[0]]; 281 | outputVar.buffer().asFloatBuffer().get(outputBuf2); 282 | assertArrayEquals(expectedOutput2, outputBuf2); 283 | } 284 | } 285 | 286 | @Test 287 | public void buildAndRunModelIfInputIsDirectBuffer() throws Exception { 288 | // [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 289 | final String path = getResourceFilePath("models/and_op.onnx"); 290 | final int batchSize = 4; 291 | final int inputDim = 2; 292 | final int inputLen = batchSize * inputDim; 293 | final ByteBuffer inputDataBuf = 294 | ByteBuffer.allocateDirect(inputLen * 4).order(ByteOrder.nativeOrder()); // test case 295 | final float[] inputData1 = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 296 | final float[] inputData2 = new float[] {1f, 1f, 1f, 0f, 0f, 1f, 0f, 0f}; 297 | final int outputDim = 1; 298 | final float[] expectedOutput1 = new float[] {0f, 0f, 0f, 1f}; 299 | final float[] expectedOutput2 = new float[] {1f, 0f, 0f, 0f}; 300 | 301 | try ( 302 | ModelData modelData = ModelData.fromOnnxFile(path); 303 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 304 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 305 | .addOutputProfile("output", DType.FLOAT); 306 | VariableProfileTable vpt = vptBuilder.build(modelData); 307 | ModelBuilder modelBuilder = 308 | Model.builder(vpt).attachExternalBuffer("input", inputDataBuf); 309 | Model model = modelBuilder.build(modelData, "mkldnn", "") 310 | ) { 311 | // you can delete modelData explicitly after building a model 312 | modelData.close(); 313 | 314 | inputDataBuf.asFloatBuffer().put(inputData1); 315 | model.run(); 316 | 317 | final Variable inputVar = model.variable("input"); 318 | assertAll("input variable", 319 | () -> assertEquals(DType.FLOAT, inputVar.dtype()), 320 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVar.dims()) 321 | ); 322 | final int[] inputDims = inputVar.dims(); 323 | final float[] inputBuf = new float[inputDims[0] * inputDims[1]]; 324 | inputVar.buffer().asFloatBuffer().get(inputBuf); 325 | assertArrayEquals(inputData1, inputBuf); 326 | 327 | final Variable outputVar = model.variable("output"); 328 | assertAll("output variable", 329 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 330 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 331 | ); 332 | final float[] outputBuf = new float[outputVar.dims()[0]]; 333 | outputVar.buffer().asFloatBuffer().get(outputBuf); 334 | assertArrayEquals(expectedOutput1, outputBuf); 335 | 336 | // rewrite the direct buffer and run again 337 | inputDataBuf.asFloatBuffer().put(inputData2); 338 | model.run(); 339 | 340 | final Variable outputVar2 = model.variable("output"); 341 | assertAll("output variable", 342 | () -> assertEquals(DType.FLOAT, outputVar2.dtype()), 343 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar2.dims()) 344 | ); 345 | final float[] outputBuf2 = new float[outputVar2.dims()[0]]; 346 | outputVar.buffer().asFloatBuffer().get(outputBuf2); 347 | assertArrayEquals(expectedOutput2, outputBuf2); 348 | } 349 | } 350 | 351 | @Test 352 | public void buildAndRunModelIfInputIsArrayBackedBuffer() throws Exception { 353 | // [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 354 | final String path = getResourceFilePath("models/and_op.onnx"); 355 | final int batchSize = 4; 356 | final int inputDim = 2; 357 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 358 | final ByteBuffer inputDataBuf = 359 | ByteBuffer.allocate(inputData.length * 4).order(ByteOrder.nativeOrder()); // test case 360 | inputDataBuf.asFloatBuffer().put(inputData); 361 | final int outputDim = 1; 362 | final float[] expectedOutput = new float[] {0f, 0f, 0f, 1f}; 363 | 364 | try ( 365 | ModelData modelData = ModelData.fromOnnxFile(path); 366 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 367 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 368 | .addOutputProfile("output", DType.FLOAT); 369 | VariableProfileTable vpt = vptBuilder.build(modelData); 370 | ModelBuilder modelBuilder = 371 | Model.builder(vpt).attachExternalBuffer("input", inputDataBuf); 372 | Model model = modelBuilder.build(modelData, "mkldnn", "") 373 | ) { 374 | // you can delete modelData explicitly after building a model 375 | modelData.close(); 376 | 377 | model.run(); 378 | 379 | final Variable inputVar = model.variable("input"); 380 | assertAll("input variable", 381 | () -> assertEquals(DType.FLOAT, inputVar.dtype()), 382 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVar.dims()) 383 | ); 384 | final int[] inputDims = inputVar.dims(); 385 | final float[] inputBuf = new float[inputDims[0] * inputDims[1]]; 386 | inputVar.buffer().asFloatBuffer().get(inputBuf); 387 | assertArrayEquals(inputData, inputBuf); 388 | 389 | final Variable outputVar = model.variable("output"); 390 | assertAll("output variable", 391 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 392 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 393 | ); 394 | final float[] outputBuf = new float[outputVar.dims()[0]]; 395 | outputVar.buffer().asFloatBuffer().get(outputBuf); 396 | assertArrayEquals(expectedOutput, outputBuf); 397 | } 398 | } 399 | 400 | @Test 401 | public void buildAndRunModelIfInputIsReadOnlyBuffer() throws Exception { 402 | // [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 403 | final String path = getResourceFilePath("models/and_op.onnx"); 404 | final int batchSize = 4; 405 | final int inputDim = 2; 406 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 407 | final ByteBuffer inputDataBuf = 408 | ByteBuffer.allocate(inputData.length * 4).order(ByteOrder.nativeOrder()); 409 | inputDataBuf.asFloatBuffer().put(inputData); 410 | final ByteBuffer readOnlyInputDataBuf = inputDataBuf.asReadOnlyBuffer(); // test case 411 | final int outputDim = 1; 412 | final float[] expectedOutput = new float[] {0f, 0f, 0f, 1f}; 413 | 414 | try ( 415 | ModelData modelData = ModelData.fromOnnxFile(path); 416 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 417 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 418 | .addOutputProfile("output", DType.FLOAT); 419 | VariableProfileTable vpt = vptBuilder.build(modelData); 420 | ModelBuilder modelBuilder = 421 | Model.builder(vpt).attachExternalBuffer("input", readOnlyInputDataBuf); 422 | Model model = modelBuilder.build(modelData, "mkldnn", "") 423 | ) { 424 | // you can delete modelData explicitly after building a model 425 | modelData.close(); 426 | 427 | model.run(); 428 | 429 | final Variable inputVar = model.variable("input"); 430 | assertAll("input variable", 431 | () -> assertEquals(DType.FLOAT, inputVar.dtype()), 432 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVar.dims()) 433 | ); 434 | final int[] inputDims = inputVar.dims(); 435 | final float[] inputBuf = new float[inputDims[0] * inputDims[1]]; 436 | inputVar.buffer().asFloatBuffer().get(inputBuf); 437 | assertArrayEquals(inputData, inputBuf); 438 | 439 | final Variable outputVar = model.variable("output"); 440 | assertAll("output variable", 441 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 442 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 443 | ); 444 | final float[] outputBuf = new float[outputVar.dims()[0]]; 445 | outputVar.buffer().asFloatBuffer().get(outputBuf); 446 | assertArrayEquals(expectedOutput, outputBuf); 447 | } 448 | } 449 | 450 | @Test 451 | public void buildAndRunModelIfInputIsFloatArray() throws Exception { 452 | // [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]] 453 | final String path = getResourceFilePath("models/and_op.onnx"); 454 | final int batchSize = 4; 455 | final int inputDim = 2; 456 | final float[] inputData = new float[] {0f, 0f, 0f, 1f, 1f, 0f, 1f, 1f}; 457 | final int outputDim = 1; 458 | final float[] expectedOutput = new float[] {0f, 0f, 0f, 1f}; 459 | 460 | try ( 461 | ModelData modelData = ModelData.fromOnnxFile(path); 462 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder() 463 | .addInputProfile("input", DType.FLOAT, new int[] {batchSize, inputDim}) 464 | .addOutputProfile("output", DType.FLOAT); 465 | VariableProfileTable vpt = vptBuilder.build(modelData); 466 | ModelBuilder modelBuilder = 467 | Model.builder(vpt).attachExternalBuffer("input", inputData); 468 | Model model = modelBuilder.build(modelData, "mkldnn", "") 469 | ) { 470 | // you can delete modelData explicitly after building a model 471 | modelData.close(); 472 | 473 | model.run(); 474 | 475 | final Variable inputVar = model.variable("input"); 476 | assertAll("input variable", 477 | () -> assertEquals(DType.FLOAT, inputVar.dtype()), 478 | () -> assertArrayEquals(new int[] {batchSize, inputDim}, inputVar.dims()) 479 | ); 480 | final int[] inputDims = inputVar.dims(); 481 | final float[] inputBuf = new float[inputDims[0] * inputDims[1]]; 482 | inputVar.buffer().asFloatBuffer().get(inputBuf); 483 | assertArrayEquals(inputData, inputBuf); 484 | 485 | final Variable outputVar = model.variable("output"); 486 | assertAll("output variable", 487 | () -> assertEquals(DType.FLOAT, outputVar.dtype()), 488 | () -> assertArrayEquals(new int[] {batchSize, outputDim}, outputVar.dims()) 489 | ); 490 | final float[] outputBuf = new float[outputVar.dims()[0]]; 491 | outputVar.buffer().asFloatBuffer().get(outputBuf); 492 | assertArrayEquals(expectedOutput, outputBuf); 493 | } 494 | } 495 | } 496 | --------------------------------------------------------------------------------