├── menoh
├── src
│ ├── test
│ │ ├── resources
│ │ │ └── models
│ │ │ │ ├── invalid_format.onnx
│ │ │ │ ├── and_op.onnx
│ │ │ │ ├── unsupported_onnx_opset_version.onnx
│ │ │ │ └── make_onnx.py
│ │ └── java
│ │ │ └── jp
│ │ │ └── preferred
│ │ │ └── menoh
│ │ │ ├── TestUtils.java
│ │ │ ├── DTypeTest.java
│ │ │ ├── MenohExceptionTest.java
│ │ │ ├── ErrorCodeTest.java
│ │ │ ├── ModelDataTest.java
│ │ │ ├── ModelRunnerTest.java
│ │ │ ├── BufferUtilsTest.java
│ │ │ ├── VariableProfileTableTest.java
│ │ │ └── ModelTest.java
│ └── main
│ │ └── java
│ │ └── jp
│ │ └── preferred
│ │ └── menoh
│ │ ├── MenohRunnerException.java
│ │ ├── VariableProfile.java
│ │ ├── DType.java
│ │ ├── ErrorCode.java
│ │ ├── ModelData.java
│ │ ├── Variable.java
│ │ ├── MenohException.java
│ │ ├── VariableProfileTable.java
│ │ ├── MenohNative.java
│ │ ├── VariableProfileTableBuilder.java
│ │ ├── Model.java
│ │ ├── BufferUtils.java
│ │ ├── ModelRunner.java
│ │ ├── ModelBuilder.java
│ │ └── ModelRunnerBuilder.java
└── pom.xml
├── .gitignore
├── LICENSE
├── menoh-examples
├── README.md
├── pom.xml
└── src
│ └── main
│ └── java
│ └── jp
│ └── preferred
│ └── menoh
│ └── examples
│ └── Vgg16.java
├── LICENSE_THIRD_PARTY
├── README.md
├── pom.xml
└── config
└── checkstyle
└── checkstyle.xml
/menoh/src/test/resources/models/invalid_format.onnx:
--------------------------------------------------------------------------------
1 | Here is an invalid ONNX model file.
--------------------------------------------------------------------------------
/menoh/src/test/resources/models/and_op.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/menoh-java/HEAD/menoh/src/test/resources/models/and_op.onnx
--------------------------------------------------------------------------------
/menoh/src/test/resources/models/unsupported_onnx_opset_version.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/menoh-java/HEAD/menoh/src/test/resources/models/unsupported_onnx_opset_version.onnx
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled class file
2 | *.class
3 | target/
4 | cppbuild/
5 |
6 | # Log file
7 | *.log
8 |
9 | # BlueJ files
10 | *.ctxt
11 |
12 | # Mobile Tools for Java (J2ME)
13 | .mtj.tmp/
14 |
15 | # Package Files #
16 | *.jar
17 | *.war
18 | *.ear
19 | *.zip
20 | *.tar.gz
21 | *.rar
22 |
23 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
24 | hs_err_pid*
25 |
26 | # eclipse specific
27 | .metadata
28 | jrebel.lic
29 | .settings
30 | .classpath
31 | .project
32 |
33 | .ensime*
34 | *.sublime-*
35 | .cache
36 |
37 | # intellij
38 | *.eml
39 | *.iml
40 | *.ipr
41 | *.iws
42 | .*.sw?
43 | .idea
44 |
45 | # image files
46 | *.bmp
47 | *.jpg
48 | *.png
49 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/MenohRunnerException.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | public class MenohRunnerException extends RuntimeException {
4 | public MenohRunnerException(String message) {
5 | super(message);
6 | }
7 |
8 | public MenohRunnerException(String message, Throwable cause) {
9 | super(message, cause);
10 | }
11 |
12 | public MenohRunnerException(Throwable cause) {
13 | super(cause);
14 | }
15 |
16 | protected MenohRunnerException(
17 | String message,
18 | Throwable cause,
19 | boolean enableSuppression,
20 | boolean writableStackTrace) {
21 | super(message, cause, enableSuppression, writableStackTrace);
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/VariableProfile.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | /**
4 | * A variable profile added to {@link VariableProfileTable}.
5 | */
6 | public class VariableProfile {
7 | private final DType dtype;
8 | private final int[] dims;
9 |
10 | VariableProfile(DType dtype, int[] dims) {
11 | this.dtype = dtype;
12 | this.dims = dims;
13 | }
14 |
15 | /**
16 | * A data type of the variable.
17 | */
18 | public DType dtype() {
19 | return this.dtype;
20 | }
21 |
22 | /**
23 | * An array of dimensions.
24 | */
25 | public int[] dims() {
26 | if (dims != null) {
27 | return this.dims.clone();
28 | } else {
29 | return new int[0];
30 | }
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/menoh/src/test/java/jp/preferred/menoh/TestUtils.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import java.io.FileNotFoundException;
4 | import java.io.IOException;
5 | import java.net.URISyntaxException;
6 | import java.net.URL;
7 | import java.nio.file.Paths;
8 |
9 | public class TestUtils {
10 | /**
11 | * Convert resource name into file path.
12 | */
13 | public static String getResourceFilePath(String name) throws IOException, URISyntaxException {
14 | // ref. https://stackoverflow.com/a/17870390/1014818
15 | URL url = TestUtils.class.getClassLoader().getResource(name);
16 | if (url != null) {
17 | return Paths.get(url.toURI()).toFile().getCanonicalPath();
18 | } else {
19 | throw new FileNotFoundException("The specified resource not found: " + name);
20 | }
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Preferred Networks, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/menoh/src/test/java/jp/preferred/menoh/DTypeTest.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import static org.junit.jupiter.api.Assertions.assertAll;
4 | import static org.junit.jupiter.api.Assertions.assertEquals;
5 | import static org.junit.jupiter.api.Assertions.assertThrows;
6 |
7 | import org.junit.jupiter.api.Test;
8 |
9 | public class DTypeTest {
10 | @Test
11 | public void floatDType() {
12 | assertAll("DType.FLOAT",
13 | () -> assertEquals(0, DType.FLOAT.getId()),
14 | () -> assertEquals(4, DType.FLOAT.size())
15 | );
16 | }
17 |
18 | @Test
19 | public void valueOfValidDType() {
20 | assertEquals(DType.FLOAT, DType.valueOf(0));
21 | }
22 |
23 | @Test
24 | public void valueOfIntMinDType() {
25 | MenohException e = assertThrows(MenohException.class, () -> DType.valueOf(Integer.MIN_VALUE));
26 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
27 | }
28 |
29 | @Test
30 | public void valueOfIntMaxDType() {
31 | MenohException e = assertThrows(MenohException.class, () -> DType.valueOf(Integer.MAX_VALUE));
32 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/DType.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | /**
4 | * A data type.
5 | */
6 | public enum DType {
7 | UNDEFINED(-1), // values[0]
8 | FLOAT(0);
9 |
10 | private final int id;
11 |
12 | // cache values() to avoid cloning the backing array every time
13 | private static final DType[] values = values();
14 |
15 | DType(int id) {
16 | this.id = id;
17 | }
18 |
19 | public int getId() {
20 | return id;
21 | }
22 |
23 | /**
24 | * The byte size of this {@link DType}.
25 | */
26 | public int size() throws MenohException {
27 | switch (this) {
28 | case FLOAT:
29 | return 4;
30 | default:
31 | throw new MenohException(
32 | ErrorCode.UNDEFINED, String.format("the size of dtype is unknown: %d", id));
33 | }
34 | }
35 |
36 | /**
37 | * Returns the enum constant of the specified enum type with the specified ID.
38 | */
39 | public static DType valueOf(int value) throws MenohException {
40 | final int index = value + 1;
41 | if (1 <= index && index < values.length) {
42 | final DType ret = values[index];
43 | assert ret.id == value;
44 |
45 | return ret;
46 | } else {
47 | throw new MenohException(ErrorCode.UNDEFINED, "undefined dtype: " + value);
48 | }
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/menoh/src/test/java/jp/preferred/menoh/MenohExceptionTest.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import static jp.preferred.menoh.MenohException.checkError;
4 |
5 | // CHECKSTYLE:OFF
6 | import static org.junit.jupiter.api.Assertions.*;
7 | // CHECKSTYLE:ON
8 |
9 | import org.junit.jupiter.api.Test;
10 |
11 | public class MenohExceptionTest {
12 | @Test
13 | public void checkErrorSuccess() {
14 | checkError(ErrorCode.SUCCESS.getId());
15 | }
16 |
17 | @Test
18 | public void checkErrorFailed() {
19 | MenohException e = assertThrows(MenohException.class, () -> checkError(ErrorCode.STD_ERROR.getId()));
20 | assertAll("ErrorCode.STD_ERROR",
21 | () -> assertEquals(ErrorCode.STD_ERROR, e.getErrorCode()),
22 | () -> assertTrue(e.getMessage().endsWith(" (std_error)"),
23 | String.format("%s doesn't end with \"(std_error)\".", e.getMessage()))
24 | );
25 | }
26 |
27 | @Test
28 | public void checkErrorUnknownErrorCode() {
29 | MenohException e = assertThrows(MenohException.class, () -> checkError(Integer.MAX_VALUE));
30 | assertAll("invalid ErrorCode",
31 | () -> assertEquals(ErrorCode.UNDEFINED, e.getErrorCode()),
32 | () -> assertTrue(e.getMessage().endsWith(" (2147483647)"),
33 | String.format("%s doesn't end with \"(2147483647)\".", e.getMessage()))
34 | );
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/menoh-examples/README.md:
--------------------------------------------------------------------------------
1 | # menoh-examples
2 | *Non*-adversarial code examples in menoh-java.
3 |
4 | ## Requirements
5 | The examples requires the native Menoh Core library in the JNA search path. See the ["Getting Started"](../README.md#getting-started) section in README document.
6 |
7 | ## Usage
8 |
9 | ### VGG16
10 | Classify an input image by using VGG16, a pre-trained CNN model. See the [Menoh tutorial](https://pfnet-research.github.io/menoh/md_tutorial.html) for the details.
11 |
12 | ```bash
13 | $ cd menoh-examples
14 | $ mvn compile
15 | $ mkdir data
16 | $ wget -qN -P data/ https://upload.wikimedia.org/wikipedia/commons/5/54/Light_sussex_hen.jpg
17 |
18 | $ mvn exec:java -Dexec.mainClass=jp.preferred.menoh.examples.Vgg16 -Dexec.args="data/Light_sussex_hen.jpg"
19 | ...
20 |
21 | -21.302792 -34.792274 -11.198693 23.983116 0.8887066 -6.290651 -26.45818 -24.885696 -4.925243 14.338675
22 | top 5 categories are
23 | index: 8, score: 0.9705545, category: n01514859 hen
24 | index: 7, score: 0.028219773, category: n01514668 cock
25 | index: 86, score: 9.522993E-4, category: n01807496 partridge
26 | index: 82, score: 1.6639252E-4, category: n01797886 ruffed grouse, partridge, Bonasa umbellus
27 | index: 97, score: 2.0030011E-5, category: n01847000 drake
28 | ```
29 |
30 | ## License
31 | `menoh-examples` package uses `VGG16.onnx`. It is generated by [ONNX-Chainer](https://github.com/chainer/onnx-chainer) from the pre-trained model which is available from the [website](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) of Visual Geometry Group from University of Oxford under [Creative Commons Attribution License](https://creativecommons.org/licenses/by/4.0/).
32 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/ErrorCode.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | /**
4 | * A representation of menoh_error_code.
5 | */
6 | public enum ErrorCode {
7 | UNDEFINED(Integer.MIN_VALUE), // values[0]
8 | SUCCESS(0),
9 | STD_ERROR(1),
10 | UNKNOWN_ERROR(2),
11 | INVALID_FILENAME(3),
12 | UNSUPPORTED_ONNX_OPSET_VERSION(4),
13 | ONNX_PARSE_ERROR(5),
14 | INVALID_DTYPE(6),
15 | INVALID_ATTRIBUTE_TYPE(7),
16 | UNSUPPORTED_OPERATOR_ATTRIBUTE(8),
17 | DIMENSION_MISMATCH(9),
18 | VARIABLE_NOT_FOUND(10),
19 | INDEX_OUT_OF_RANGE(11),
20 | JSON_PARSE_ERROR(12),
21 | INVALID_BACKEND_NAME(13),
22 | UNSUPPORTED_OPERATOR(14),
23 | FAILED_TO_CONFIGURE_OPERATOR(15),
24 | BACKEND_ERROR(16),
25 | SAME_NAMED_VARIABLE_ALREADY_EXIST(17),
26 | UNSUPPORTED_INPUT_DIMS(18);
27 |
28 | private final int id;
29 |
30 | // cache values() to avoid cloning the backing array every time
31 | private static final ErrorCode[] values = values();
32 |
33 | ErrorCode(int id) {
34 | this.id = id;
35 | }
36 |
37 | public int getId() {
38 | return this.id;
39 | }
40 |
41 | /**
42 | * Returns the enum constant of the specified enum type with the specified ID.
43 | */
44 | public static ErrorCode valueOf(int value) throws MenohException {
45 | final int index = value + 1;
46 | if (1 <= index && index < values.length) {
47 | final ErrorCode ret = values[index];
48 | assert ret.id == value;
49 |
50 | return ret;
51 | } else {
52 | throw new MenohException(UNDEFINED, "The error code is undefined: " + value);
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/ModelData.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import static jp.preferred.menoh.MenohException.checkError;
4 |
5 | import com.sun.jna.Pointer;
6 | import com.sun.jna.ptr.PointerByReference;
7 |
8 | /**
9 | *
A model data.
10 | * 11 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 12 | * in the native heap.
13 | */ 14 | public class ModelData implements AutoCloseable { 15 | private Pointer handle; 16 | 17 | private ModelData(Pointer handle) { 18 | this.handle = handle; 19 | } 20 | 21 | Pointer nativeHandle() { 22 | return this.handle; 23 | } 24 | 25 | /** 26 | * Optimizes this model data. 27 | * 28 | * @return this object 29 | */ 30 | public ModelData optimize(VariableProfileTable vpt) throws MenohException { 31 | checkError(MenohNative.INSTANCE.menoh_model_data_optimize(handle, vpt.nativeHandle())); 32 | 33 | return this; 34 | } 35 | 36 | @Override 37 | public void close() { 38 | synchronized (this) { 39 | if (handle != Pointer.NULL) { 40 | MenohNative.INSTANCE.menoh_delete_model_data(handle); 41 | handle = Pointer.NULL; 42 | } 43 | } 44 | } 45 | 46 | /** 47 | * Loads an ONNX model from the specified file. 48 | */ 49 | public static ModelData fromOnnxFile(String onnxModelPath) throws MenohException { 50 | final PointerByReference handle = new PointerByReference(); 51 | checkError(MenohNative.INSTANCE.menoh_make_model_data_from_onnx(onnxModelPath, handle)); 52 | 53 | return new ModelData(handle.getValue()); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/Variable.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import com.sun.jna.Pointer; 4 | 5 | import java.nio.ByteBuffer; 6 | 7 | /** 8 | * An input or output variable in the model. 9 | */ 10 | public class Variable { 11 | private final DType dtype; 12 | private final int[] dims; 13 | private final Pointer bufferHandle; 14 | 15 | Variable(DType dtype, int[] dims, Pointer bufferHandle) { 16 | this.dtype = dtype; 17 | this.dims = dims; 18 | this.bufferHandle = bufferHandle; 19 | } 20 | 21 | /** 22 | * A data type of the variable. 23 | */ 24 | public DType dtype() { 25 | return this.dtype; 26 | } 27 | 28 | /** 29 | * An array of dimensions. 30 | */ 31 | public int[] dims() { 32 | if (dims != null) { 33 | return this.dims.clone(); 34 | } else { 35 | return new int[0]; 36 | } 37 | } 38 | 39 | /** 40 | * A direct {@link ByteBuffer} which points to the native buffer of the variable. The buffer can be read 41 | * and written via the methods ofByteBuffer before and after running the model.
42 | */
43 | public ByteBuffer buffer() throws MenohException {
44 | return this.bufferHandle.getByteBuffer(0, bufferLength());
45 | }
46 |
47 | /**
48 | * The length of the buffer in bytes.
49 | */
50 | long bufferLength() {
51 | if (dims.length > 0) {
52 | final long elementSize = dtype.size();
53 | long length = 1;
54 | for (int d : dims) {
55 | length *= d;
56 | }
57 |
58 | if (length <= 0) {
59 | throw new MenohException(ErrorCode.UNDEFINED, "buffer is empty");
60 | }
61 |
62 | return elementSize * length;
63 | } else {
64 | throw new MenohException(ErrorCode.UNDEFINED, "buffer is empty");
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/menoh/src/test/java/jp/preferred/menoh/ErrorCodeTest.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import static org.junit.jupiter.api.Assertions.assertAll;
4 | import static org.junit.jupiter.api.Assertions.assertEquals;
5 | import static org.junit.jupiter.api.Assertions.assertThrows;
6 |
7 | import org.junit.jupiter.api.Test;
8 |
9 | public class ErrorCodeTest {
10 | @Test
11 | public void undefinedErrorCode() {
12 | assertEquals(Integer.MIN_VALUE, ErrorCode.UNDEFINED.getId());
13 | }
14 |
15 | @Test
16 | public void successErrorCode() {
17 | assertEquals(0, ErrorCode.SUCCESS.getId());
18 | }
19 |
20 | @Test
21 | public void valueOfValidErrorCode() {
22 | assertAll("valid ErrorCodes",
23 | () -> assertEquals(ErrorCode.SUCCESS, ErrorCode.valueOf(0)),
24 | () -> assertEquals(ErrorCode.UNSUPPORTED_INPUT_DIMS, ErrorCode.valueOf(18))
25 | );
26 | }
27 |
28 | @Test
29 | public void valueOfInvalidErrorCode() {
30 | assertAll("invalid ErrorCodes",
31 | () -> {
32 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(-1));
33 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
34 | },
35 | () -> {
36 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(18 + 1));
37 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
38 | });
39 | }
40 |
41 | @Test
42 | public void valueOfIntMinErrorCode() {
43 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(Integer.MIN_VALUE));
44 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
45 | }
46 |
47 | @Test
48 | public void valueOfIntMaxErrorCode() {
49 | MenohException e = assertThrows(MenohException.class, () -> ErrorCode.valueOf(Integer.MAX_VALUE));
50 | assertEquals(ErrorCode.UNDEFINED, e.getErrorCode());
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/MenohException.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import java.util.Locale;
4 |
5 | public class MenohException extends RuntimeException {
6 | private final ErrorCode errorCode;
7 |
8 | public MenohException(ErrorCode errorCode, String message) {
9 | super(message);
10 | this.errorCode = errorCode;
11 | }
12 |
13 | public MenohException(ErrorCode errorCode, String message, Throwable cause) {
14 | super(message, cause);
15 | this.errorCode = errorCode;
16 | }
17 |
18 | public MenohException(ErrorCode errorCode, Throwable cause) {
19 | super(cause);
20 | this.errorCode = errorCode;
21 | }
22 |
23 | protected MenohException(
24 | ErrorCode errorCode,
25 | String message,
26 | Throwable cause,
27 | boolean enableSuppression,
28 | boolean writableStackTrace) {
29 | super(message, cause, enableSuppression, writableStackTrace);
30 | this.errorCode = errorCode;
31 | }
32 |
33 | public ErrorCode getErrorCode() {
34 | return this.errorCode;
35 | }
36 |
37 | /**
38 | * Throws {@link MenohException} if the errorCode of a native Menoh function is
39 | * non-zero. This method must be called right after the invocation because it uses
40 | * menoh_get_last_error_message.
41 | *
42 | * @param errorCode an error code returned from the Menoh function
43 | */
44 | static void checkError(int errorCode) throws MenohException {
45 | if (errorCode != ErrorCode.SUCCESS.getId()) {
46 | final String errorMessage = MenohNative.INSTANCE.menoh_get_last_error_message();
47 |
48 | ErrorCode ec;
49 | try {
50 | ec = ErrorCode.valueOf(errorCode);
51 | } catch (MenohException e) {
52 | // ErrorCode.valueOf() throws MenohException if the error code is undefined
53 | throw new MenohException(ErrorCode.UNDEFINED, String.format("%s (%d)", errorMessage, errorCode), e);
54 | }
55 |
56 | final String errorCodeName = ec.toString().toLowerCase(Locale.ENGLISH);
57 | throw new MenohException(ec, String.format("%s (%s)", errorMessage, errorCodeName));
58 | }
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/menoh/src/main/java/jp/preferred/menoh/VariableProfileTable.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | import static jp.preferred.menoh.MenohException.checkError;
4 |
5 | import com.sun.jna.Pointer;
6 | import com.sun.jna.ptr.IntByReference;
7 | import com.sun.jna.ptr.PointerByReference;
8 |
9 | /**
10 | * A table which holds {@link VariableProfile}s for the {@link Model}. It will be built by 11 | * {@link VariableProfileTableBuilder}.
12 | * 13 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 14 | * in the native heap.
15 | */ 16 | public class VariableProfileTable implements AutoCloseable { 17 | private Pointer handle; 18 | 19 | VariableProfileTable(Pointer handle) { 20 | this.handle = handle; 21 | } 22 | 23 | Pointer nativeHandle() { 24 | return this.handle; 25 | } 26 | 27 | @Override 28 | public void close() { 29 | synchronized (this) { 30 | if (handle != Pointer.NULL) { 31 | MenohNative.INSTANCE.menoh_delete_variable_profile_table(handle); 32 | handle = Pointer.NULL; 33 | } 34 | } 35 | } 36 | 37 | /** 38 | * Creates a {@link VariableProfileTableBuilder}. 39 | */ 40 | public static VariableProfileTableBuilder builder() throws MenohException { 41 | final PointerByReference ref = new PointerByReference(); 42 | checkError(MenohNative.INSTANCE.menoh_make_variable_profile_table_builder(ref)); 43 | 44 | return new VariableProfileTableBuilder(ref.getValue()); 45 | } 46 | 47 | /** 48 | * A {@link VariableProfile} with the specified name. 49 | */ 50 | public VariableProfile variableProfile(String variableName) throws MenohException { 51 | final IntByReference dtype = new IntByReference(0); 52 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dtype(handle, variableName, dtype)); 53 | 54 | final IntByReference dimsSize = new IntByReference(); 55 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dims_size(handle, variableName, dimsSize)); 56 | 57 | final int[] dims = new int[dimsSize.getValue()]; 58 | 59 | for (int i = 0; i < dimsSize.getValue(); ++i) { 60 | final IntByReference d = new IntByReference(); 61 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_get_dims_at(handle, variableName, i, d)); 62 | dims[i] = d.getValue(); 63 | } 64 | 65 | return new VariableProfile(DType.valueOf(dtype.getValue()), dims); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/MenohNative.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import com.sun.jna.Library; 4 | import com.sun.jna.Native; 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.IntByReference; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | // CHECKSTYLE:OFF 10 | interface MenohNative extends Library { 11 | MenohNative INSTANCE = (MenohNative) Native.loadLibrary("menoh", MenohNative.class); 12 | 13 | String menoh_get_last_error_message(); 14 | 15 | int menoh_make_model_data_from_onnx(String onnx_filename, PointerByReference dst_handle); 16 | 17 | void menoh_delete_model_data(Pointer model_data); 18 | 19 | int menoh_model_data_optimize(Pointer model_data, Pointer variable_profile_table); 20 | 21 | int menoh_make_variable_profile_table_builder(PointerByReference dst_handle); 22 | 23 | void menoh_delete_variable_profile_table_builder(Pointer builder); 24 | 25 | int menoh_variable_profile_table_builder_add_input_profile_dims_2(Pointer builder, String name, int dtype, int num, int size); 26 | 27 | int menoh_variable_profile_table_builder_add_input_profile_dims_4(Pointer builder, String name, int dtype, int num, int channel, int height, int width); 28 | 29 | int menoh_variable_profile_table_builder_add_output_profile(Pointer builder, String name, int dtype); 30 | 31 | int menoh_build_variable_profile_table(Pointer builder, Pointer model_data, PointerByReference dst_handle); 32 | 33 | void menoh_delete_variable_profile_table(Pointer variable_profile_table); 34 | 35 | int menoh_variable_profile_table_get_dtype(Pointer variable_profile_table, String variable_name, IntByReference dst_dtype); 36 | 37 | int menoh_variable_profile_table_get_dims_size(Pointer variable_profile_table, String variable_name, IntByReference dst_size); 38 | 39 | int menoh_variable_profile_table_get_dims_at(Pointer variable_profile_table, String variable_name, int index, IntByReference dst_size); 40 | 41 | int menoh_make_model_builder(Pointer variable_profile_table, PointerByReference dst_handle); 42 | 43 | void menoh_delete_model_builder(Pointer model_builder); 44 | 45 | int menoh_model_builder_attach_external_buffer(Pointer builder, String variable_name, Pointer buffer_handle); 46 | 47 | int menoh_build_model(Pointer builder, Pointer model_data, String backend_name, String backend_config, PointerByReference dst_model_handle); 48 | 49 | void menoh_delete_model(Pointer model); 50 | 51 | int menoh_model_get_variable_dtype(Pointer model, String variable_name, IntByReference dst_dtype); 52 | 53 | int menoh_model_run(Pointer model); 54 | 55 | int menoh_model_get_variable_dims_size(Pointer model, String variable_name, IntByReference dst_size); 56 | 57 | int menoh_model_get_variable_dims_at(Pointer model, String variable_name, int index, IntByReference dst_size); 58 | 59 | int menoh_model_get_variable_buffer_handle(Pointer model, String variable_name, PointerByReference dst_data); 60 | } 61 | // CHECKSTYLE:ON 62 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/VariableProfileTableBuilder.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.PointerByReference; 7 | 8 | /** 9 | *A builder object for {@link VariableProfileTable}.
10 | * 11 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 12 | * in the native heap.
13 | */ 14 | public class VariableProfileTableBuilder implements AutoCloseable { 15 | private Pointer handle; 16 | 17 | VariableProfileTableBuilder(Pointer handle) { 18 | this.handle = handle; 19 | } 20 | 21 | Pointer nativeHandle() { 22 | return this.handle; 23 | } 24 | 25 | @Override 26 | public void close() { 27 | synchronized (this) { 28 | if (handle != Pointer.NULL) { 29 | MenohNative.INSTANCE.menoh_delete_variable_profile_table_builder(handle); 30 | handle = Pointer.NULL; 31 | } 32 | } 33 | } 34 | 35 | /** 36 | * Adds an input profile to configure the specified variable in the model. 37 | * 38 | * @return this object 39 | */ 40 | public VariableProfileTableBuilder addInputProfile(String name, DType dtype, int[] dims) throws MenohException { 41 | if (dims.length == 2) { 42 | checkError( 43 | MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_input_profile_dims_2( 44 | handle, name, dtype.getId(), dims[0], dims[1])); 45 | } else if (dims.length == 4) { 46 | checkError( 47 | MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_input_profile_dims_4( 48 | handle, name, dtype.getId(), dims[0], dims[1], dims[2], dims[3])); 49 | } else { 50 | throw new MenohException( 51 | ErrorCode.UNDEFINED, 52 | String.format("%s has an invalid dims size: %d (it must be 2 or 4)", name, dims.length)); 53 | } 54 | 55 | return this; 56 | } 57 | 58 | /** 59 | * Adds an output profile to configure the specified variable in the model. 60 | * 61 | * @return this object 62 | */ 63 | public VariableProfileTableBuilder addOutputProfile(String name, DType dtype) throws MenohException { 64 | checkError(MenohNative.INSTANCE.menoh_variable_profile_table_builder_add_output_profile( 65 | handle, name, dtype.getId())); 66 | 67 | return this; 68 | } 69 | 70 | /** 71 | * Builds a {@link VariableProfileTable}. It is used for making {@link ModelBuilder}. 72 | */ 73 | public VariableProfileTable build(ModelData modelData) throws MenohException { 74 | final PointerByReference ref = new PointerByReference(); 75 | checkError(MenohNative.INSTANCE.menoh_build_variable_profile_table( 76 | this.handle, modelData.nativeHandle(), ref)); 77 | 78 | return new VariableProfileTable(ref.getValue()); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /menoh/src/main/java/jp/preferred/menoh/Model.java: -------------------------------------------------------------------------------- 1 | package jp.preferred.menoh; 2 | 3 | import static jp.preferred.menoh.MenohException.checkError; 4 | 5 | import com.sun.jna.Pointer; 6 | import com.sun.jna.ptr.IntByReference; 7 | import com.sun.jna.ptr.PointerByReference; 8 | 9 | import java.util.List; 10 | 11 | /** 12 | *A representation of the model. It will be built by {@link ModelBuilder}.
13 | * 14 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 15 | * in the native heap.
16 | */ 17 | public class Model implements AutoCloseable { 18 | private Pointer handle; 19 | 20 | /** 21 | * A reference to the pointers to prevent them from getting garbage collected. 22 | */ 23 | private final ListCopies a buffer to a native memory.
13 | * 14 | *If the buffer is direct, it will be attached to the model directly without copying.
15 | * Otherwise, it copies the content of the buffer to a newly allocated memory in the native heap ranging
16 | * from position() to (limit() - 1) without changing its position.
Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because
19 | * the native byte order of your platform may differ from JVM.
buffer is null or empty
25 | */
26 | static Pointer copyToNativeMemory(final ByteBuffer buffer) {
27 | if (buffer == null || buffer.remaining() <= 0) {
28 | throw new IllegalArgumentException("buffer must not be null or empty");
29 | }
30 |
31 | final int length = buffer.remaining();
32 |
33 | if (buffer.isDirect()) {
34 | final int offset = buffer.position();
35 |
36 | // return a pointer to the direct buffer without copying
37 | return Native.getDirectBufferPointer(buffer).share(offset, length);
38 | } else {
39 | final Memory mem = new Memory(length);
40 |
41 | int index;
42 | byte[] bytes;
43 | if (buffer.hasArray()) { // it is array-backed and not read-only
44 | index = buffer.arrayOffset() + buffer.position();
45 | bytes = buffer.array();
46 | } else {
47 | index = 0;
48 | // use duplicated buffer to avoid changing `position`
49 | bytes = new byte[length];
50 | buffer.duplicate().get(bytes);
51 | }
52 | mem.write(0, bytes, index, length);
53 |
54 | return mem.share(0, length);
55 | }
56 | }
57 |
58 | /**
59 | * Copies the array to a newly allocated memory in the native heap.
60 | * 61 | * @param values the non-empty array from which to copy 62 | * @param offset the array index from which to start copying 63 | * @param length the number of elements fromvalues that must be copied
64 | *
65 | * @return the pointer to the allocated native memory
66 | * @throws IllegalArgumentException if values is null or empty
67 | */
68 | static Pointer copyToNativeMemory(final float[] values, final int offset, final int length) {
69 | if (values == null || values.length <= 0) {
70 | throw new IllegalArgumentException("values must not be null or empty");
71 | }
72 |
73 | final Memory mem = new Memory((long) length * 4);
74 | mem.write(0, values, offset, length);
75 |
76 | return mem.share(0, length);
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/menoh/src/test/resources/models/make_onnx.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import numpy as np
5 | from unittest import mock
6 |
7 | import chainer
8 | import chainer.functions as F
9 | import chainer.links as L
10 |
11 | import onnx_chainer
12 |
13 |
14 | # e.g. [[0, 0], [0, 1], [1, 0], [1, 1]] -> [[0], [0], [0], [1]]
15 | class AND(chainer.Chain):
16 | def __init__(self):
17 | super().__init__()
18 | with self.init_scope():
19 | self.fc1 = L.Linear(2, 1)
20 |
21 | self.fc1.W.array[:] = np.array([[1.0, 1.0]])
22 | self.fc1.b.array[:] = np.array([[-1.0]])
23 |
24 | def __call__(self, x):
25 | x.node._onnx_name = 'input'
26 | h = F.absolute(F.relu(self.fc1(x)))
27 | h.node._onnx_name = 'output'
28 | return h
29 |
30 |
31 | # e.g. [[0, 1, 2], [3, 4, 5]] -> [[0, 0, 15, 96, 177], [0, 0, 51, 312, 573]]
32 | class MLP(chainer.Chain):
33 |
34 | def __init__(self):
35 | super().__init__()
36 | with self.init_scope():
37 | self.fc1 = L.Linear(3, 4)
38 | self.fc2 = L.Linear(4, 5)
39 |
40 | self.fc1.W.array[:] = np.arange(-6, 6).reshape((4, 3))
41 | self.fc1.b.array[:] = np.arange(-2, 2)
42 |
43 | self.fc2.W.array[:] = np.arange(-10, 10).reshape((5, 4))
44 | self.fc2.b.array[:] = np.arange(-2, 3)
45 |
46 | def __call__(self, x):
47 | x.node._onnx_name = 'input'
48 | h = F.relu(self.fc1(x))
49 | h.node._onnx_name = 'fc1'
50 | h = F.relu(self.fc2(h))
51 | h.node._onnx_name = 'fc2'
52 | return h
53 |
54 |
55 | class IDGenerator(object):
56 |
57 | def __init__(self):
58 | # keep original
59 | self._id = id
60 |
61 | def __call__(self, obj):
62 | return getattr(obj, '_onnx_name', self._id(obj))
63 |
64 |
65 | def main(logger):
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('--model')
68 | parser.add_argument('--out', required=False)
69 | args = parser.parse_args()
70 |
71 | if args.out is None:
72 | output_filename = args.model + ".onnx"
73 | else:
74 | output_filename = args.out
75 |
76 | logger.info("Generating `{}` model and save it to `{}`".format(args.model, output_filename))
77 |
78 | try:
79 | if args.model == 'and_op':
80 | model = AND()
81 | x = np.empty((1, 2), dtype=np.float32)
82 | with chainer.using_config('train', False), \
83 | mock.patch('builtins.id', IDGenerator()):
84 | onnx_chainer.export(model, x, filename=output_filename)
85 |
86 | elif args.model == 'mlp':
87 | model = MLP()
88 | x = np.empty((1, 3), dtype=np.float32)
89 | with chainer.using_config('train', False), \
90 | mock.patch('builtins.id', IDGenerator()):
91 | onnx_chainer.export(model, x, filename=output_filename)
92 |
93 | except Exception:
94 | logger.exception("An error occurred during generation of the model")
95 |
96 |
97 | if __name__ == '__main__':
98 | logging.basicConfig()
99 | logger = logging.getLogger(__name__)
100 | logger.setLevel(logging.INFO)
101 |
102 | main(logger)
103 |
--------------------------------------------------------------------------------
/LICENSE_THIRD_PARTY:
--------------------------------------------------------------------------------
1 | This software depends on
2 |
3 | ==== Intel(R) MKL-DNN ====
4 |
5 | Intel(R) MKL-DNN distributed under Apache License 2.0 licence
6 |
7 | https://github.com/intel/mkl-dnn
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | ==== Xbyak ====
12 |
13 | Xbyak distributed under 3-clause BSD licence
14 |
15 | https://github.com/intel/mkl-dnn/blob/master/src/cpu/xbyak/COPYRIGHT
16 |
17 | ==== mklml ====
18 |
19 | Copyright (c) 2016-2018, Intel Corporation
20 | All rights reserved.
21 |
22 | Redistribution and use in source and binary forms, with or without
23 | modification, are permitted provided that the following conditions are met:
24 |
25 | * Redistributions of source code must retain the above copyright notice,
26 | this list of conditions and the following disclaimer.
27 | * Redistributions in binary form must reproduce the above copyright
28 | notice, this list of conditions and the following disclaimer in the
29 | documentation and/or other materials provided with the distribution.
30 | * Neither the name of Intel Corporation nor the names of its contributors
31 | may be used to endorse or promote products derived from this software
32 | without specific prior written permission.
33 |
34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44 |
45 | ==== Protocol Buffers ====
46 |
47 | Copyright 2008 Google Inc. All rights reserved.
48 |
49 | Redistribution and use in source and binary forms, with or without
50 | modification, are permitted provided that the following conditions are
51 | met:
52 |
53 | * Redistributions of source code must retain the above copyright
54 | notice, this list of conditions and the following disclaimer.
55 | * Redistributions in binary form must reproduce the above
56 | copyright notice, this list of conditions and the following disclaimer
57 | in the documentation and/or other materials provided with the
58 | distribution.
59 | * Neither the name of Google Inc. nor the names of its
60 | contributors may be used to endorse or promote products derived from
61 | this software without specific prior written permission.
62 |
63 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
64 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
65 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
66 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
67 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
68 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
69 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
70 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
71 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
72 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
73 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
74 |
75 | Code generated by the Protocol Buffer compiler is owned by the owner
76 | of the input file used when generating it. This code is not
77 | standalone and requires a support library to be linked with it. This
78 | support library is itself covered by the above license.
79 |
--------------------------------------------------------------------------------
/menoh/src/test/java/jp/preferred/menoh/ModelDataTest.java:
--------------------------------------------------------------------------------
1 | package jp.preferred.menoh;
2 |
3 | // CHECKSTYLE:OFF
4 | import static jp.preferred.menoh.TestUtils.*;
5 | import static org.junit.jupiter.api.Assertions.*;
6 | // CHECKSTYLE:ON
7 |
8 | import org.junit.jupiter.api.Test;
9 |
10 | public class ModelDataTest {
11 | @Test
12 | public void makeFromValidOnnxFile() throws Exception {
13 | final String path = getResourceFilePath("models/and_op.onnx");
14 |
15 | try (ModelData modelData = ModelData.fromOnnxFile(path)) {
16 | assertNotNull(modelData.nativeHandle());
17 | }
18 | }
19 |
20 | @Test
21 | public void closeModelData() throws Exception {
22 | final String path = getResourceFilePath("models/and_op.onnx");
23 |
24 | final ModelData modelData = ModelData.fromOnnxFile(path);
25 | try {
26 | assertNotNull(modelData.nativeHandle());
27 | } finally {
28 | modelData.close();
29 | assertNull(modelData.nativeHandle());
30 |
31 | // close() is an idempotent operation
32 | modelData.close();
33 | }
34 | }
35 |
36 | @Test
37 | public void makeFromNonExistentOnnxFile() {
38 | MenohException e = assertThrows(
39 | MenohException.class, () -> ModelData.fromOnnxFile("__NON_EXISTENT_FILENAME__"));
40 | assertAll("non-existent onnx file",
41 | () -> assertEquals(ErrorCode.INVALID_FILENAME, e.getErrorCode()),
42 | () -> assertEquals(
43 | "menoh invalid filename error: __NON_EXISTENT_FILENAME__ (invalid_filename)",
44 | e.getMessage())
45 | );
46 | }
47 |
48 | @Test
49 | public void makeFromInvalidOnnxFile() throws Exception {
50 | final String path = getResourceFilePath("models/invalid_format.onnx");
51 |
52 | MenohException e = assertThrows(MenohException.class, () -> ModelData.fromOnnxFile(path));
53 | assertAll("invalid onnx file",
54 | () -> assertEquals(ErrorCode.ONNX_PARSE_ERROR, e.getErrorCode()),
55 | () -> assertEquals(
56 | String.format("menoh onnx parse error: %s (onnx_parse_error)", path),
57 | e.getMessage())
58 | );
59 | }
60 |
61 | @Test
62 | public void makeFromUnsupportedOnnxOpsetVersionFile() throws Exception {
63 | // Note: This file is a copy of and_op.onnx which is edited the last byte to 127 (0x7e)
64 | final String path = getResourceFilePath("models/unsupported_onnx_opset_version.onnx");
65 |
66 | MenohException e = assertThrows(MenohException.class, () -> ModelData.fromOnnxFile(path));
67 | assertAll("invalid onnx file",
68 | () -> assertEquals(ErrorCode.UNSUPPORTED_ONNX_OPSET_VERSION, e.getErrorCode()),
69 | () -> assertEquals(
70 | String.format(
71 | "menoh unsupported onnx opset version error: %s has "
72 | + "onnx opset version %d > %d (unsupported_onnx_opset_version)",
73 | path, 127, 7),
74 | e.getMessage())
75 | );
76 | }
77 |
78 | @Test
79 | public void optimizeModelData() throws Exception {
80 | final String path = getResourceFilePath("models/and_op.onnx");
81 |
82 | try (
83 | ModelData modelData = ModelData.fromOnnxFile(path);
84 | VariableProfileTableBuilder vptBuilder = VariableProfileTable.builder()
85 | .addInputProfile("input", DType.FLOAT, new int[] {1, 2})
86 | .addOutputProfile("output", DType.FLOAT);
87 | VariableProfileTable vpt = vptBuilder.build(modelData)
88 | ) {
89 | modelData.optimize(vpt);
90 | }
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/menoh/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | A convenient wrapper for building and running {@link Model}.
11 | * 12 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 13 | * in the native heap.
14 | */ 15 | public class ModelRunner implements AutoCloseable { 16 | private final Model model; 17 | 18 | private static final String DEFAULT_BACKEND_NAME = "mkldnn"; 19 | 20 | private static final String DEFAULT_BACKEND_CONFIG = ""; 21 | 22 | ModelRunner(Model model) { 23 | this.model = model; 24 | } 25 | 26 | /** 27 | * Returns the underlying {@link Model}. 28 | */ 29 | Model model() { 30 | return this.model; 31 | } 32 | 33 | @Override 34 | public void close() { 35 | model.close(); 36 | } 37 | 38 | /** 39 | * Loads an ONNX model from the specified file. 40 | */ 41 | public static ModelRunnerBuilder fromOnnxFile(String path) { 42 | ModelData modelData = null; 43 | VariableProfileTableBuilder vptBuilder = null; 44 | try { 45 | modelData = ModelData.fromOnnxFile(path); 46 | vptBuilder = VariableProfileTable.builder(); 47 | 48 | return new ModelRunnerBuilder( 49 | modelData, 50 | vptBuilder, 51 | DEFAULT_BACKEND_NAME, 52 | DEFAULT_BACKEND_CONFIG, 53 | new HashMapRun this model after assigning a non-empty array to the specified variable.
75 | * 76 | * @param name the name of the input variable 77 | * @param values the values to be copied to the input variable 78 | */ 79 | public void run(String name, float[] values) { 80 | run(name, values, 0, values.length); 81 | } 82 | 83 | /** 84 | *Run this model after assigning a non-empty array to the specified variable. It copies the content
85 | * ranging from offset to (offset + length - 1).
Run this model after assigning a non-empty buffer to the specified variable. It copies the content
99 | * ranging from position() to (limit() - 1) without changing them, even if
100 | * the buffer is direct unlike {@link ModelRunnerBuilder#attachExternalBuffer(String, ByteBuffer)}.
101 | *
Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because
104 | * the native byte order of your platform may differ from JVM.
Run this model after assigning non-empty buffers to the specified variables. It copies the content
115 | * ranging from position() to (limit() - 1) without changing them, even if
116 | * the buffer is direct unlike {@link ModelRunnerBuilder#attachExternalBuffer(String, ByteBuffer)}.
117 | *
Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because
120 | * the native byte order of your platform may differ from JVM.
A builder object for {@link Model}.
16 | * 17 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 18 | * in the native heap.
19 | */ 20 | public class ModelBuilder implements AutoCloseable { 21 | private Pointer handle; 22 | 23 | /** 24 | * A reference to the pointers to prevent them from getting garbage collected. 25 | */ 26 | private final ListAttaches a non-empty external buffer to the specified variable.
56 | * 57 | *If the specified buffer is direct, it will be attached to the model directly without
58 | * copying. Otherwise, it copies the content to a newly allocated buffer in the native heap ranging from
59 | * position() to (limit() - 1) without changing its position.
The attached buffer can be accessed through {@link Model#variable(String)}.
62 | * 63 | *Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because
64 | * the native byte order of your platform may differ from JVM.
buffer is null or empty
71 | */
72 | public ModelBuilder attachExternalBuffer(String variableName, ByteBuffer buffer) throws MenohException {
73 | final Pointer bufferHandle = copyToNativeMemory(buffer);
74 | synchronized (this) {
75 | externalBuffers.add(bufferHandle);
76 | }
77 |
78 | return attachImpl(variableName, bufferHandle);
79 | }
80 |
81 | /**
82 | * Attaches a non-empty external buffer to the specified variable. It also copies the content of the
83 | * values to a newly allocated buffer in the native heap.
The buffer can be accessed through {@link Model#variable(String)}.
86 | * 87 | * @param variableName the name of the variable 88 | * @param values the byte buffer from which to copy 89 | * @return this object 90 | * 91 | * @throws IllegalArgumentException ifvalues is null or empty
92 | */
93 | public ModelBuilder attachExternalBuffer(String variableName, float[] values) throws MenohException {
94 | return attachExternalBuffer(variableName, values, 0, values.length);
95 | }
96 |
97 | /**
98 | * Attaches a non-empty external buffer to the specified variable. It also copies the content of the
99 | * values to a newly allocated buffer in the native heap ranging from offset
100 | * to (offset + length - 1).
The buffer can be accessed through {@link Model#variable(String)}.
103 | * 104 | * @param variableName the name of the variable 105 | * @param values the byte buffer from which to copy 106 | * @param offset the array index from which to start copying 107 | * @param length the number of elements fromvalues that must be copied
108 | * @return this object
109 | *
110 | * @throws IllegalArgumentException if values is null or empty
111 | */
112 | public ModelBuilder attachExternalBuffer(
113 | String variableName, float[] values, int offset, int length) throws MenohException {
114 | final Pointer bufferHandle = copyToNativeMemory(values, offset, length);
115 | synchronized (this) {
116 | externalBuffers.add(bufferHandle);
117 | }
118 |
119 | return attachImpl(variableName, bufferHandle);
120 | }
121 |
122 | private ModelBuilder attachImpl(String variableName, Pointer bufferHandle) throws MenohException {
123 | checkError(MenohNative.INSTANCE.menoh_model_builder_attach_external_buffer(
124 | handle, variableName, bufferHandle));
125 |
126 | return this;
127 | }
128 |
129 | /**
130 | * Builds a {@link Model} to run() by using the specified backend (e.g. "mkldnn").
Menoh will allocate a new buffer for input and output variables to which an external buffer is not 133 | * attached. It can be accessed via {@link Model#variable(String)}.
134 | */ 135 | public Model build(ModelData modelData, String backendName, String backendConfig) throws MenohException { 136 | final PointerByReference ref = new PointerByReference(); 137 | checkError(MenohNative.INSTANCE.menoh_build_model( 138 | this.handle, modelData.nativeHandle(), backendName, backendConfig, ref)); 139 | 140 | synchronized (this) { 141 | return new Model(ref.getValue(), new ArrayList<>(this.externalBuffers)); 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # menoh-java 2 | *menoh-java* enables you to build your own Deep Neural Network (DNN) application with a few lines of code in Java. 3 | 4 | This is a Java binding for [Menoh](https://github.com/pfnet-research/menoh/) DNN inference library, which supports [ONNX](http://onnx.ai/) model format. 5 | 6 | ## Getting Started 7 | 8 | ### Add a dependency 9 | Using Gradle: 10 | 11 | ```groovy 12 | dependencies { 13 | implementation 'jp.preferred.menoh:menoh:0.1.0' 14 | } 15 | ``` 16 | 17 | Using Maven: 18 | 19 | ```xml 20 |A builder object for {@link ModelRunner}.
9 | * 10 | *Make sure to {@link #close()} this object after finishing the process to free the underlying memory 11 | * in the native heap.
12 | */ 13 | public class ModelRunnerBuilder implements AutoCloseable { 14 | private final ModelData modelData; 15 | 16 | private final VariableProfileTableBuilder vptBuilder; 17 | 18 | private String backendName; 19 | 20 | private String backendConfig; 21 | 22 | private final MapAttaches a non-empty external buffer to the specified variable.
98 | * 99 | *If the specified buffer is direct, it will be attached to the model directly without
100 | * copying. Otherwise, it copies the content to a newly allocated buffer in the native heap ranging from
101 | * position() to (limit() - 1) without changing its position.
The buffer can be accessed through {@link Model#variable(String)}.
104 | * 105 | *Note that the order() of the buffer should be {@link ByteOrder#nativeOrder()} because
106 | * the native byte order of your platform may differ from JVM.
buffer is null or empty
113 | */
114 | public ModelRunnerBuilder attachExternalBuffer(String variableName, ByteBuffer buffer) throws MenohException {
115 | externalBuffers.put(variableName, buffer);
116 | return this;
117 | }
118 |
119 | /**
120 | * Attaches a non-empty external buffer to the specified variable. It also copies the content of the
121 | * values to a newly allocated buffer in the native heap.
The buffer can be accessed through {@link Model#variable(String)}.
124 | * 125 | * @param variableName the name of the variable 126 | * @param values the byte buffer from which to copy 127 | * @return this object 128 | * 129 | * @throws IllegalArgumentException ifvalues is null or empty
130 | */
131 | public ModelRunnerBuilder attachExternalBuffer(String variableName, float[] values) throws MenohException {
132 | return attachExternalBuffer(variableName, values, 0, values.length);
133 | }
134 |
135 | /**
136 | * Attaches a non-empty external buffer to the specified variable. It also copies the content of the
137 | * values to a newly allocated buffer in the native heap ranging from offset
138 | * to (offset + length - 1).
The buffer can be accessed through {@link Model#variable(String)}.
141 | * 142 | * @param variableName the name of the variable 143 | * @param values the byte buffer from which to copy 144 | * @param offset the array index from which to start copying 145 | * @param length the number of elements fromvalues that must be copied
146 | * @return this object
147 | *
148 | * @throws IllegalArgumentException if values is null or empty
149 | */
150 | public ModelRunnerBuilder attachExternalBuffer(
151 | String variableName,
152 | float[] values,
153 | int offset,
154 | int length) throws MenohException {
155 | // copy the array to a direct buffer
156 | final ByteBuffer buffer = ByteBuffer.allocateDirect(length * 4).order(ByteOrder.nativeOrder());
157 | buffer.asFloatBuffer().put(values, offset, length);
158 |
159 | externalBuffers.put(variableName, buffer);
160 | return this;
161 | }
162 |
163 | /**
164 | * Builds a {@link ModelRunner} to run() by using the specified backend (e.g. "mkldnn").
Menoh will allocate a new buffer for input and output variables to which an external buffer is not
167 | * attached. It can be accessed via {@link Model#variable(String)} in the ModelRunner object.
Classify an input image by using VGG16, a pre-trained CNN model.
See the 22 | * Menoh tutorial for the details. 23 | */ 24 | public class Vgg16 { 25 | public static void main(String[] args) throws Exception { 26 | if (args.length <= 0) { 27 | System.err.println("You must specify a filename of the input image in the argument."); 28 | System.exit(1); 29 | } 30 | 31 | System.err.println("Working Directory = " + System.getProperty("user.dir")); 32 | if (Boolean.getBoolean("jna.debug_load")) { 33 | NativeLibrary.getProcess(); // a trick to initialize `NativeLibrary` explicitly in this place 34 | System.err.println("jna.library.path: " + System.getProperty("jna.library.path")); 35 | System.err.println("jna.platform.library.path: " + System.getProperty("jna.platform.library.path")); 36 | } 37 | 38 | final String conv11InName = "140326425860192"; 39 | final String fc6OutName = "140326200777584"; 40 | final String softmaxOutName = "140326200803680"; 41 | 42 | final int batchSize = 1; 43 | final int channelNum = 3; 44 | final int height = 224; 45 | final int width = 224; 46 | 47 | final String inputImagePath = args[0]; 48 | final String onnxModelPath = getResourceFilePath("data/VGG16.onnx"); 49 | final String synsetWordsPath = getResourceFilePath("data/synset_words.txt"); 50 | 51 | // Read and pre-process an input 52 | final BufferedImage image = ImageIO.read(new File(inputImagePath)); 53 | if (image == null) { 54 | throw new Exception("Invalid input image path: " + inputImagePath); 55 | } 56 | final float[] imageData = renderToNctw(cropAndResize(image, width, height)); 57 | 58 | // Build a ModelRunner 59 | try ( 60 | // Note: You must `close()` the runner and builder to free the native memory explicitly 61 | ModelRunnerBuilder builder = ModelRunner 62 | // Load ONNX model data 63 | .fromOnnxFile(onnxModelPath) 64 | 65 | // Define input profile (name, dtype, dims) and output profile (name, dtype) 66 | // Menoh calculates dims of outputs automatically at build time 67 | .addInputProfile(conv11InName, DType.FLOAT, new int[] {batchSize, channelNum, height, width}) 68 | .addOutputProfile(fc6OutName, DType.FLOAT) 69 | .addOutputProfile(softmaxOutName, DType.FLOAT) 70 | 71 | // Configure backend 72 | .backendName("mkldnn") 73 | .backendConfig(""); 74 | ModelRunner runner = builder.build() 75 | ) { 76 | // The builder can be deleted explicitly after building a model runner 77 | builder.close(); 78 | 79 | // Run the inference 80 | runner.run(conv11InName, imageData); 81 | 82 | // Get output of `fc6` (the first hidden FC layer after CNNs in VGG16) 83 | final Variable fc6Out = runner.variable(fc6OutName); 84 | 85 | final FloatBuffer fc6OutFloatBuf = fc6Out.buffer().asFloatBuffer(); 86 | for (int i = 0; i < 10; i++) { 87 | System.out.print(Float.toString(fc6OutFloatBuf.get(i)) + " "); 88 | } 89 | System.out.println(); 90 | 91 | // Get output of `softmax` (the output layer in VGG16) 92 | final Variable softmaxOut = runner.variable(softmaxOutName); 93 | 94 | final int[] softmaxOutDims = softmaxOut.dims(); 95 | final float[] scores = new float[softmaxOutDims[1]]; 96 | 97 | // Note: use `get()` instead of `array()` because it is a direct buffer 98 | softmaxOut.buffer().asFloatBuffer().get(scores); 99 | 100 | final int topK = 5; 101 | final List