├── .gitignore ├── LICENSE ├── README.md ├── build.gradle ├── examples ├── build.gradle ├── gradlew ├── gradlew.bat └── src │ └── main │ └── java │ └── ai │ └── djl │ └── jsr381 │ ├── classification │ ├── BinaryClassifierExample.java │ ├── CatDogRecognition.java │ └── ImageClassifierExample.java │ └── detection │ └── ObjectDetectorExample.java ├── gradle.properties ├── gradle └── wrapper │ ├── .gitignore │ ├── GradleWrapperDownloader.java │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── jsr381 ├── build.gradle ├── gradlew ├── gradlew.bat └── src │ ├── main │ ├── java │ │ └── ai │ │ │ └── djl │ │ │ └── jsr381 │ │ │ ├── classification │ │ │ ├── SimpleBinaryClassifier.java │ │ │ └── SimpleImageClassifier.java │ │ │ ├── dataset │ │ │ └── CsvDataset.java │ │ │ ├── detection │ │ │ └── SimpleObjectDetector.java │ │ │ └── spi │ │ │ ├── DjlBinaryClassifierFactory.java │ │ │ ├── DjlImageClassifierFactory.java │ │ │ ├── DjlImageFactoryService.java │ │ │ ├── DjlImplementationService.java │ │ │ └── DjlServiceProvider.java │ └── resources │ │ └── META-INF │ │ └── services │ │ ├── javax.visrec.spi.BinaryClassifierFactory │ │ ├── javax.visrec.spi.ImageClassifierFactory │ │ └── javax.visrec.spi.ServiceProvider │ └── test │ ├── java │ └── ai │ │ └── djl │ │ └── jsr381 │ │ ├── classification │ │ ├── BinaryClassifierTest.java │ │ └── ImageClassifierTest.java │ │ └── detection │ │ └── ObjectDetectorTest.java │ └── resources │ ├── 0.png │ ├── mlp │ ├── mlp-0000.params │ ├── mlp-symbol.json │ └── synset.txt │ └── spam.csv ├── settings.gradle └── tools └── gradle └── formatter.gradle /.gitignore: -------------------------------------------------------------------------------- 1 | .gradle 2 | .DS_Store 3 | .idea 4 | *.iml 5 | .ipynb_checkpoints 6 | build 7 | libs 8 | *.gz 9 | *.params 10 | *.zip 11 | *.jar 12 | *.so 13 | *.dylib 14 | *.dll 15 | *.class 16 | datasets/ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisRec API JSR381 implementation using DJL 2 | 3 | The repository contains the source code of the reference implementation of JSR 381, 4 | a standardization in Java for Visual Recognition. 5 | 6 | The Visual Recognition API JSR #381 is a software development standard recognized by the Java 7 | Community Process (JCP) that simplifies and standardizes a set of APIs familiar to Java developers 8 | for classifying and recognizing objects in images using machine learning. Beside classes specific 9 | for visual recognition tasks, it provides general abstractions for machine learning tasks like 10 | classification, regression and data set, and reusable design which can be applied for machine 11 | learning systems in other domains. At the current stage it provides basic hello world examples 12 | for supported machine learning tasks (classification and regression) and image classification. 13 | 14 | This reference implementation is based on Deep Java Library (DJL) available at 15 | 16 | https://github.com/awslabs/djl 17 | 18 | Specification for VisRec API is available at 19 | 20 | https://github.com/JavaVisRec/visrec-api 21 | 22 | ## Getting Started Guide 23 | For step by step guide and additional info see getting started guide at 24 | 25 | https://github.com/JavaVisRec/visrec-api/wiki/Getting-Started-Guide 26 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | allprojects { 2 | repositories { 3 | mavenCentral() 4 | maven { 5 | url "https://oss.sonatype.org/content/repositories/snapshots/" 6 | } 7 | } 8 | 9 | apply plugin: 'idea' 10 | idea { 11 | module { 12 | outputDir = file('build/classes/java/main') 13 | testOutputDir = file('build/classes/java/test') 14 | // inheritOutputDirs = true 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "java" 3 | id "application" 4 | } 5 | 6 | dependencies { 7 | implementation project(":jsr381") 8 | implementation "org.slf4j:slf4j-simple:1.7.36" 9 | 10 | testImplementation('org.testng:testng:7.6.0') { 11 | exclude group: "junit", module: "junit" 12 | } 13 | } 14 | 15 | test { 16 | // Use TestNG for unit tests 17 | useTestNG() 18 | } 19 | 20 | application { 21 | mainClass = System.getProperty("main", "ai.djl.jsr381.detection.ObjectDetectorExample") 22 | } 23 | 24 | run { 25 | systemProperties System.getProperties() 26 | systemProperties.remove("user.dir") 27 | systemProperty("file.encoding", "UTF-8") 28 | } 29 | tasks.distTar.enabled = false 30 | 31 | apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle") 32 | -------------------------------------------------------------------------------- /examples/gradlew: -------------------------------------------------------------------------------- 1 | ../gradlew -------------------------------------------------------------------------------- /examples/gradlew.bat: -------------------------------------------------------------------------------- 1 | ../gradlew.bat -------------------------------------------------------------------------------- /examples/src/main/java/ai/djl/jsr381/classification/BinaryClassifierExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.classification; 14 | 15 | import java.nio.file.Path; 16 | import java.nio.file.Paths; 17 | 18 | import javax.visrec.ml.classification.BinaryClassifier; 19 | import javax.visrec.ml.classification.NeuralNetBinaryClassifier; 20 | import javax.visrec.ml.model.ModelCreationException; 21 | 22 | public class BinaryClassifierExample { 23 | 24 | public static void main(String[] args) throws ModelCreationException { 25 | Path trainingFile = Paths.get("../jsr381/src/test/resources/spam.csv"); 26 | BinaryClassifier spamClassifier = 27 | NeuralNetBinaryClassifier.builder() 28 | .inputClass(float[].class) 29 | .inputsNum(57) 30 | .hiddenLayers(5) 31 | .maxEpochs(2) 32 | .trainingPath(trainingFile) 33 | .build(); 34 | 35 | // create test email feature 36 | float[] emailFeatures = new float[57]; 37 | emailFeatures[56] = 1; 38 | 39 | Float result = spamClassifier.classify(emailFeatures); 40 | System.out.println(result); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /examples/src/main/java/ai/djl/jsr381/classification/CatDogRecognition.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.classification; 14 | 15 | import ai.djl.util.ZipUtils; 16 | 17 | import java.awt.image.BufferedImage; 18 | import java.io.IOException; 19 | import java.io.InputStream; 20 | import java.net.URL; 21 | import java.nio.file.Files; 22 | import java.nio.file.Path; 23 | import java.nio.file.Paths; 24 | import java.util.Map; 25 | 26 | import javax.visrec.ml.classification.ImageClassifier; 27 | import javax.visrec.ml.classification.NeuralNetImageClassifier; 28 | import javax.visrec.ml.model.ModelCreationException; 29 | 30 | public class CatDogRecognition { 31 | 32 | public static void main(String[] args) throws IOException, ModelCreationException { 33 | Path trainingFile = downloadTrainingData(); 34 | Path modelDir = Paths.get("build/model"); 35 | 36 | ImageClassifier classifier = 37 | NeuralNetImageClassifier.builder() 38 | .inputClass(BufferedImage.class) 39 | .imageHeight(128) 40 | .imageWidth(128) 41 | .trainingFile(trainingFile) 42 | .exportModel(modelDir) 43 | .maxEpochs(20) 44 | .build(); 45 | 46 | Path input = trainingFile.resolve("cat/cat_1.png"); 47 | Map result = classifier.classify(input); 48 | for (Map.Entry entry : result.entrySet()) { 49 | System.out.println(entry.getKey() + ": " + entry.getValue()); 50 | } 51 | } 52 | 53 | private static Path downloadTrainingData() throws IOException { 54 | String link = 55 | "https://github.com/JavaVisRec/jsr381-examples-datasets/raw/master/cats_and_dogs_training_data_png.zip"; 56 | URL url = new URL(link); 57 | Path dir = Paths.get("datasets", "cats_and_dogs"); 58 | if (!Files.exists(dir)) { 59 | Files.createDirectories(dir); 60 | try (InputStream is = url.openStream()) { 61 | ZipUtils.unzip(is, dir); 62 | } 63 | } 64 | return dir.resolve("training"); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /examples/src/main/java/ai/djl/jsr381/classification/ImageClassifierExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.classification; 14 | 15 | import java.awt.image.BufferedImage; 16 | import java.nio.file.Path; 17 | import java.nio.file.Paths; 18 | import java.util.Map; 19 | 20 | import javax.visrec.ml.classification.ImageClassifier; 21 | import javax.visrec.ml.classification.NeuralNetImageClassifier; 22 | import javax.visrec.ml.model.ModelCreationException; 23 | 24 | public class ImageClassifierExample { 25 | 26 | public static void main(String[] args) throws ModelCreationException { 27 | Path input = Paths.get("../jsr381/src/test/resources/0.png"); 28 | 29 | // use pre-trained mlp model 30 | Path modelDir = Paths.get("../jsr381/src/test/resources/mlp"); 31 | 32 | ImageClassifier classifier = 33 | NeuralNetImageClassifier.builder() 34 | .inputClass(BufferedImage.class) 35 | .imageHeight(28) 36 | .imageWidth(28) 37 | .importModel(modelDir) 38 | .build(); 39 | 40 | Map result = classifier.classify(input); 41 | for (Map.Entry entry : result.entrySet()) { 42 | System.out.println(entry.getKey() + ": " + entry.getValue()); 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /examples/src/main/java/ai/djl/jsr381/detection/ObjectDetectorExample.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.detection; 14 | 15 | import ai.djl.Application; 16 | import ai.djl.ModelException; 17 | import ai.djl.modality.cv.Image; 18 | import ai.djl.modality.cv.output.DetectedObjects; 19 | import ai.djl.repository.zoo.Criteria; 20 | import ai.djl.repository.zoo.ZooModel; 21 | 22 | import java.awt.image.BufferedImage; 23 | import java.io.IOException; 24 | import java.net.URL; 25 | import java.util.List; 26 | import java.util.Map; 27 | 28 | import javax.imageio.ImageIO; 29 | import javax.visrec.ml.detection.BoundingBox; 30 | 31 | public class ObjectDetectorExample { 32 | 33 | public static void main(String[] args) throws IOException, ModelException { 34 | Criteria criteria = 35 | Criteria.builder() 36 | .setTypes(Image.class, DetectedObjects.class) 37 | .optApplication(Application.CV.OBJECT_DETECTION) 38 | .optArtifactId("yolo") 39 | .optArgument("threshold", 0.3) 40 | .build(); 41 | try (ZooModel model = criteria.loadModel()) { 42 | SimpleObjectDetector objectDetector = new SimpleObjectDetector(model); 43 | URL imageUrl = 44 | new URL("https://djl-ai.s3.amazonaws.com/resources/images/dog_bike_car.jpg"); 45 | BufferedImage inputImage = ImageIO.read(imageUrl); 46 | Map> result = objectDetector.detectObject(inputImage); 47 | 48 | for (List boundingBoxes : result.values()) { 49 | for (BoundingBox boundingBox : boundingBoxes) { 50 | System.out.println(boundingBox.toString()); 51 | } 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | org.gradle.daemon=true 2 | org.gradle.jvmargs=-Xmx1024M 3 | 4 | systemProp.org.gradle.internal.http.socketTimeout=120000 5 | systemProp.org.gradle.internal.http.connectionTimeout=60000 6 | 7 | # FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308 8 | systemProp.org.gradle.internal.publish.checksums.insecure=true 9 | -------------------------------------------------------------------------------- /gradle/wrapper/.gitignore: -------------------------------------------------------------------------------- 1 | gradle-wrapper.jar 2 | -------------------------------------------------------------------------------- /gradle/wrapper/GradleWrapperDownloader.java: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed to the Apache Software Foundation (ASF) under one 3 | or more contributor license agreements. See the NOTICE file 4 | distributed with this work for additional information 5 | regarding copyright ownership. The ASF licenses this file 6 | to you under the Apache License, Version 2.0 (the 7 | "License"); you may not use this file except in compliance 8 | with the License. You may obtain a copy of the License at 9 | 10 | http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | Unless required by applicable law or agreed to in writing, 13 | software distributed under the License is distributed on an 14 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | KIND, either express or implied. See the License for the 16 | specific language governing permissions and limitations 17 | under the License. 18 | */ 19 | 20 | import java.net.*; 21 | import java.io.*; 22 | import java.nio.channels.*; 23 | import java.util.Properties; 24 | 25 | public class GradleWrapperDownloader { 26 | 27 | /** 28 | * URL to download the gradle-wrapper.jar from. 29 | */ 30 | private static final String DEFAULT_DOWNLOAD_URL = 31 | "https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar"; 32 | 33 | /** 34 | * Path where the gradle-wrapper.jar will be saved to. 35 | */ 36 | private static final String GRADLE_WRAPPER_JAR_PATH = 37 | "gradle/wrapper/gradle-wrapper.jar"; 38 | 39 | public static void main(String args[]) { 40 | System.out.println("- Downloader started"); 41 | File baseDirectory = new File(args[0]); 42 | System.out.println("- Using base directory: " + baseDirectory.getAbsolutePath()); 43 | 44 | String url = DEFAULT_DOWNLOAD_URL; 45 | System.out.println("- Downloading from: : " + url); 46 | 47 | File outputFile = new File(baseDirectory.getAbsolutePath(), GRADLE_WRAPPER_JAR_PATH); 48 | if(!outputFile.getParentFile().exists()) { 49 | if(!outputFile.getParentFile().mkdirs()) { 50 | System.out.println( 51 | "- ERROR creating output direcrory '" + outputFile.getParentFile().getAbsolutePath() + "'"); 52 | } 53 | } 54 | System.out.println("- Downloading to: " + outputFile.getAbsolutePath()); 55 | try { 56 | downloadFileFromURL(url, outputFile); 57 | System.out.println("Done"); 58 | System.exit(0); 59 | } catch (Throwable e) { 60 | System.out.println("- Error downloading"); 61 | e.printStackTrace(); 62 | System.exit(1); 63 | } 64 | } 65 | 66 | private static void downloadFileFromURL(String urlString, File destination) throws Exception { 67 | URL website = new URL(urlString); 68 | ReadableByteChannel rbc; 69 | rbc = Channels.newChannel(website.openStream()); 70 | FileOutputStream fos = new FileOutputStream(destination); 71 | fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); 72 | fos.close(); 73 | rbc.close(); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Wed Jan 16 22:46:13 PST 2019 2 | distributionBase=GRADLE_USER_HOME 3 | distributionPath=wrapper/dists 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | ########################################################################################## 108 | # Extension to allow automatically downloading the gradle-wrapper.jar 109 | # This allows using the maven wrapper in projects that prohibit checking in binary data. 110 | ########################################################################################## 111 | WRAPPER_JAR_PATH="$APP_HOME/gradle/wrapper/gradle-wrapper.jar" 112 | if [ ! -r "${WRAPPER_JAR_PATH}" ]; then 113 | jarUrl="https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar" 114 | if command -v wget > /dev/null; then 115 | wget -q "${jarUrl}" -O "${WRAPPER_JAR_PATH}" 116 | elif command -v curl > /dev/null; then 117 | curl -s -o "${WRAPPER_JAR_PATH}" "$jarUrl" 118 | else 119 | javaClass="$APP_HOME/gradle/wrapper/GradleWrapperDownloader.java" 120 | if [ -e "$javaClass" ]; then 121 | if [ ! -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then 122 | # Compiling the Java class 123 | ("${JAVACMD}c" "$javaClass") 124 | fi 125 | if [ -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then 126 | ("$JAVACMD" -cp gradle/wrapper GradleWrapperDownloader "$APP_HOME") 127 | fi 128 | fi 129 | fi 130 | fi 131 | ########################################################################################## 132 | # End of extension 133 | ########################################################################################## 134 | 135 | 136 | # For Darwin, add options to specify how the application appears in the dock 137 | if $darwin; then 138 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 139 | fi 140 | 141 | # For Cygwin, switch paths to Windows format before running java 142 | if $cygwin ; then 143 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 144 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 145 | JAVACMD=`cygpath --unix "$JAVACMD"` 146 | 147 | # We build the pattern for arguments to be converted via cygpath 148 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 149 | SEP="" 150 | for dir in $ROOTDIRSRAW ; do 151 | ROOTDIRS="$ROOTDIRS$SEP$dir" 152 | SEP="|" 153 | done 154 | OURCYGPATTERN="(^($ROOTDIRS))" 155 | # Add a user-defined pattern to the cygpath arguments 156 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 157 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 158 | fi 159 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 160 | i=0 161 | for arg in "$@" ; do 162 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 163 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 164 | 165 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 166 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 167 | else 168 | eval `echo args$i`="\"$arg\"" 169 | fi 170 | i=$((i+1)) 171 | done 172 | case $i in 173 | (0) set -- ;; 174 | (1) set -- "$args0" ;; 175 | (2) set -- "$args0" "$args1" ;; 176 | (3) set -- "$args0" "$args1" "$args2" ;; 177 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 178 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 179 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 180 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 181 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 182 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 183 | esac 184 | fi 185 | 186 | # Escape application args 187 | save () { 188 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 189 | echo " " 190 | } 191 | APP_ARGS=$(save "$@") 192 | 193 | # Collect all arguments for the java command, following the shell quoting and substitution rules 194 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 195 | 196 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 197 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 198 | cd "$(dirname "$0")" 199 | fi 200 | 201 | exec "$JAVACMD" "$@" 202 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | set DOWNLOAD_URL="https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar" 68 | 69 | @rem Extension to allow automatically downloading the gradle-wrapper.jar 70 | @rem This allows using the gradle wrapper in projects that prohibit checking in binary data. 71 | if exist %CLASSPATH% ( 72 | echo Found %CLASSPATH% 73 | ) else ( 74 | echo Couldn't find %CLASSPATH%, downloading it ... 75 | echo Downloading from: %DOWNLOAD_URL% 76 | powershell -Command "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; (New-Object Net.WebClient).DownloadFile('%DOWNLOAD_URL%', '%CLASSPATH%')" 77 | 78 | echo Finished downloading %CLASSPATH% 79 | ) 80 | @rem End of extension 81 | 82 | @rem Execute Gradle 83 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 84 | 85 | :end 86 | @rem End local scope for the variables with windows NT shell 87 | if "%ERRORLEVEL%"=="0" goto mainEnd 88 | 89 | :fail 90 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 91 | rem the _cmd.exe /c_ return code! 92 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 93 | exit /b 1 94 | 95 | :mainEnd 96 | if "%OS%"=="Windows_NT" endlocal 97 | 98 | :omega 99 | -------------------------------------------------------------------------------- /jsr381/build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "java-library" 3 | id "maven-publish" 4 | id "signing" 5 | } 6 | 7 | repositories { 8 | mavenCentral() 9 | } 10 | 11 | group "ai.djl.jsr381" 12 | boolean isRelease = project.hasProperty("release") || project.hasProperty("staging") 13 | version = "0.8.0" + (isRelease ? "" : "-SNAPSHOT") 14 | 15 | dependencies { 16 | api 'javax.visrec:visrec-api:1.0.5' 17 | api platform("ai.djl:bom:0.23.0") 18 | api "ai.djl:api" 19 | api "ai.djl:basicdataset" 20 | api "ai.djl:model-zoo" 21 | api "ai.djl.mxnet:mxnet-model-zoo" 22 | api 'ai.djl.mxnet:mxnet-engine' 23 | 24 | testImplementation 'org.slf4j:slf4j-simple:2.0.1' 25 | testImplementation('org.testng:testng:7.6.1') { 26 | exclude group: "junit", module: "junit" 27 | } 28 | } 29 | 30 | java { 31 | withJavadocJar() 32 | withSourcesJar() 33 | } 34 | 35 | test { 36 | // Use TestNG for unit tests 37 | useTestNG() 38 | } 39 | 40 | apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle") 41 | 42 | project.tasks.withType(GenerateModuleMetadata) { 43 | enabled = false 44 | } 45 | 46 | signing { 47 | required(project.hasProperty("staging") || project.hasProperty("snapshot")) 48 | def signingKey = findProperty("signingKey") 49 | def signingPassword = findProperty("signingPassword") 50 | useInMemoryPgpKeys(signingKey, signingPassword) 51 | sign publishing.publications 52 | } 53 | 54 | if (JavaVersion.current() != JavaVersion.VERSION_1_8) { 55 | if (gradle.startParameter.taskNames.contains("publish")) { 56 | throw new GradleException("JDK 1.8 is required to run publish task.") 57 | } 58 | return 59 | } 60 | 61 | publishing { 62 | publications { 63 | maven(MavenPublication) { 64 | from components.java 65 | artifacts = [jar, javadocJar, sourcesJar] 66 | pom { 67 | name = "Deep Java Library implementation of JSR-381" 68 | description = "Deep Java Library implementation of JSR-381" 69 | url = "http://www.djl.ai/" 70 | 71 | packaging = "jar" 72 | 73 | licenses { 74 | license { 75 | name = 'The Apache License, Version 2.0' 76 | url = 'https://www.apache.org/licenses/LICENSE-2.0' 77 | } 78 | } 79 | 80 | scm { 81 | connection = "scm:git:git@github.com:awslabs/djl.git" 82 | developerConnection = "scm:git:git@github.com:awslabs/djl.git" 83 | url = "https://github.com/awslabs/djl" 84 | tag = "HEAD" 85 | } 86 | 87 | developers { 88 | developer { 89 | name = "DJL.AI Team" 90 | email = "djl-dev@amazon.com" 91 | organization = "Amazon AI" 92 | organizationUrl = "https://amazon.com" 93 | } 94 | } 95 | } 96 | } 97 | } 98 | 99 | repositories { 100 | maven { 101 | if (project.hasProperty("snapshot")) { 102 | name = "snapshot" 103 | url = "https://oss.sonatype.org/content/repositories/snapshots/" 104 | credentials { 105 | username = findProperty("ossrhUsername") 106 | password = findProperty("ossrhPassword") 107 | } 108 | } else if (project.hasProperty("staging")) { 109 | name = "staging" 110 | url = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" 111 | credentials { 112 | username = findProperty("ossrhUsername") 113 | password = findProperty("ossrhPassword") 114 | } 115 | } else { 116 | name = "local" 117 | url = "build/repo" 118 | } 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /jsr381/gradlew: -------------------------------------------------------------------------------- 1 | ../gradlew -------------------------------------------------------------------------------- /jsr381/gradlew.bat: -------------------------------------------------------------------------------- 1 | ../gradlew.bat -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/classification/SimpleBinaryClassifier.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.classification; 2 | 3 | import ai.djl.inference.Predictor; 4 | import ai.djl.repository.zoo.ZooModel; 5 | import ai.djl.translate.TranslateException; 6 | 7 | import javax.visrec.ml.classification.BinaryClassifier; 8 | import javax.visrec.ml.classification.ClassificationException; 9 | 10 | /** Implementation of a {@link BinaryClassifier} with DJL. */ 11 | public class SimpleBinaryClassifier implements BinaryClassifier { 12 | 13 | private ZooModel model; 14 | 15 | public SimpleBinaryClassifier(ZooModel model) { 16 | this.model = model; 17 | } 18 | 19 | @Override 20 | public Float classify(float[] input) throws ClassificationException { 21 | try (Predictor predictor = model.newPredictor()) { 22 | return predictor.predict(input); 23 | } catch (TranslateException e) { 24 | throw new ClassificationException("Failed to process output", e); 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/classification/SimpleImageClassifier.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.classification; 2 | 3 | import ai.djl.inference.Predictor; 4 | import ai.djl.modality.Classifications; 5 | import ai.djl.modality.cv.BufferedImageFactory; 6 | import ai.djl.modality.cv.Image; 7 | import ai.djl.repository.zoo.ZooModel; 8 | import ai.djl.translate.TranslateException; 9 | 10 | import java.awt.image.BufferedImage; 11 | import java.io.IOException; 12 | import java.io.InputStream; 13 | import java.nio.file.Path; 14 | import java.util.List; 15 | import java.util.Map; 16 | import java.util.stream.Collectors; 17 | 18 | import javax.imageio.ImageIO; 19 | import javax.visrec.ml.classification.ClassificationException; 20 | import javax.visrec.ml.classification.ImageClassifier; 21 | 22 | /** 23 | * Implementation of abstract image classifier for BufferedImage-s using DJL. 24 | * 25 | * @author Frank Liu 26 | */ 27 | public class SimpleImageClassifier implements ImageClassifier { 28 | 29 | private final ZooModel model; 30 | private final int topK; 31 | 32 | public SimpleImageClassifier(ZooModel model, int topK) { 33 | this.model = model; 34 | this.topK = topK; 35 | } 36 | 37 | @Override 38 | public Map classify(Path input) throws ClassificationException { 39 | try { 40 | return classify(ImageIO.read(input.toFile())); 41 | } catch (IOException e) { 42 | throw new ClassificationException("Couldn't transform input into a BufferedImage", e); 43 | } 44 | } 45 | 46 | @Override 47 | public Map classify(InputStream input) throws ClassificationException { 48 | try { 49 | return classify(ImageIO.read(input)); 50 | } catch (IOException e) { 51 | throw new ClassificationException("Couldn't transform input into a BufferedImage", e); 52 | } 53 | } 54 | 55 | @Override 56 | public Map classify(BufferedImage input) throws ClassificationException { 57 | try (Predictor predictor = model.newPredictor()) { 58 | Classifications classifications = 59 | predictor.predict(BufferedImageFactory.getInstance().fromImage(input)); 60 | List list = classifications.topK(topK); 61 | return list.stream() 62 | .collect( 63 | Collectors.toMap( 64 | Classifications.Classification::getClassName, 65 | x -> (float) x.getProbability())); 66 | } catch (TranslateException e) { 67 | throw new ClassificationException("Failed to process output", e); 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/dataset/CsvDataset.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.dataset; 2 | 3 | import ai.djl.ndarray.NDArray; 4 | import ai.djl.ndarray.NDList; 5 | import ai.djl.ndarray.NDManager; 6 | import ai.djl.training.dataset.RandomAccessDataset; 7 | import ai.djl.training.dataset.Record; 8 | import ai.djl.util.Progress; 9 | 10 | import org.apache.commons.csv.CSVFormat; 11 | import org.apache.commons.csv.CSVParser; 12 | import org.apache.commons.csv.CSVRecord; 13 | 14 | import java.io.IOException; 15 | import java.io.Reader; 16 | import java.nio.file.Files; 17 | import java.nio.file.Path; 18 | import java.util.List; 19 | 20 | public class CsvDataset extends RandomAccessDataset { 21 | 22 | private List records; 23 | 24 | private CsvDataset(Builder builder) { 25 | super(builder); 26 | records = builder.records; 27 | } 28 | 29 | public static Builder builder() { 30 | return new Builder(); 31 | } 32 | 33 | @Override 34 | public Record get(NDManager manager, long index) { 35 | CSVRecord record = records.get(Math.toIntExact(index)); 36 | int size = record.size(); 37 | float[] data = new float[size - 1]; 38 | for (int i = 0; i < size - 1; ++i) { 39 | data[i] = Float.parseFloat(record.get(i)); 40 | } 41 | NDArray datum = manager.create(data); 42 | NDArray label = manager.create(Float.parseFloat(record.get(size - 1))); 43 | return new Record(new NDList(datum), new NDList(label)); 44 | } 45 | 46 | @Override 47 | public long availableSize() { 48 | return records.size(); 49 | } 50 | 51 | @Override 52 | public void prepare(Progress progress) {} 53 | 54 | public static final class Builder extends BaseBuilder { 55 | 56 | List records; 57 | 58 | private Path file; 59 | 60 | @Override 61 | protected Builder self() { 62 | return this; 63 | } 64 | 65 | public Builder setCsvFile(Path file) { 66 | this.file = file; 67 | return this; 68 | } 69 | 70 | public CsvDataset build() throws IOException { 71 | try (Reader reader = Files.newBufferedReader(file); 72 | CSVParser csvParser = 73 | new CSVParser( 74 | reader, 75 | CSVFormat.DEFAULT 76 | .builder() 77 | .setHeader() 78 | .setSkipHeaderRecord(true) 79 | .setIgnoreHeaderCase(true) 80 | .setTrim(true) 81 | .build())) { 82 | records = csvParser.getRecords(); 83 | } 84 | return new CsvDataset(this); 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/detection/SimpleObjectDetector.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.detection; 2 | 3 | import ai.djl.inference.Predictor; 4 | import ai.djl.modality.cv.BufferedImageFactory; 5 | import ai.djl.modality.cv.Image; 6 | import ai.djl.modality.cv.output.DetectedObjects; 7 | import ai.djl.modality.cv.output.Rectangle; 8 | import ai.djl.repository.zoo.ZooModel; 9 | import ai.djl.translate.TranslateException; 10 | 11 | import java.awt.image.BufferedImage; 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | import java.util.Map; 15 | import java.util.concurrent.ConcurrentHashMap; 16 | 17 | import javax.visrec.ml.classification.ClassificationException; 18 | import javax.visrec.ml.detection.BoundingBox; 19 | import javax.visrec.ml.detection.ObjectDetector; 20 | 21 | /** A simple object detector implemented with DJL. */ 22 | public class SimpleObjectDetector implements ObjectDetector { 23 | 24 | private final ZooModel model; 25 | 26 | public SimpleObjectDetector(ZooModel model) { 27 | this.model = model; 28 | } 29 | 30 | @Override 31 | public Map> detectObject(BufferedImage image) 32 | throws ClassificationException { 33 | try (Predictor predictor = model.newPredictor()) { 34 | DetectedObjects detectedObjects = 35 | predictor.predict(BufferedImageFactory.getInstance().fromImage(image)); 36 | Map> ret = new ConcurrentHashMap<>(); 37 | 38 | int imageWidth = image.getWidth(); 39 | int imageHeight = image.getHeight(); 40 | 41 | List detections = detectedObjects.items(); 42 | for (DetectedObjects.DetectedObject detection : detections) { 43 | String className = detection.getClassName(); 44 | float probability = (float) detection.getProbability(); 45 | Rectangle rect = detection.getBoundingBox().getBounds(); 46 | 47 | int x = (int) (rect.getX() * imageWidth); 48 | int y = (int) (rect.getY() * imageHeight); 49 | float w = (float) (rect.getWidth() * imageWidth); 50 | float h = (float) (rect.getHeight() * imageHeight); 51 | 52 | ret.compute( 53 | className, 54 | (k, list) -> { 55 | if (list == null) { 56 | list = new ArrayList<>(); 57 | } 58 | list.add(new BoundingBox(className, probability, x, y, w, h)); 59 | return list; 60 | }); 61 | } 62 | 63 | return ret; 64 | } catch (TranslateException e) { 65 | throw new ClassificationException("Failed to process output", e); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/spi/DjlBinaryClassifierFactory.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.spi; 2 | 3 | import ai.djl.Model; 4 | import ai.djl.jsr381.classification.SimpleBinaryClassifier; 5 | import ai.djl.jsr381.dataset.CsvDataset; 6 | import ai.djl.metric.Metrics; 7 | import ai.djl.ndarray.NDArray; 8 | import ai.djl.ndarray.NDList; 9 | import ai.djl.ndarray.NDManager; 10 | import ai.djl.ndarray.types.Shape; 11 | import ai.djl.nn.Activation; 12 | import ai.djl.nn.Blocks; 13 | import ai.djl.nn.SequentialBlock; 14 | import ai.djl.nn.core.Linear; 15 | import ai.djl.nn.norm.BatchNorm; 16 | import ai.djl.repository.zoo.ZooModel; 17 | import ai.djl.training.DefaultTrainingConfig; 18 | import ai.djl.training.EasyTrain; 19 | import ai.djl.training.Trainer; 20 | import ai.djl.training.dataset.Batch; 21 | import ai.djl.training.dataset.RandomAccessDataset; 22 | import ai.djl.training.evaluator.BinaryAccuracy; 23 | import ai.djl.training.listener.TrainingListener; 24 | import ai.djl.training.loss.Loss; 25 | import ai.djl.translate.Batchifier; 26 | import ai.djl.translate.TranslateException; 27 | import ai.djl.translate.Translator; 28 | import ai.djl.translate.TranslatorContext; 29 | 30 | import java.io.IOException; 31 | 32 | import javax.visrec.ml.classification.BinaryClassifier; 33 | import javax.visrec.ml.classification.NeuralNetBinaryClassifier; 34 | import javax.visrec.ml.model.ModelCreationException; 35 | import javax.visrec.spi.BinaryClassifierFactory; 36 | 37 | public class DjlBinaryClassifierFactory implements BinaryClassifierFactory { 38 | 39 | @Override 40 | public Class getTargetClass() { 41 | return float[].class; 42 | } 43 | 44 | @Override 45 | public BinaryClassifier create(NeuralNetBinaryClassifier.BuildingBlock block) 46 | throws ModelCreationException { 47 | int inputSize = block.getInputsNum(); 48 | int[] hiddenLayers = block.getHiddenLayers(); 49 | int epochs = block.getMaxEpochs(); 50 | int batchSize = 32; 51 | 52 | SequentialBlock mlp = new SequentialBlock().add(Blocks.batchFlattenBlock(inputSize)); 53 | for (int size : hiddenLayers) { 54 | mlp.add(Linear.builder().setUnits(size).build()).add(Activation::relu); 55 | } 56 | mlp.add(BatchNorm.builder().build()) 57 | .add(Linear.builder().setUnits(1).build()) 58 | .add(arrays -> new NDList(arrays.singletonOrThrow().flatten())); 59 | 60 | Model model = Model.newInstance("binaryClassifier"); // TODO generate better model name 61 | model.setBlock(mlp); 62 | 63 | RandomAccessDataset[] dataset; 64 | try { 65 | CsvDataset csv = 66 | CsvDataset.builder() 67 | .setCsvFile(block.getTrainingPath()) 68 | .setSampling(batchSize, true) 69 | .build(); 70 | dataset = csv.randomSplit(8, 2); 71 | } catch (IOException | TranslateException e) { 72 | throw new ModelCreationException("Failed to load dataset.", e); 73 | } 74 | 75 | // setup training configuration 76 | DefaultTrainingConfig config = 77 | new DefaultTrainingConfig(Loss.sigmoidBinaryCrossEntropyLoss()) 78 | .addTrainingListeners(TrainingListener.Defaults.logging()) 79 | .addEvaluator(new BinaryAccuracy()); 80 | 81 | try (Trainer trainer = model.newTrainer(config)) { 82 | trainer.setMetrics(new Metrics()); 83 | Shape inputShape = new Shape(1, inputSize); 84 | trainer.initialize(inputShape); 85 | 86 | for (int i = 0; i < epochs; i++) { 87 | for (Batch batch : trainer.iterateDataset(dataset[0])) { 88 | EasyTrain.trainBatch(trainer, batch); 89 | trainer.step(); 90 | batch.close(); 91 | } 92 | 93 | for (Batch batch : trainer.iterateDataset(dataset[1])) { 94 | EasyTrain.validateBatch(trainer, batch); 95 | batch.close(); 96 | } 97 | 98 | // reset training and validation evaluators at end of epoch 99 | trainer.notifyListeners(listener -> listener.onEpoch(trainer)); 100 | } 101 | } catch (IOException | TranslateException e) { 102 | throw new ModelCreationException("Failed to process dataset.", e); 103 | } 104 | 105 | return new SimpleBinaryClassifier(new ZooModel<>(model, new BinaryClassifierTranslator())); 106 | } 107 | 108 | private static final class BinaryClassifierTranslator implements Translator { 109 | 110 | @Override 111 | public NDList processInput(TranslatorContext ctx, float[] input) { 112 | NDManager manager = ctx.getNDManager(); 113 | NDArray array = manager.create(input); 114 | return new NDList(array); 115 | } 116 | 117 | @Override 118 | public Float processOutput(TranslatorContext ctx, NDList list) { 119 | return list.singletonOrThrow().getFloat(); 120 | } 121 | 122 | @Override 123 | public Batchifier getBatchifier() { 124 | return Batchifier.STACK; 125 | } 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/spi/DjlImageClassifierFactory.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.spi; 2 | 3 | import ai.djl.MalformedModelException; 4 | import ai.djl.Model; 5 | import ai.djl.basicdataset.cv.classification.ImageFolder; 6 | import ai.djl.basicmodelzoo.cv.classification.ResNetV1; 7 | import ai.djl.jsr381.classification.SimpleImageClassifier; 8 | import ai.djl.metric.Metrics; 9 | import ai.djl.modality.Classifications; 10 | import ai.djl.modality.cv.Image; 11 | import ai.djl.modality.cv.Image.Flag; 12 | import ai.djl.modality.cv.transform.CenterCrop; 13 | import ai.djl.modality.cv.transform.Resize; 14 | import ai.djl.modality.cv.transform.ToTensor; 15 | import ai.djl.modality.cv.translator.ImageClassificationTranslator; 16 | import ai.djl.ndarray.NDArray; 17 | import ai.djl.ndarray.types.Shape; 18 | import ai.djl.nn.Block; 19 | import ai.djl.repository.zoo.ZooModel; 20 | import ai.djl.training.DefaultTrainingConfig; 21 | import ai.djl.training.EasyTrain; 22 | import ai.djl.training.Trainer; 23 | import ai.djl.training.dataset.Batch; 24 | import ai.djl.training.dataset.RandomAccessDataset; 25 | import ai.djl.training.evaluator.Accuracy; 26 | import ai.djl.training.listener.TrainingListener; 27 | import ai.djl.training.loss.Loss; 28 | import ai.djl.translate.TranslateException; 29 | import ai.djl.translate.Translator; 30 | 31 | import org.slf4j.Logger; 32 | import org.slf4j.LoggerFactory; 33 | 34 | import java.awt.image.BufferedImage; 35 | import java.io.IOException; 36 | import java.nio.file.Path; 37 | import java.util.List; 38 | 39 | import javax.visrec.ml.classification.ImageClassifier; 40 | import javax.visrec.ml.classification.NeuralNetImageClassifier; 41 | import javax.visrec.ml.model.ModelCreationException; 42 | import javax.visrec.spi.ImageClassifierFactory; 43 | 44 | public class DjlImageClassifierFactory implements ImageClassifierFactory { 45 | 46 | private static final Logger logger = LoggerFactory.getLogger(DjlImageClassifierFactory.class); 47 | 48 | @Override 49 | public Class getImageClass() { 50 | return BufferedImage.class; 51 | } 52 | 53 | @Override 54 | public ImageClassifier create( 55 | NeuralNetImageClassifier.BuildingBlock block) 56 | throws ModelCreationException { 57 | int width = block.getImageWidth(); 58 | int height = block.getImageHeight(); 59 | 60 | Model model = Model.newInstance("imageClassifier"); // TODO generate better model name 61 | ZooModel zooModel; 62 | 63 | Path modelPath = block.getImportPath(); 64 | if (modelPath != null) { 65 | // load pre-trained model from model zoo 66 | logger.info("Loading pre-trained model ..."); 67 | 68 | try { 69 | model.load(modelPath); 70 | Flag flag = width < 50 ? Flag.GRAYSCALE : Flag.COLOR; 71 | Translator translator = 72 | ImageClassificationTranslator.builder() 73 | .optFlag(flag) 74 | .addTransform(new CenterCrop()) 75 | .addTransform(new Resize(width, height)) 76 | .addTransform(new ToTensor()) 77 | .optSynsetArtifactName("synset.txt") 78 | .optApplySoftmax(true) 79 | .build(); 80 | zooModel = new ZooModel<>(model, translator); 81 | } catch (MalformedModelException | IOException e) { 82 | throw new ModelCreationException("Failed load model from model zoo.", e); 83 | } 84 | } else { 85 | try { 86 | zooModel = trainWithResnet(model, block); 87 | } catch (IOException | TranslateException e) { 88 | throw new ModelCreationException("Failed train model.", e); 89 | } 90 | } 91 | return new SimpleImageClassifier(zooModel, 5); 92 | } 93 | 94 | private ZooModel trainWithResnet( 95 | Model model, NeuralNetImageClassifier.BuildingBlock block) 96 | throws IOException, TranslateException { 97 | int width = block.getImageWidth(); 98 | int height = block.getImageHeight(); 99 | int epochs = block.getMaxEpochs(); 100 | int batch = 1; 101 | 102 | Path trainingFile = block.getTrainingPath(); 103 | if (trainingFile == null) { 104 | throw new IllegalArgumentException("TrainingFile is required."); 105 | } 106 | ImageFolder dataset = 107 | ImageFolder.builder() 108 | .setSampling(batch, true) 109 | .setRepositoryPath(trainingFile) 110 | .addTransform(new CenterCrop(width, height)) 111 | .addTransform(new Resize(width, height)) 112 | .addTransform(new ToTensor()) 113 | .build(); 114 | 115 | RandomAccessDataset[] set = dataset.randomSplit(9, 1); 116 | 117 | List synset = dataset.getSynset(); 118 | 119 | Block resNet18 = 120 | ResNetV1.builder() 121 | .setImageShape(new Shape(3, width, height)) 122 | .setNumLayers(18) 123 | .setOutSize(synset.size()) 124 | .build(); 125 | model.setBlock(resNet18); 126 | 127 | Path exportDir = block.getExportPath(); 128 | // setup training configuration 129 | DefaultTrainingConfig config = 130 | new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) 131 | .addEvaluator(new Accuracy()) 132 | .addTrainingListeners(TrainingListener.Defaults.logging()); 133 | 134 | try (Trainer trainer = model.newTrainer(config)) { 135 | trainer.setMetrics(new Metrics()); 136 | // initialize trainer with proper input shape 137 | trainer.initialize(new Shape(1, 3, width, height)); 138 | EasyTrain.fit(trainer, epochs, set[0], set[1]); 139 | } 140 | 141 | if (exportDir != null) { 142 | model.save(exportDir, model.getName()); 143 | } 144 | 145 | Batch b = dataset.getData(model.getNDManager()).iterator().next(); 146 | NDArray array = b.getData().singletonOrThrow(); 147 | 148 | Translator translator = 149 | ImageClassificationTranslator.builder() 150 | .addTransform(new CenterCrop(width, height)) 151 | .addTransform(new Resize(width, height)) 152 | .addTransform(new ToTensor()) 153 | .optSynset(synset) 154 | .optApplySoftmax(true) 155 | .build(); 156 | return new ZooModel<>(model, translator); 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/spi/DjlImageFactoryService.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.spi; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.io.IOException; 5 | import java.io.InputStream; 6 | import java.net.URL; 7 | import java.nio.file.Path; 8 | import java.util.Map; 9 | import java.util.Objects; 10 | import java.util.Optional; 11 | import java.util.concurrent.ConcurrentHashMap; 12 | 13 | import javax.imageio.ImageIO; 14 | import javax.visrec.ImageFactory; 15 | import javax.visrec.spi.ImageFactoryService; 16 | 17 | /** 18 | * DJL implementation of {@link ImageFactoryService} which serves the implementations of {@link 19 | * ImageFactory}. 20 | * 21 | * @author Frank Liu 22 | */ 23 | public final class DjlImageFactoryService implements ImageFactoryService { 24 | 25 | private static final Map, ImageFactory> IMAGE_FACTORIES = new ConcurrentHashMap<>(); 26 | 27 | static { 28 | IMAGE_FACTORIES.put(BufferedImage.class, new ImageFactoryImpl()); 29 | } 30 | 31 | /** 32 | * Get the {@link ImageFactory} by image type. 33 | * 34 | * @param imageCls image type in {@link Class} object which is able to be processed by the image 35 | * factory implementation. 36 | * @param image type. 37 | * @return {@link ImageFactory} wrapped in {@link Optional}. If the {@link ImageFactory} could 38 | * not be found then the {@link Optional} would contain null. 39 | */ 40 | @Override 41 | @SuppressWarnings("unchecked") 42 | public Optional> getByImageType(Class imageCls) { 43 | Objects.requireNonNull(imageCls, "imageCls == null"); 44 | ImageFactory imageFactory = IMAGE_FACTORIES.get(imageCls); 45 | return Optional.ofNullable((ImageFactory) imageFactory); 46 | } 47 | 48 | /** {@link ImageFactory} to provide {@link BufferedImage} as return object. */ 49 | public static final class ImageFactoryImpl implements ImageFactory { 50 | 51 | /** {@inheritDoc} */ 52 | @Override 53 | public BufferedImage getImage(Path file) throws IOException { 54 | BufferedImage img = ImageIO.read(file.toFile()); 55 | if (img == null) { 56 | throw new IOException( 57 | "Unable to transform File into BufferedImage due to unknown image" 58 | + " encoding"); 59 | } 60 | return img; 61 | } 62 | 63 | /** {@inheritDoc} */ 64 | @Override 65 | public BufferedImage getImage(URL file) throws IOException { 66 | BufferedImage img = ImageIO.read(file); 67 | if (img == null) { 68 | throw new IOException( 69 | "Unable to transform URL into BufferedImage due to unknown image encoding"); 70 | } 71 | return img; 72 | } 73 | 74 | /** {@inheritDoc} */ 75 | @Override 76 | public BufferedImage getImage(InputStream file) throws IOException { 77 | BufferedImage img = ImageIO.read(file); 78 | if (img == null) { 79 | throw new IOException( 80 | "Unable to transform InputStream into BufferedImage due to unknown image" 81 | + " encoding"); 82 | } 83 | return img; 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/spi/DjlImplementationService.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.spi; 2 | 3 | import javax.visrec.spi.ImplementationService; 4 | 5 | /** 6 | * DJL' {@link ImplementationService}. 7 | * 8 | * @author Frank Liu 9 | */ 10 | public class DjlImplementationService extends ImplementationService { 11 | 12 | /** {@inheritDoc} */ 13 | @Override 14 | public String getName() { 15 | return "DJL"; 16 | } 17 | 18 | /** {@inheritDoc} */ 19 | @Override 20 | public String getVersion() { 21 | return "0.8.0"; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /jsr381/src/main/java/ai/djl/jsr381/spi/DjlServiceProvider.java: -------------------------------------------------------------------------------- 1 | package ai.djl.jsr381.spi; 2 | 3 | import javax.visrec.spi.ImageFactoryService; 4 | import javax.visrec.spi.ImplementationService; 5 | import javax.visrec.spi.ServiceProvider; 6 | 7 | /** 8 | * {@link ServiceProvider} implementation with DJL. 9 | * 10 | * @author Frank Liu 11 | */ 12 | public final class DjlServiceProvider extends ServiceProvider { 13 | 14 | /** {@inheritDoc} */ 15 | @Override 16 | public ImageFactoryService getImageFactoryService() { 17 | return new DjlImageFactoryService(); 18 | } 19 | 20 | /** {@inheritDoc} */ 21 | @Override 22 | public ImplementationService getImplementationService() { 23 | return new DjlImplementationService(); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /jsr381/src/main/resources/META-INF/services/javax.visrec.spi.BinaryClassifierFactory: -------------------------------------------------------------------------------- 1 | ai.djl.jsr381.spi.DjlBinaryClassifierFactory 2 | -------------------------------------------------------------------------------- /jsr381/src/main/resources/META-INF/services/javax.visrec.spi.ImageClassifierFactory: -------------------------------------------------------------------------------- 1 | ai.djl.jsr381.spi.DjlImageClassifierFactory 2 | -------------------------------------------------------------------------------- /jsr381/src/main/resources/META-INF/services/javax.visrec.spi.ServiceProvider: -------------------------------------------------------------------------------- 1 | ai.djl.jsr381.spi.DjlServiceProvider 2 | -------------------------------------------------------------------------------- /jsr381/src/test/java/ai/djl/jsr381/classification/BinaryClassifierTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.classification; 14 | 15 | import org.testng.annotations.Test; 16 | 17 | import java.net.URL; 18 | import java.nio.file.Path; 19 | import java.nio.file.Paths; 20 | import java.util.Objects; 21 | 22 | import javax.visrec.ml.classification.BinaryClassifier; 23 | import javax.visrec.ml.classification.NeuralNetBinaryClassifier; 24 | import javax.visrec.ml.model.ModelCreationException; 25 | 26 | public class BinaryClassifierTest { 27 | 28 | @Test 29 | public void testSpamEmail() throws ModelCreationException { 30 | URL url = Objects.requireNonNull(BinaryClassifierTest.class.getResource("/spam.csv")); 31 | Path trainingFile = Paths.get(url.getFile()); 32 | BinaryClassifier spamClassifier = 33 | NeuralNetBinaryClassifier.builder() 34 | .inputClass(float[].class) 35 | .inputsNum(57) 36 | .hiddenLayers(5) 37 | .maxEpochs(2) 38 | .trainingPath(trainingFile) 39 | .build(); 40 | 41 | // create test email feature 42 | float[] emailFeatures = new float[57]; 43 | emailFeatures[56] = 1; 44 | 45 | Float result = spamClassifier.classify(emailFeatures); 46 | System.out.println(result); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /jsr381/src/test/java/ai/djl/jsr381/classification/ImageClassifierTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.classification; 14 | 15 | import ai.djl.util.ZipUtils; 16 | 17 | import org.testng.annotations.Test; 18 | 19 | import java.awt.image.BufferedImage; 20 | import java.io.IOException; 21 | import java.io.InputStream; 22 | import java.net.URL; 23 | import java.nio.file.Files; 24 | import java.nio.file.Path; 25 | import java.nio.file.Paths; 26 | import java.util.Map; 27 | import java.util.Objects; 28 | 29 | import javax.visrec.ml.classification.ImageClassifier; 30 | import javax.visrec.ml.classification.NeuralNetImageClassifier; 31 | import javax.visrec.ml.model.ModelCreationException; 32 | 33 | public class ImageClassifierTest { 34 | 35 | @Test 36 | public void testImageClassifier() throws ModelCreationException { 37 | URL url = Objects.requireNonNull(ImageClassifierTest.class.getResource("/0.png")); 38 | Path input = Paths.get(url.getFile()); 39 | 40 | Path modelDir = Paths.get("src/test/resources/mlp"); 41 | 42 | ImageClassifier classifier = 43 | NeuralNetImageClassifier.builder() 44 | .inputClass(BufferedImage.class) 45 | .imageHeight(28) 46 | .imageWidth(28) 47 | .importModel(modelDir) 48 | .build(); 49 | 50 | Map result = classifier.classify(input); 51 | for (Map.Entry entry : result.entrySet()) { 52 | System.out.println(entry.getKey() + ": " + entry.getValue()); 53 | } 54 | } 55 | 56 | @Test 57 | public void testImageClassifierTraining() throws IOException, ModelCreationException { 58 | Path trainingFile = downloadTrainingData(); 59 | Path modelDir = Paths.get("build/model"); 60 | 61 | ImageClassifier classifier = 62 | NeuralNetImageClassifier.builder() 63 | .inputClass(BufferedImage.class) 64 | .imageHeight(128) 65 | .imageWidth(128) 66 | .trainingFile(trainingFile) 67 | .exportModel(modelDir) 68 | .maxEpochs(2) 69 | .build(); 70 | 71 | Path input = trainingFile.resolve("cat/cat_1.png"); 72 | Map result = classifier.classify(input); 73 | for (Map.Entry entry : result.entrySet()) { 74 | System.out.println(entry.getKey() + ": " + entry.getValue()); 75 | } 76 | } 77 | 78 | private Path downloadTrainingData() throws IOException { 79 | String link = 80 | "https://github.com/JavaVisRec/jsr381-examples-datasets/raw/master/cats_and_dogs_training_data_png.zip"; 81 | URL url = new URL(link); 82 | Path dir = Paths.get("datasets", "cats_and_dogs"); 83 | if (!Files.exists(dir)) { 84 | Files.createDirectories(dir); 85 | try (InputStream is = url.openStream()) { 86 | ZipUtils.unzip(is, dir); 87 | } 88 | } 89 | return dir.resolve("training"); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /jsr381/src/test/java/ai/djl/jsr381/detection/ObjectDetectorTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance 5 | * with the License. A copy of the License is located at 6 | * 7 | * http://aws.amazon.com/apache2.0/ 8 | * 9 | * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES 10 | * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 11 | * and limitations under the License. 12 | */ 13 | package ai.djl.jsr381.detection; 14 | 15 | import ai.djl.Application; 16 | import ai.djl.MalformedModelException; 17 | import ai.djl.modality.cv.Image; 18 | import ai.djl.modality.cv.output.DetectedObjects; 19 | import ai.djl.repository.zoo.Criteria; 20 | import ai.djl.repository.zoo.ModelNotFoundException; 21 | import ai.djl.repository.zoo.ModelZoo; 22 | import ai.djl.repository.zoo.ZooModel; 23 | 24 | import org.testng.annotations.Test; 25 | 26 | import java.awt.image.BufferedImage; 27 | import java.io.IOException; 28 | import java.net.URL; 29 | import java.util.List; 30 | import java.util.Map; 31 | 32 | import javax.imageio.ImageIO; 33 | import javax.visrec.ml.detection.BoundingBox; 34 | 35 | public class ObjectDetectorTest { 36 | 37 | @Test 38 | public void testObjectDetection() 39 | throws IOException, ModelNotFoundException, MalformedModelException { 40 | Criteria criteria = 41 | Criteria.builder() 42 | .setTypes(Image.class, DetectedObjects.class) 43 | .optApplication(Application.CV.OBJECT_DETECTION) 44 | .optArtifactId("yolo") 45 | .optArgument("threshold", 0.3) 46 | .build(); 47 | try (ZooModel model = ModelZoo.loadModel(criteria)) { 48 | SimpleObjectDetector objectDetector = new SimpleObjectDetector(model); 49 | 50 | URL imageUrl = 51 | new URL("https://djl-ai.s3.amazonaws.com/resources/images/dog_bike_car.jpg"); 52 | BufferedImage inputImage = ImageIO.read(imageUrl); 53 | Map> result = objectDetector.detectObject(inputImage); 54 | 55 | for (List boundingBoxes : result.values()) { 56 | for (BoundingBox boundingBox : boundingBoxes) { 57 | System.out.println(boundingBox.toString()); 58 | } 59 | } 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /jsr381/src/test/resources/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaVisRec/visrec-djl/e858a6e41a41e68c89262b2d4a7b1f991d141825/jsr381/src/test/resources/0.png -------------------------------------------------------------------------------- /jsr381/src/test/resources/mlp/mlp-0000.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavaVisRec/visrec-djl/e858a6e41a41e68c89262b2d4a7b1f991d141825/jsr381/src/test/resources/mlp/mlp-0000.params -------------------------------------------------------------------------------- /jsr381/src/test/resources/mlp/mlp-symbol.json: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": [ 3 | { 4 | "op": "null", 5 | "name": "data", 6 | "inputs": [] 7 | }, 8 | { 9 | "op": "null", 10 | "name": "hybridsequential0_dense0_weight", 11 | "attrs": { 12 | "__dtype__": "0", 13 | "__lr_mult__": "1.0", 14 | "__shape__": "(128, 0)", 15 | "__storage_type__": "0", 16 | "__wd_mult__": "1.0" 17 | }, 18 | "inputs": [] 19 | }, 20 | { 21 | "op": "null", 22 | "name": "hybridsequential0_dense0_bias", 23 | "attrs": { 24 | "__dtype__": "0", 25 | "__init__": "zeros", 26 | "__lr_mult__": "1.0", 27 | "__shape__": "(128,)", 28 | "__storage_type__": "0", 29 | "__wd_mult__": "1.0" 30 | }, 31 | "inputs": [] 32 | }, 33 | { 34 | "op": "FullyConnected", 35 | "name": "hybridsequential0_dense0_fwd", 36 | "attrs": { 37 | "flatten": "True", 38 | "no_bias": "False", 39 | "num_hidden": "128" 40 | }, 41 | "inputs": [ 42 | [ 43 | 0, 44 | 0, 45 | 0 46 | ], 47 | [ 48 | 1, 49 | 0, 50 | 0 51 | ], 52 | [ 53 | 2, 54 | 0, 55 | 0 56 | ] 57 | ] 58 | }, 59 | { 60 | "op": "Activation", 61 | "name": "hybridsequential0_dense0_relu_fwd", 62 | "attrs": { 63 | "act_type": "relu" 64 | }, 65 | "inputs": [ 66 | [ 67 | 3, 68 | 0, 69 | 0 70 | ] 71 | ] 72 | }, 73 | { 74 | "op": "null", 75 | "name": "hybridsequential0_dense1_weight", 76 | "attrs": { 77 | "__dtype__": "0", 78 | "__lr_mult__": "1.0", 79 | "__shape__": "(64, 0)", 80 | "__storage_type__": "0", 81 | "__wd_mult__": "1.0" 82 | }, 83 | "inputs": [] 84 | }, 85 | { 86 | "op": "null", 87 | "name": "hybridsequential0_dense1_bias", 88 | "attrs": { 89 | "__dtype__": "0", 90 | "__init__": "zeros", 91 | "__lr_mult__": "1.0", 92 | "__shape__": "(64,)", 93 | "__storage_type__": "0", 94 | "__wd_mult__": "1.0" 95 | }, 96 | "inputs": [] 97 | }, 98 | { 99 | "op": "FullyConnected", 100 | "name": "hybridsequential0_dense1_fwd", 101 | "attrs": { 102 | "flatten": "True", 103 | "no_bias": "False", 104 | "num_hidden": "64" 105 | }, 106 | "inputs": [ 107 | [ 108 | 4, 109 | 0, 110 | 0 111 | ], 112 | [ 113 | 5, 114 | 0, 115 | 0 116 | ], 117 | [ 118 | 6, 119 | 0, 120 | 0 121 | ] 122 | ] 123 | }, 124 | { 125 | "op": "Activation", 126 | "name": "hybridsequential0_dense1_relu_fwd", 127 | "attrs": { 128 | "act_type": "relu" 129 | }, 130 | "inputs": [ 131 | [ 132 | 7, 133 | 0, 134 | 0 135 | ] 136 | ] 137 | }, 138 | { 139 | "op": "null", 140 | "name": "hybridsequential0_dense2_weight", 141 | "attrs": { 142 | "__dtype__": "0", 143 | "__lr_mult__": "1.0", 144 | "__shape__": "(10, 0)", 145 | "__storage_type__": "0", 146 | "__wd_mult__": "1.0" 147 | }, 148 | "inputs": [] 149 | }, 150 | { 151 | "op": "null", 152 | "name": "hybridsequential0_dense2_bias", 153 | "attrs": { 154 | "__dtype__": "0", 155 | "__init__": "zeros", 156 | "__lr_mult__": "1.0", 157 | "__shape__": "(10,)", 158 | "__storage_type__": "0", 159 | "__wd_mult__": "1.0" 160 | }, 161 | "inputs": [] 162 | }, 163 | { 164 | "op": "FullyConnected", 165 | "name": "hybridsequential0_dense2_fwd", 166 | "attrs": { 167 | "flatten": "True", 168 | "no_bias": "False", 169 | "num_hidden": "10" 170 | }, 171 | "inputs": [ 172 | [ 173 | 8, 174 | 0, 175 | 0 176 | ], 177 | [ 178 | 9, 179 | 0, 180 | 0 181 | ], 182 | [ 183 | 10, 184 | 0, 185 | 0 186 | ] 187 | ] 188 | } 189 | ], 190 | "arg_nodes": [ 191 | 0, 192 | 1, 193 | 2, 194 | 5, 195 | 6, 196 | 9, 197 | 10 198 | ], 199 | "node_row_ptr": [ 200 | 0, 201 | 1, 202 | 2, 203 | 3, 204 | 4, 205 | 5, 206 | 6, 207 | 7, 208 | 8, 209 | 9, 210 | 10, 211 | 11, 212 | 12 213 | ], 214 | "heads": [ 215 | [ 216 | 11, 217 | 0, 218 | 0 219 | ] 220 | ], 221 | "attrs": { 222 | "mxnet_version": [ 223 | "int", 224 | 10500 225 | ] 226 | } 227 | } -------------------------------------------------------------------------------- /jsr381/src/test/resources/mlp/synset.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | include "jsr381", "examples" 2 | -------------------------------------------------------------------------------- /tools/gradle/formatter.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | repositories { 3 | mavenCentral() 4 | maven { 5 | url "https://plugins.gradle.org/m2/" 6 | } 7 | } 8 | dependencies { 9 | classpath 'com.google.googlejavaformat:google-java-format:1.15.0' 10 | } 11 | } 12 | 13 | apply plugin: JavaFormatterPlugin 14 | 15 | check.dependsOn verifyJava 16 | 17 | import com.google.googlejavaformat.java.Formatter 18 | import com.google.googlejavaformat.java.ImportOrderer 19 | import com.google.googlejavaformat.java.JavaFormatterOptions 20 | import com.google.googlejavaformat.java.Main 21 | import com.google.googlejavaformat.java.RemoveUnusedImports 22 | 23 | class JavaFormatterPlugin implements Plugin { 24 | void apply(Project project) { 25 | project.task('formatJava') { 26 | doLast { 27 | Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in) 28 | for (item in project.sourceSets) { 29 | for (File file : item.getAllSource()) { 30 | if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) { 31 | continue 32 | } 33 | if (formatter.format("-a", "-i", file.getAbsolutePath()) != 0) { 34 | throw new GradleException("Format java failed: " + file.getAbsolutePath()) 35 | } 36 | } 37 | } 38 | } 39 | } 40 | 41 | project.task('verifyJava') { 42 | doLast { 43 | Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in) 44 | for (item in project.sourceSets) { 45 | for (File file : item.getAllSource()) { 46 | if (!file.getName().endsWith(".java") || file.getAbsolutePath().contains("generated-src")) { 47 | continue 48 | } 49 | if (formatter.format("-a", "-n", "--set-exit-if-changed", file.getAbsolutePath()) != 0) { 50 | throw new GradleException("File not formatted: " + file.getAbsolutePath() 51 | + System.lineSeparator() 52 | + "In order to reformat your code, run './gradlew formatJava' (or './gradlew fJ' for short)" 53 | + System.lineSeparator() 54 | + "See https://github.com/deepjavalibrary/djl/blob/master/docs/development/development_guideline.md#coding-conventions for more details") 55 | } 56 | } 57 | } 58 | } 59 | } 60 | } 61 | } 62 | --------------------------------------------------------------------------------