├── .circleci.settings.xml ├── .deepsource.toml ├── .github ├── release-drafter-config.yml └── workflows │ ├── release-drafter.yml │ └── version-and-release.yml ├── .gitignore ├── LICENSE ├── README.md ├── pom.xml └── src ├── main └── java │ └── com │ └── redislabs │ └── redisai │ ├── Backend.java │ ├── Command.java │ ├── Dag.java │ ├── DagRunCommands.java │ ├── DataType.java │ ├── Device.java │ ├── Keyword.java │ ├── Model.java │ ├── RedisAI.java │ ├── RedisAIException.java │ ├── Script.java │ ├── Tensor.java │ └── exceptions │ └── JRedisAIRunTimeException.java └── test ├── java └── com │ └── redislabs │ └── redisai │ ├── ChunkTest.java │ ├── DagTest.java │ ├── DagV2Test.java │ ├── DataTypeTest.java │ ├── ModelTest.java │ ├── RedisAITest.java │ ├── ScriptTest.java │ └── TensorTest.java └── resources └── test_data ├── creditcard_10K.csv ├── creditcardfraud.pb ├── graph.pb ├── graph_v2.pb ├── imagenet_class_index.json ├── linear_iris.onnx ├── logreg_iris.onnx ├── mnist.onnx ├── mnist_batched.onnx ├── mnist_model_quant.tflite ├── mobilenet_v2_1.4_224_frozen.pb ├── one.png ├── one.raw ├── onnx_batch.py ├── panda-224x224.jpg ├── panda.jpg ├── pt-minimal.pt ├── pt_minimal.py ├── script.txt ├── script_v2.txt ├── tf-minimal.py └── tf2-minimal.py /.circleci.settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ossrh 5 | ${env.OSSH_USERNAME} 6 | ${env.OSSH_PASSWORD} 7 | 8 | 9 | gpg.passphrase 10 | ${env.GPG_PASSPHRASE} 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "java" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "8" 9 | -------------------------------------------------------------------------------- /.github/release-drafter-config.yml: -------------------------------------------------------------------------------- 1 | name-template: 'Version $NEXT_PATCH_VERSION' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | categories: 4 | - title: 'Features' 5 | labels: 6 | - 'feature' 7 | - 'enhancement' 8 | - title: 'Bug Fixes' 9 | labels: 10 | - 'fix' 11 | - 'bugfix' 12 | - 'bug' 13 | - title: 'Maintenance' 14 | label: 'chore' 15 | change-template: '- $TITLE (#$NUMBER)' 16 | exclude-labels: 17 | - 'skip-changelog' 18 | template: | 19 | ## Changes 20 | 21 | $CHANGES 22 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: 7 | - master 8 | 9 | jobs: 10 | update_release_draft: 11 | runs-on: ubuntu-latest 12 | steps: 13 | # Drafts your next Release notes as Pull Requests are merged into "master" 14 | - uses: release-drafter/release-drafter@v5 15 | with: 16 | # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml 17 | config-name: release-drafter-config.yml 18 | env: 19 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 20 | -------------------------------------------------------------------------------- /.github/workflows/version-and-release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: get version from tag 15 | id: get_version 16 | run: | 17 | realversion="${GITHUB_REF/refs\/tags\//}" 18 | realversion="${realversion//v/}" 19 | echo "::set-output name=VERSION::$realversion" 20 | 21 | - name: Set up publishing to maven central 22 | uses: actions/setup-java@v2 23 | with: 24 | java-version: '8' 25 | distribution: 'adopt' 26 | server-id: ossrh 27 | server-username: MAVEN_USERNAME 28 | server-password: MAVEN_PASSWORD 29 | 30 | - name: mvn versions 31 | run: mvn versions:set -DnewVersion=${{ steps.get_version.outputs.VERSION }} 32 | 33 | - name: Install gpg key 34 | run: | 35 | cat <(echo -e "${{ secrets.OSSH_GPG_SECRET_KEY }}") | gpg --batch --import 36 | gpg --list-secret-keys --keyid-format LONG 37 | 38 | - name: Publish 39 | run: | 40 | mvn --no-transfer-progress \ 41 | --batch-mode \ 42 | -Dgpg.passphrase='${{ secrets.OSSH_GPG_SECRET_KEY_PASSWORD }}' \ 43 | -DskipTests deploy -P release 44 | env: 45 | MAVEN_USERNAME: ${{secrets.OSSH_USERNAME}} 46 | MAVEN_PASSWORD: ${{secrets.OSSH_TOKEN}} 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | target/ 4 | 5 | # Log file 6 | *.log 7 | 8 | # BlueJ files 9 | *.ctxt 10 | 11 | # Mobile Tools for Java (J2ME) 12 | .mtj.tmp/ 13 | 14 | # Package Files # 15 | *.jar 16 | *.war 17 | *.nar 18 | *.ear 19 | *.zip 20 | *.tar.gz 21 | *.rar 22 | 23 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 24 | hs_err_pid* 25 | /bin/ 26 | 27 | # eclipse 28 | .classpath 29 | .project 30 | .settings/ 31 | .pydevproject 32 | 33 | *.iml 34 | /.idea/ 35 | 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, RedisAI 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![license](https://img.shields.io/github/license/RedisAI/JRedisAI.svg)](https://github.com/RedisAI/JRedisAI) 2 | [![CircleCI](https://circleci.com/gh/RedisAI/JRedisAI/tree/master.svg?style=svg)](https://circleci.com/gh/RedisAI/JRedisAI/tree/master) 3 | [![GitHub issues](https://img.shields.io/github/release/RedisAI/JRedisAI.svg)](https://github.com/RedisAI/JRedisAI/releases/latest) 4 | [![Maven Central](https://maven-badges.herokuapp.com/maven-central/com.redislabs/jredisai/badge.svg)](https://maven-badges.herokuapp.com/maven-central/com.redislabs/jredisai) 5 | [![Javadocs](https://www.javadoc.io/badge/com.redislabs/jredisai.svg)](https://www.javadoc.io/doc/com.redislabs/jredisai) 6 | [![codecov](https://codecov.io/gh/RedisAI/JRedisAI/branch/master/graph/badge.svg?token=cC4H2TvQHs)](https://codecov.io/gh/RedisAI/JRedisAI) 7 | [![Language grade: Java](https://img.shields.io/lgtm/grade/java/g/RedisAI/JRedisAI.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/RedisAI/JRedisAI/context:java) 8 | [![Known Vulnerabilities](https://snyk.io/test/github/RedisAI/JRedisAI/badge.svg?targetFile=pom.xml)](https://snyk.io/test/github/RedisAI/JRedisAI?targetFile=pom.xml) 9 | 10 | 11 | # JRedisAI 12 | [![Forum](https://img.shields.io/badge/Forum-RedisAI-blue)](https://forum.redislabs.com/c/modules/redisai) 13 | [![Discord](https://img.shields.io/discord/697882427875393627?style=flat-square)](https://discord.gg/rTQm7UZ) 14 | 15 | Java client for RedisAI 16 | 17 | ### Official Releases 18 | 19 | ```xml 20 | 21 | 22 | com.redislabs 23 | jredisai 24 | 0.9.0 25 | 26 | 27 | ``` 28 | 29 | ### Snapshots 30 | 31 | ```xml 32 | 33 | 34 | snapshots-repo 35 | https://oss.sonatype.org/content/repositories/snapshots 36 | 37 | 38 | ``` 39 | 40 | and 41 | 42 | ```xml 43 | 44 | 45 | com.redislabs 46 | jredisai 47 | 1.0.0-SNAPSHOT 48 | 49 | 50 | ``` 51 | 52 | # Example: Using the Java Client 53 | 54 | ```java 55 | RedisAI client = new RedisAI("localhost", 6379); 56 | client.setModel("model", Backend.TF, Device.CPU, new String[] {"a", "b"}, new String[] {"mul"}, "graph.pb"); 57 | 58 | client.setTensor("a", new float[] {2, 3}, new int[]{2}); 59 | client.setTensor("b", new float[] {2, 3}, new int[]{2}); 60 | 61 | client.runModel("model", new String[] {"a", "b"}, new String[] {"c"}); 62 | ``` 63 | 64 | ## Note 65 | 66 | **Chunk size:** Since version `0.10.0`, the chunk size of model (blob) is set to 512mb (536870912 bytes) based on 67 | default Redis configuration. This behavior can be changed by `redisai.blob.chunkSize` system property at the beginning 68 | of the application. For example, chunk size can be limited to 8mb by setting `-Dredisai.blob.chunkSize=8388608` or 69 | `System.setProperty(Model.BLOB_CHUNK_SIZE_PROPERTY, "8388608");`. A limit of 0 (zero) would disable chunking. 70 | 71 | **Socket timeout:** Operations with large data and/or long processing time may require a higher socket timeout. 72 | Following constructor may come in handy for that purpose. 73 | ``` 74 | HostAndPort hostAndPort = new HostAndPort(host, port); 75 | JedisClientConfig clientConfig = DefaultJedisClientConfig.builder().socketTimeoutMillis(largeTimeout).build(); 76 | new RedisAI(hostAndPort, clientConfig); 77 | ``` 78 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | org.sonatype.oss 7 | oss-parent 8 | 7 9 | 10 | 4.0.0 11 | jar 12 | com.redislabs 13 | jredisai 14 | 1.0.0-SNAPSHOT 15 | JRedisAI 16 | Java Client for RedisAI. 17 | https://github.com/RedisAI/JRedisAI 18 | 19 | 20 | 21 | JRedisAI Mailing List 22 | 23 | http://groups.google.com/group/redisai 24 | 25 | 26 | 27 | 28 | 29 | BSD 3 Clause 30 | https://opensource.org/licenses/BSD-3-Clause 31 | repo 32 | 33 | 34 | 35 | 36 | github 37 | http://github.com/RedisAI/JRedisAI/issues 38 | 39 | 40 | 41 | scm:git:git@github.com:RedisAI/JRedisAI.git 42 | scm:git:git@github.com:RedisAI/JRedisAI.git 43 | scm:git:git@github.com:RedisAI/JRedisAI.git 44 | 45 | 46 | 47 | 48 | 49 | junit 50 | junit 51 | 4.13.1 52 | test 53 | 54 | 55 | redis.clients 56 | jedis 57 | 3.7.1 58 | compile 59 | 60 | 61 | commons-io 62 | commons-io 63 | 2.8.0 64 | test 65 | 66 | 67 | 68 | 69 | 70 | ossrh 71 | https://oss.sonatype.org/content/repositories/snapshots 72 | 73 | 74 | ossrh 75 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 76 | 77 | 78 | 79 | 80 | 81 | 82 | org.codehaus.mojo 83 | cobertura-maven-plugin 84 | 2.7 85 | 86 | 87 | html 88 | xml 89 | 90 | 91 | 92 | 93 | 94 | org.apache.maven.plugins 95 | maven-compiler-plugin 96 | 3.1 97 | 98 | 1.8 99 | 1.8 100 | 101 | 102 | 103 | org.apache.maven.plugins 104 | maven-surefire-plugin 105 | 2.19.1 106 | 107 | 108 | org.apache.maven.plugins 109 | maven-source-plugin 110 | 2.2.1 111 | 112 | true 113 | 114 | 115 | 116 | attach-sources 117 | 118 | jar 119 | 120 | 121 | 122 | 123 | 124 | org.apache.maven.plugins 125 | maven-javadoc-plugin 126 | 2.9.1 127 | 128 | true 129 | -Xdoclint:none 130 | 131 | 132 | 133 | attach-javadoc 134 | 135 | jar 136 | 137 | 138 | 139 | 140 | 141 | org.apache.maven.plugins 142 | maven-release-plugin 143 | 2.4.2 144 | 145 | 146 | org.sonatype.plugins 147 | nexus-staging-maven-plugin 148 | 1.6.7 149 | true 150 | 151 | ossrh 152 | https://oss.sonatype.org/ 153 | true 154 | 155 | 156 | 157 | com.cosium.code 158 | git-code-format-maven-plugin 159 | 2.4 160 | 161 | 162 | 163 | install-formatter-hook 164 | 165 | install-hooks 166 | 167 | 168 | 170 | 171 | validate-code-format 172 | 173 | validate-code-format 174 | 175 | 176 | 177 | 178 | 179 | maven-jar-plugin 180 | 2.6 181 | 182 | 183 | ${project.build.outputDirectory}/META-INF/MANIFEST.MF 184 | 185 | 186 | 187 | 188 | org.apache.felix 189 | maven-bundle-plugin 190 | 2.5.3 191 | 192 | 193 | bundle-manifest 194 | process-classes 195 | 196 | manifest 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | release 206 | 207 | 208 | 209 | 210 | maven-gpg-plugin 211 | 3.0.1 212 | 213 | 214 | --pinentry-mode 215 | loopback 216 | 217 | 218 | 219 | 220 | sign-artifacts 221 | verify 222 | 223 | sign 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Backend.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import redis.clients.jedis.commands.ProtocolCommand; 4 | import redis.clients.jedis.util.SafeEncoder; 5 | 6 | public enum Backend implements ProtocolCommand { 7 | TF, 8 | TORCH, 9 | TFLITE, 10 | ONNX; 11 | 12 | private final byte[] raw; 13 | 14 | Backend() { 15 | raw = SafeEncoder.encode(this.name()); 16 | } 17 | 18 | @Override 19 | public byte[] getRaw() { 20 | return raw; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Command.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import redis.clients.jedis.commands.ProtocolCommand; 4 | import redis.clients.jedis.util.SafeEncoder; 5 | 6 | public enum Command implements ProtocolCommand { 7 | TENSOR_GET("AI.TENSORGET"), 8 | TENSOR_SET("AI.TENSORSET"), 9 | MODEL_GET("AI.MODELGET"), 10 | MODEL_SET("AI.MODELSET"), 11 | MODEL_STORE("AI.MODELSTORE"), 12 | MODEL_DEL("AI.MODELDEL"), 13 | MODEL_RUN("AI.MODELRUN"), 14 | MODEL_EXECUTE("AI.MODELEXECUTE"), 15 | SCRIPT_SET("AI.SCRIPTSET"), 16 | SCRIPT_STORE("AI.SCRIPTSTORE"), 17 | SCRIPT_GET("AI.SCRIPTGET"), 18 | SCRIPT_DEL("AI.SCRIPTDEL"), 19 | SCRIPT_RUN("AI.SCRIPTRUN"), 20 | SCRIPT_EXECUTE("AI.SCRIPTEXECUTE"), 21 | DAGRUN("AI.DAGRUN"), 22 | DAGRUN_RO("AI.DAGRUN_RO"), 23 | DAGEXECUTE("AI.DAGEXECUTE"), 24 | DAGEXECUTE_RO("AI.DAGEXECUTE_RO"), 25 | INFO("AI.INFO"), 26 | CONFIG("AI.CONFIG"); 27 | 28 | private final byte[] raw; 29 | 30 | Command(String alt) { 31 | raw = SafeEncoder.encode(alt); 32 | } 33 | 34 | @Override 35 | public byte[] getRaw() { 36 | return raw; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Dag.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import redis.clients.jedis.util.SafeEncoder; 6 | 7 | public class Dag implements DagRunCommands { 8 | private final List> commands = new ArrayList<>(); 9 | private final List tensorgetflag = new ArrayList<>(); 10 | 11 | /** Direct acyclic graph of operations to run within RedisAI */ 12 | public Dag() {} 13 | 14 | protected List processDagReply(List reply) { 15 | List outputList = new ArrayList<>(reply.size()); 16 | for (int i = 0; i < reply.size(); i++) { 17 | Object obj = reply.get(i); 18 | // TODO: Should encode 'OK', 'NA', etc. response 19 | if (obj instanceof Exception) { 20 | Exception ex = (Exception) obj; 21 | outputList.add(new RedisAIException(ex.getMessage(), ex)); 22 | } else if (this.tensorgetflag.get(i)) { 23 | outputList.add(Tensor.createTensorFromRespReply((List) obj)); 24 | } else { 25 | outputList.add(obj); 26 | } 27 | } 28 | return outputList; 29 | } 30 | 31 | @Override 32 | public Dag setTensor(String key, Tensor tensor) { 33 | List args = tensor.tensorSetFlatArgs(key, true); 34 | this.commands.add(args); 35 | this.tensorgetflag.add(false); 36 | return this; 37 | } 38 | 39 | @Override 40 | public Dag getTensor(String key) { 41 | List args = Tensor.tensorGetFlatArgs(key, true); 42 | this.commands.add(args); 43 | this.tensorgetflag.add(true); 44 | return this; 45 | } 46 | 47 | @Override 48 | public Dag runModel(String key, String[] inputs, String[] outputs) { 49 | List args = Model.modelRunFlatArgs(key, inputs, outputs, true); 50 | this.commands.add(args); 51 | this.tensorgetflag.add(false); 52 | return this; 53 | } 54 | 55 | @Override 56 | public Dag executeModel(String key, String[] inputs, String[] outputs) { 57 | List args = Model.modelExecuteCommandArgs(key, inputs, outputs, -1L, true); 58 | this.commands.add(args); 59 | this.tensorgetflag.add(false); 60 | return this; 61 | } 62 | 63 | @Override 64 | public Dag runScript(String key, String function, String[] inputs, String[] outputs) { 65 | List args = Script.scriptRunFlatArgs(key, function, inputs, outputs, true); 66 | this.commands.add(args); 67 | this.tensorgetflag.add(false); 68 | return this; 69 | } 70 | 71 | @Override 72 | public Dag executeScript( 73 | String key, 74 | String function, 75 | List keys, 76 | List inputs, 77 | List args, 78 | List outputs) { 79 | List binary = 80 | Script.scriptExecuteFlatArgs(key, function, keys, inputs, args, outputs, -1L, true); 81 | this.commands.add(binary); 82 | this.tensorgetflag.add(false); 83 | return this; 84 | } 85 | 86 | List dagRunFlatArgs(String[] loadKeys, String[] persistKeys) { 87 | List args = new ArrayList<>(); 88 | if (loadKeys != null && loadKeys.length > 0) { 89 | args.add(Keyword.LOAD.getRaw()); 90 | args.add(SafeEncoder.encode(String.valueOf(loadKeys.length))); 91 | for (String key : loadKeys) { 92 | args.add(SafeEncoder.encode(key)); 93 | } 94 | } 95 | if (persistKeys != null && persistKeys.length > 0) { 96 | args.add(Keyword.PERSIST.getRaw()); 97 | args.add(SafeEncoder.encode(String.valueOf(persistKeys.length))); 98 | for (String key : persistKeys) { 99 | args.add(SafeEncoder.encode(key)); 100 | } 101 | } 102 | for (List command : this.commands) { 103 | args.add(Keyword.PIPE.getRaw()); 104 | args.addAll(command); 105 | } 106 | return args; 107 | } 108 | 109 | List dagExecuteFlatArgs( 110 | String[] loadTensors, String[] persistTensors, String routingHint) { 111 | List args = new ArrayList<>(); 112 | if (loadTensors != null && loadTensors.length > 0) { 113 | args.add(Keyword.LOAD.getRaw()); 114 | args.add(SafeEncoder.encode(String.valueOf(loadTensors.length))); 115 | for (String key : loadTensors) { 116 | args.add(SafeEncoder.encode(key)); 117 | } 118 | } 119 | if (persistTensors != null && persistTensors.length > 0) { 120 | args.add(Keyword.PERSIST.getRaw()); 121 | args.add(SafeEncoder.encode(String.valueOf(persistTensors.length))); 122 | for (String key : persistTensors) { 123 | args.add(SafeEncoder.encode(key)); 124 | } 125 | } 126 | 127 | if (routingHint != null) { 128 | args.add(Keyword.ROUTING.getRaw()); 129 | args.add(SafeEncoder.encode(routingHint)); 130 | } 131 | 132 | for (List command : this.commands) { 133 | args.add(Keyword.PIPE.getRaw()); 134 | args.addAll(command); 135 | } 136 | return args; 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/DagRunCommands.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import java.util.List; 4 | 5 | interface DagRunCommands { 6 | T setTensor(String key, Tensor tensor); 7 | 8 | T getTensor(String key); 9 | 10 | T runModel(String key, String[] inputs, String[] outputs); 11 | 12 | T executeModel(String key, String[] inputs, String[] outputs); 13 | 14 | T runScript(String key, String function, String[] inputs, String[] outputs); 15 | 16 | T executeScript( 17 | String key, 18 | String function, 19 | List keys, 20 | List inputs, 21 | List args, 22 | List outputs); 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/DataType.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import java.lang.reflect.Array; 4 | import java.util.ArrayList; 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import redis.clients.jedis.Protocol; 8 | import redis.clients.jedis.commands.ProtocolCommand; 9 | import redis.clients.jedis.util.SafeEncoder; 10 | 11 | public enum DataType implements ProtocolCommand { 12 | INT32 { 13 | @Override 14 | public List toByteArray(Object obj) { 15 | int[] values = (int[]) obj; 16 | List res = new ArrayList<>(values.length); 17 | for (int value : values) { 18 | res.add(Protocol.toByteArray(value)); 19 | } 20 | return res; 21 | } 22 | 23 | @Override 24 | protected Object toObject(List data) { 25 | int[] values = new int[data.size()]; 26 | for (int i = 0; i < data.size(); i++) { 27 | values[i] = Integer.parseInt(SafeEncoder.encode(data.get(i))); 28 | } 29 | return values; 30 | } 31 | }, 32 | INT64 { 33 | @Override 34 | public List toByteArray(Object obj) { 35 | long[] values = (long[]) obj; 36 | List res = new ArrayList<>(values.length); 37 | for (long value : values) { 38 | res.add(Protocol.toByteArray(value)); 39 | } 40 | return res; 41 | } 42 | 43 | @Override 44 | protected Object toObject(List data) { 45 | long[] values = new long[data.size()]; 46 | for (int i = 0; i < data.size(); i++) { 47 | values[i] = Long.parseLong(SafeEncoder.encode(data.get(i))); 48 | } 49 | return values; 50 | } 51 | }, 52 | FLOAT { 53 | @Override 54 | public List toByteArray(Object obj) { 55 | float[] values = (float[]) obj; 56 | List res = new ArrayList<>(values.length); 57 | for (float value : values) { 58 | res.add(Protocol.toByteArray(value)); 59 | } 60 | return res; 61 | } 62 | 63 | @Override 64 | protected Object toObject(List data) { 65 | float[] values = new float[data.size()]; 66 | for (int i = 0; i < data.size(); i++) { 67 | values[i] = Float.parseFloat(SafeEncoder.encode(data.get(i))); 68 | } 69 | return values; 70 | } 71 | }, 72 | DOUBLE { 73 | @Override 74 | public List toByteArray(Object obj) { 75 | double[] values = (double[]) obj; 76 | List res = new ArrayList<>(values.length); 77 | for (double value : values) { 78 | res.add(Protocol.toByteArray(value)); 79 | } 80 | return res; 81 | } 82 | 83 | @Override 84 | protected Object toObject(List data) { 85 | double[] values = new double[data.size()]; 86 | for (int i = 0; i < data.size(); i++) { 87 | values[i] = Double.parseDouble(SafeEncoder.encode(data.get(i))); 88 | } 89 | return values; 90 | } 91 | }; 92 | 93 | private static final HashMap, DataType> classDataTypes = new HashMap<>(); 94 | 95 | static { 96 | classDataTypes.put(int.class, DataType.INT32); 97 | classDataTypes.put(Integer.class, DataType.INT32); 98 | classDataTypes.put(long.class, DataType.INT64); 99 | classDataTypes.put(Long.class, DataType.INT64); 100 | classDataTypes.put(float.class, DataType.FLOAT); 101 | classDataTypes.put(Float.class, DataType.FLOAT); 102 | classDataTypes.put(double.class, DataType.DOUBLE); 103 | classDataTypes.put(Double.class, DataType.DOUBLE); 104 | } 105 | 106 | private final byte[] raw; 107 | 108 | DataType() { 109 | raw = SafeEncoder.encode(this.name()); 110 | } 111 | 112 | static DataType getDataTypefromString(String dtypeRaw) { 113 | DataType dt = null; 114 | if (dtypeRaw.equals(DataType.INT32.name())) { 115 | dt = DataType.INT32; 116 | } 117 | if (dtypeRaw.equals(DataType.INT64.name())) { 118 | dt = DataType.INT64; 119 | } 120 | if (dtypeRaw.equals(DataType.FLOAT.name())) { 121 | dt = DataType.FLOAT; 122 | } 123 | if (dtypeRaw.equals(DataType.DOUBLE.name())) { 124 | dt = DataType.DOUBLE; 125 | } 126 | return dt; 127 | } 128 | 129 | /** The class for the data type to which Java object o corresponds. */ 130 | public static DataType baseObjType(Object o) { 131 | Class c = o.getClass(); 132 | while (c.isArray()) { 133 | c = c.getComponentType(); 134 | } 135 | DataType ret = classDataTypes.get(c); 136 | if (ret != null) { 137 | return ret; 138 | } 139 | throw new IllegalArgumentException("cannot create Tensors of type " + c.getName()); 140 | } 141 | 142 | private static List toByteArray(Object obj, long[] dimensions, int dim, DataType type) { 143 | ArrayList res = new ArrayList<>(); 144 | if (dimensions.length - 1 > dim) { 145 | long dimension = dimensions[dim++]; 146 | for (int i = 0; i < dimension; ++i) { 147 | Object value = Array.get(obj, i); 148 | res.addAll(toByteArray(value, dimensions, dim, type)); 149 | } 150 | } else { 151 | res.addAll(type.toByteArray(obj)); 152 | } 153 | return res; 154 | } 155 | 156 | protected abstract List toByteArray(Object obj); 157 | 158 | protected abstract Object toObject(List data); 159 | 160 | public byte[] getRaw() { 161 | return raw; 162 | } 163 | 164 | public List toByteArray(Object obj, long[] dimensions) { 165 | return toByteArray(obj, dimensions, 0, this); 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Device.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import redis.clients.jedis.commands.ProtocolCommand; 4 | import redis.clients.jedis.util.SafeEncoder; 5 | 6 | public enum Device implements ProtocolCommand { 7 | CPU, 8 | GPU; 9 | 10 | private final byte[] raw; 11 | 12 | Device() { 13 | raw = SafeEncoder.encode(name()); 14 | } 15 | 16 | @Override 17 | public byte[] getRaw() { 18 | return raw; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Keyword.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import redis.clients.jedis.commands.ProtocolCommand; 4 | import redis.clients.jedis.util.SafeEncoder; 5 | 6 | public enum Keyword implements ProtocolCommand { 7 | INPUTS, 8 | OUTPUTS, 9 | META, 10 | VALUES, 11 | BLOB, 12 | SOURCE, 13 | RESETSTAT, 14 | TAG, 15 | ENTRY_POINTS, 16 | BATCHSIZE, 17 | MINBATCHSIZE, 18 | MINBATCHTIMEOUT, 19 | TIMEOUT, 20 | BACKENDSPATH, 21 | LOADBACKEND, 22 | LOAD, 23 | PERSIST, 24 | KEYS, 25 | ROUTING, 26 | ARGS, 27 | PIPE("|>"); 28 | 29 | private final byte[] raw; 30 | 31 | Keyword() { 32 | raw = SafeEncoder.encode(this.name()); 33 | } 34 | 35 | Keyword(String encodeStr) { 36 | raw = SafeEncoder.encode(encodeStr); 37 | } 38 | 39 | public byte[] getRaw() { 40 | return raw; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/Model.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import com.redislabs.redisai.exceptions.JRedisAIRunTimeException; 4 | import java.io.IOException; 5 | import java.net.URI; 6 | import java.nio.file.Files; 7 | import java.nio.file.Paths; 8 | import java.util.ArrayList; 9 | import java.util.Arrays; 10 | import java.util.List; 11 | import redis.clients.jedis.Protocol; 12 | import redis.clients.jedis.util.SafeEncoder; 13 | 14 | /** Direct mapping to RedisAI Model */ 15 | public class Model { 16 | 17 | public static final String BLOB_CHUNK_SIZE_PROPERTY = "redisai.blob.chunkSize"; 18 | 19 | private static final int BLOB_CHUNK_SIZE = 20 | Integer.parseInt(System.getProperty(BLOB_CHUNK_SIZE_PROPERTY, "536870912")); 21 | 22 | private Backend backend; // TODO: final 23 | private Device device; // TODO: final 24 | private String[] inputs; 25 | private String[] outputs; 26 | private byte[] blob; // TODO: final 27 | private String tag; 28 | private long batchSize; 29 | private long minBatchSize; 30 | private long minBatchTimeout; 31 | 32 | /** 33 | * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX 34 | * @param device - the device that will execute the model. can be of CPU or GPU 35 | * @param modelUri - filepath of the Protobuf-serialized model 36 | * @throws java.io.IOException 37 | * @see #Model(com.redislabs.redisai.Backend, com.redislabs.redisai.Device, byte[]) 38 | * @see Files#readAllBytes(java.nio.file.Path) 39 | * @see Paths#get(java.net.URI) 40 | */ 41 | public Model(Backend backend, Device device, URI modelUri) throws IOException { 42 | this(backend, device, Files.readAllBytes(Paths.get(modelUri))); 43 | } 44 | 45 | /** 46 | * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX 47 | * @param device - the device that will execute the model. can be of CPU or GPU 48 | * @param blob - the Protobuf-serialized model 49 | */ 50 | public Model(Backend backend, Device device, byte[] blob) { 51 | this.backend = backend; 52 | this.device = device; 53 | this.blob = blob; 54 | } 55 | 56 | /** 57 | * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX 58 | * @param device - the device that will execute the model. can be of CPU or GPU 59 | * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow 60 | * models) 61 | * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow 62 | * models) 63 | * @param blob - the Protobuf-serialized model 64 | */ 65 | public Model(Backend backend, Device device, String[] inputs, String[] outputs, byte[] blob) { 66 | this(backend, device, inputs, outputs, blob, 0, 0); 67 | } 68 | 69 | /** 70 | * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX 71 | * @param device - the device that will execute the model. can be of CPU or GPU 72 | * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow 73 | * models) 74 | * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow 75 | * models) 76 | * @param blob - the Protobuf-serialized model 77 | * @param batchSize - when provided with an batchsize that is greater than 0, the engine will 78 | * batch incoming requests from multiple clients that use the model with input tensors of the 79 | * same shape. 80 | * @param minBatchSize - when provided with an minbatchsize that is greater than 0, the engine 81 | * will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize 82 | */ 83 | public Model( 84 | Backend backend, 85 | Device device, 86 | String[] inputs, 87 | String[] outputs, 88 | byte[] blob, 89 | long batchSize, 90 | long minBatchSize) { 91 | this.backend = backend; 92 | this.device = device; 93 | this.inputs = inputs; 94 | this.outputs = outputs; 95 | this.blob = blob; 96 | this.tag = null; 97 | this.batchSize = batchSize; 98 | this.minBatchSize = minBatchSize; 99 | } 100 | 101 | public static Model createModelFromRespReply(List reply) { 102 | Backend backend = null; 103 | Device device = null; 104 | String tag = null; 105 | byte[] blob = null; 106 | long batchsize = 0; 107 | long minbatchsize = 0; 108 | long minbatchtimeout = 0; 109 | String[] inputs = new String[0]; 110 | String[] outputs = new String[0]; 111 | for (int i = 0; i < reply.size(); i += 2) { 112 | String arrayKey = SafeEncoder.encode((byte[]) reply.get(i)); 113 | switch (arrayKey) { 114 | case "backend": 115 | String backendString = SafeEncoder.encode((byte[]) reply.get(i + 1)); 116 | backend = Backend.valueOf(backendString); 117 | if (backend == null) { 118 | throw new JRedisAIRunTimeException("Unrecognized backend: " + backendString); 119 | } 120 | break; 121 | case "device": 122 | String deviceString = SafeEncoder.encode((byte[]) reply.get(i + 1)); 123 | device = Device.valueOf(deviceString); 124 | if (device == null) { 125 | throw new JRedisAIRunTimeException("Unrecognized device: " + deviceString); 126 | } 127 | break; 128 | case "tag": 129 | tag = SafeEncoder.encode((byte[]) reply.get(i + 1)); 130 | break; 131 | case "blob": 132 | blob = (byte[]) reply.get(i + 1); 133 | break; 134 | case "batchsize": 135 | batchsize = (Long) reply.get(i + 1); 136 | break; 137 | case "minbatchsize": 138 | minbatchsize = (Long) reply.get(i + 1); 139 | break; 140 | case "minbatchtimeout": 141 | minbatchtimeout = (Long) reply.get(i + 1); 142 | break; 143 | case "inputs": 144 | List inputsEncoded = (List) reply.get(i + 1); 145 | if (!inputsEncoded.isEmpty()) { 146 | inputs = new String[inputsEncoded.size()]; 147 | for (int j = 0; j < inputsEncoded.size(); j++) { 148 | inputs[j] = SafeEncoder.encode(inputsEncoded.get(j)); 149 | } 150 | } 151 | break; 152 | case "outputs": 153 | List outputsEncoded = (List) reply.get(i + 1); 154 | if (!outputsEncoded.isEmpty()) { 155 | outputs = new String[outputsEncoded.size()]; 156 | for (int j = 0; j < outputsEncoded.size(); j++) { 157 | outputs[j] = SafeEncoder.encode(outputsEncoded.get(j)); 158 | } 159 | } 160 | break; 161 | default: 162 | break; 163 | } 164 | } 165 | 166 | if (backend == null || device == null || blob == null) { 167 | throw new JRedisAIRunTimeException( 168 | "AI.MODELGET reply did not contained all elements to build the model"); 169 | } 170 | return new Model(backend, device, blob) 171 | .setInputs(inputs) 172 | .setOutputs(outputs) 173 | .setBatchSize(batchsize) 174 | .setMinBatchSize(minbatchsize) 175 | .setMinBatchTimeout(minbatchtimeout) 176 | .setTag(tag); 177 | } 178 | 179 | public String getTag() { 180 | return tag; 181 | } 182 | 183 | public Model setTag(String tag) { 184 | this.tag = tag; 185 | return this; 186 | } 187 | 188 | public byte[] getBlob() { 189 | return blob; 190 | } 191 | 192 | /** 193 | * @param blob 194 | * @deprecated This variable will be final. Use any constructor. 195 | */ 196 | @Deprecated 197 | public void setBlob(byte[] blob) { 198 | this.blob = blob; 199 | } 200 | 201 | public String[] getOutputs() { 202 | return outputs; 203 | } 204 | 205 | public Model setOutputs(String[] outputs) { 206 | this.outputs = outputs; 207 | return this; 208 | } 209 | 210 | public String[] getInputs() { 211 | return inputs; 212 | } 213 | 214 | public Model setInputs(String[] inputs) { 215 | this.inputs = inputs; 216 | return this; 217 | } 218 | 219 | public Device getDevice() { 220 | return device; 221 | } 222 | 223 | /** 224 | * @param device 225 | * @deprecated This variable will be final. Use any constructor. 226 | */ 227 | @Deprecated 228 | public void setDevice(Device device) { 229 | this.device = device; 230 | } 231 | 232 | public Backend getBackend() { 233 | return backend; 234 | } 235 | 236 | /** 237 | * @param backend 238 | * @deprecated This variable will be final. Use any constructor. 239 | */ 240 | @Deprecated 241 | public void setBackend(Backend backend) { 242 | this.backend = backend; 243 | } 244 | 245 | public long getBatchSize() { 246 | return batchSize; 247 | } 248 | 249 | public Model setBatchSize(long batchsize) { 250 | this.batchSize = batchsize; 251 | return this; 252 | } 253 | 254 | public long getMinBatchSize() { 255 | return minBatchSize; 256 | } 257 | 258 | public Model setMinBatchSize(long minbatchsize) { 259 | this.minBatchSize = minbatchsize; 260 | return this; 261 | } 262 | 263 | public long getMinBatchTimeout() { 264 | return minBatchTimeout; 265 | } 266 | 267 | public Model setMinBatchTimeout(long minBatchTimeout) { 268 | this.minBatchTimeout = minBatchTimeout; 269 | return this; 270 | } 271 | 272 | /** 273 | * Encodes the current model properties into an AI.MODELSET command to be store in RedisAI Server 274 | * 275 | * @param key name of key to store the Model 276 | * @return 277 | */ 278 | protected List getModelSetCommandBytes(String key) { 279 | List args = new ArrayList<>(); 280 | args.add(SafeEncoder.encode(key)); 281 | args.add(backend.getRaw()); 282 | args.add(device.getRaw()); 283 | if (tag != null) { 284 | args.add(Keyword.TAG.getRaw()); 285 | args.add(SafeEncoder.encode(tag)); 286 | } 287 | if (batchSize > 0) { 288 | args.add(Keyword.BATCHSIZE.getRaw()); 289 | args.add(Protocol.toByteArray(batchSize)); 290 | if (minBatchSize > 0) { 291 | args.add(Keyword.MINBATCHSIZE.getRaw()); 292 | args.add(Protocol.toByteArray(minBatchSize)); 293 | } 294 | } 295 | args.add(Keyword.INPUTS.getRaw()); 296 | for (String input : inputs) { 297 | args.add(SafeEncoder.encode(input)); 298 | } 299 | args.add(Keyword.OUTPUTS.getRaw()); 300 | for (String output : outputs) { 301 | args.add(SafeEncoder.encode(output)); 302 | } 303 | args.add(Keyword.BLOB.getRaw()); 304 | args.add(blob); 305 | return args; 306 | } 307 | 308 | /** 309 | * Encodes the current model properties into an AI.MODELSTORE command to store in RedisAI Server. 310 | * 311 | * @param key 312 | * @return 313 | */ 314 | protected List getModelStoreCommandArgs(String key) { 315 | 316 | List args = new ArrayList<>(); 317 | args.add(SafeEncoder.encode(key)); 318 | 319 | args.add(backend.getRaw()); 320 | args.add(device.getRaw()); 321 | 322 | if (tag != null) { 323 | args.add(Keyword.TAG.getRaw()); 324 | args.add(SafeEncoder.encode(tag)); 325 | } 326 | 327 | if (batchSize > 0) { 328 | args.add(Keyword.BATCHSIZE.getRaw()); 329 | args.add(Protocol.toByteArray(batchSize)); 330 | 331 | args.add(Keyword.MINBATCHSIZE.getRaw()); 332 | args.add(Protocol.toByteArray(minBatchSize)); 333 | 334 | args.add(Keyword.MINBATCHTIMEOUT.getRaw()); 335 | args.add(Protocol.toByteArray(minBatchTimeout)); 336 | } 337 | 338 | if (inputs != null && inputs.length > 0) { 339 | args.add(Keyword.INPUTS.getRaw()); 340 | args.add(Protocol.toByteArray(inputs.length)); 341 | for (String input : inputs) { 342 | args.add(SafeEncoder.encode(input)); 343 | } 344 | } 345 | 346 | if (outputs != null && outputs.length > 0) { 347 | args.add(Keyword.OUTPUTS.getRaw()); 348 | args.add(Protocol.toByteArray(outputs.length)); 349 | for (String output : outputs) { 350 | args.add(SafeEncoder.encode(output)); 351 | } 352 | } 353 | 354 | args.add(Keyword.BLOB.getRaw()); 355 | collectChunks(args, blob); 356 | 357 | return args; 358 | } 359 | 360 | private static void collectChunks(List collector, byte[] array) { 361 | final int chunkSize = BLOB_CHUNK_SIZE; 362 | if (chunkSize <= 0 || array.length <= chunkSize) { 363 | collector.add(array); 364 | return; 365 | } 366 | 367 | int from = 0; 368 | while (from < array.length) { 369 | int copySize = Math.min(array.length - from, chunkSize); 370 | collector.add(Arrays.copyOfRange(array, from, from + copySize)); 371 | from += copySize; 372 | } 373 | } 374 | 375 | protected static List modelRunFlatArgs( 376 | String key, String[] inputs, String[] outputs, boolean includeCommandName) { 377 | List args = new ArrayList<>(); 378 | if (includeCommandName) { 379 | args.add(Command.MODEL_RUN.getRaw()); 380 | } 381 | args.add(SafeEncoder.encode(key)); 382 | 383 | args.add(Keyword.INPUTS.getRaw()); 384 | for (String input : inputs) { 385 | args.add(SafeEncoder.encode(input)); 386 | } 387 | 388 | args.add(Keyword.OUTPUTS.getRaw()); 389 | for (String output : outputs) { 390 | args.add(SafeEncoder.encode(output)); 391 | } 392 | return args; 393 | } 394 | 395 | protected static List modelExecuteCommandArgs( 396 | String key, String[] inputs, String[] outputs, long timeout, boolean includeCommandName) { 397 | 398 | List args = new ArrayList<>(); 399 | if (includeCommandName) { 400 | args.add(Command.MODEL_EXECUTE.getRaw()); 401 | } 402 | args.add(SafeEncoder.encode(key)); 403 | 404 | args.add(Keyword.INPUTS.getRaw()); 405 | args.add(Protocol.toByteArray(inputs.length)); 406 | for (String input : inputs) { 407 | args.add(SafeEncoder.encode(input)); 408 | } 409 | 410 | args.add(Keyword.OUTPUTS.getRaw()); 411 | args.add(Protocol.toByteArray(outputs.length)); 412 | for (String output : outputs) { 413 | args.add(SafeEncoder.encode(output)); 414 | } 415 | 416 | if (timeout >= 0) { 417 | args.add(Keyword.TIMEOUT.getRaw()); 418 | args.add(Protocol.toByteArray(timeout)); 419 | } 420 | return args; 421 | } 422 | } 423 | -------------------------------------------------------------------------------- /src/main/java/com/redislabs/redisai/RedisAI.java: -------------------------------------------------------------------------------- 1 | package com.redislabs.redisai; 2 | 3 | import com.redislabs.redisai.exceptions.JRedisAIRunTimeException; 4 | import java.io.IOException; 5 | import java.nio.file.Files; 6 | import java.nio.file.Paths; 7 | import java.util.HashMap; 8 | import java.util.List; 9 | import java.util.Map; 10 | import org.apache.commons.pool2.impl.GenericObjectPoolConfig; 11 | import redis.clients.jedis.BinaryClient; 12 | import redis.clients.jedis.Client; 13 | import redis.clients.jedis.HostAndPort; 14 | import redis.clients.jedis.Jedis; 15 | import redis.clients.jedis.JedisClientConfig; 16 | import redis.clients.jedis.JedisPool; 17 | import redis.clients.jedis.JedisPoolConfig; 18 | import redis.clients.jedis.exceptions.JedisDataException; 19 | import redis.clients.jedis.util.Pool; 20 | import redis.clients.jedis.util.SafeEncoder; 21 | 22 | public class RedisAI implements AutoCloseable { 23 | 24 | private final Pool pool; 25 | 26 | /** Create a new RedisAI client with default connection to local host */ 27 | public RedisAI() { 28 | this("localhost", 6379); 29 | } 30 | 31 | /** 32 | * Create a new RedisAI client 33 | * 34 | * @param host the redis host 35 | * @param port the redis pot 36 | */ 37 | public RedisAI(String host, int port) { 38 | this(host, port, 500, 100); 39 | } 40 | 41 | /** 42 | * Create a new RedisAI client 43 | * 44 | * @param host the redis host 45 | * @param port the redis pot 46 | * @param timeout 47 | * @param poolSize 48 | */ 49 | public RedisAI(String host, int port, int timeout, int poolSize) { 50 | this(host, port, timeout, poolSize, null); 51 | } 52 | 53 | /** 54 | * Create a new RedisAI client 55 | * 56 | * @param host the redis host 57 | * @param port the redis pot 58 | * @param timeout 59 | * @param poolSize 60 | * @param password the password for authentication in a password protected Redis server 61 | */ 62 | public RedisAI(String host, int port, int timeout, int poolSize, String password) { 63 | this(new JedisPool(initPoolConfig(poolSize), host, port, timeout, password)); 64 | } 65 | 66 | /** 67 | * Create a new RedisAI client 68 | * 69 | * @param hostAndPort 70 | * @param clientConfig 71 | */ 72 | public RedisAI(HostAndPort hostAndPort, JedisClientConfig clientConfig) { 73 | this(new GenericObjectPoolConfig<>(), hostAndPort, clientConfig); 74 | } 75 | 76 | /** 77 | * Create a new RedisAI client 78 | * 79 | * @param hostAndPort 80 | * @param clientConfig 81 | * @param poolSize 82 | */ 83 | public RedisAI(HostAndPort hostAndPort, JedisClientConfig clientConfig, int poolSize) { 84 | this(initPoolConfig(poolSize), hostAndPort, clientConfig); 85 | } 86 | 87 | /** 88 | * Create a new RedisAI client 89 | * 90 | * @param poolConfig 91 | * @param hostAndPort 92 | * @param clientConfig 93 | */ 94 | public RedisAI( 95 | GenericObjectPoolConfig poolConfig, 96 | HostAndPort hostAndPort, 97 | JedisClientConfig clientConfig) { 98 | this(new JedisPool(poolConfig, hostAndPort, clientConfig)); 99 | } 100 | 101 | /** 102 | * Create a new RedisAI client 103 | * 104 | * @param pool jedis connection pool 105 | */ 106 | public RedisAI(Pool pool) { 107 | this.pool = pool; 108 | } 109 | 110 | @Override 111 | public void close() { 112 | this.pool.close(); 113 | } 114 | 115 | /** 116 | * Constructs JedisPoolConfig object. 117 | * 118 | * @param poolSize size of the JedisPool 119 | * @return {@link JedisPoolConfig} object with a few default settings 120 | */ 121 | private static JedisPoolConfig initPoolConfig(int poolSize) { 122 | JedisPoolConfig conf = new JedisPoolConfig(); 123 | conf.setMaxTotal(poolSize); 124 | conf.setTestOnBorrow(false); 125 | conf.setTestOnReturn(false); 126 | conf.setTestOnCreate(false); 127 | conf.setTestWhileIdle(false); 128 | conf.setMinEvictableIdleTimeMillis(60000); 129 | conf.setTimeBetweenEvictionRunsMillis(30000); 130 | conf.setNumTestsPerEvictionRun(-1); 131 | conf.setFairness(true); 132 | 133 | return conf; 134 | } 135 | 136 | private Jedis getConnection() { 137 | return pool.getResource(); 138 | } 139 | 140 | private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) { 141 | BinaryClient client = conn.getClient(); 142 | client.sendCommand(command, args); 143 | return client; 144 | } 145 | 146 | private Client sendCommand(Jedis conn, Command command, String... args) { 147 | Client client = conn.getClient(); 148 | client.sendCommand(command, args); 149 | return client; 150 | } 151 | 152 | /** 153 | * Direct mapping to AI.TENSORSET 154 | * 155 | * @param key name of key to store the Tensor 156 | * @param values multi-dimension numeric data 157 | * @param shape one or more dimensions, or the number of elements per axis, for the tensor 158 | * @return true if Tensor was properly set in RedisAI server 159 | */ 160 | public boolean setTensor(String key, Object values, int[] shape) { 161 | DataType dataType = DataType.baseObjType(values); 162 | long[] shapeL = new long[shape.length]; 163 | for (int i = 0; i < shape.length; i++) { 164 | shapeL[i] = shape[i]; 165 | } 166 | Tensor tensor = new Tensor(dataType, shapeL, values); 167 | return setTensor(key, tensor); 168 | } 169 | 170 | /** 171 | * Direct mapping to AI.TENSORSET 172 | * 173 | * @param key name of key to store the Tensor 174 | * @param tensor Tensor object 175 | * @return true if Tensor was properly set in RedisAI server 176 | */ 177 | public boolean setTensor(String key, Tensor tensor) { 178 | try (Jedis conn = getConnection()) { 179 | List args = tensor.tensorSetFlatArgs(key, false); 180 | return sendCommand(conn, Command.TENSOR_SET, args.toArray(new byte[args.size()][])) 181 | .getStatusCodeReply() 182 | .equals("OK"); 183 | 184 | } catch (JedisDataException ex) { 185 | throw new RedisAIException(ex); 186 | } 187 | } 188 | 189 | /** 190 | * Direct mapping to AI.TENSORGET 191 | * 192 | * @param key name of key to get the Tensor from 193 | * @return Tensor 194 | * @throws JRedisAIRunTimeException 195 | */ 196 | public Tensor getTensor(String key) { 197 | try (Jedis conn = getConnection()) { 198 | List args = Tensor.tensorGetFlatArgs(key, false); 199 | List reply = 200 | sendCommand(conn, Command.TENSOR_GET, args.toArray(new byte[args.size()][])) 201 | .getObjectMultiBulkReply(); 202 | if (reply.isEmpty()) { 203 | return null; 204 | } 205 | return Tensor.createTensorFromRespReply(reply); 206 | } 207 | } 208 | 209 | /** 210 | * Direct mapping to AI.MODELSET 211 | * 212 | * @param key name of key to store the Model 213 | * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX 214 | * @param device - the device that will execute the model. can be of CPU or GPU 215 | * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow 216 | * models) 217 | * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow 218 | * models) 219 | * @param modelPath - the file path for the Protobuf-serialized model 220 | * @return true if Model was properly set in RedisAI server 221 | */ 222 | public boolean setModel( 223 | String key, 224 | Backend backend, 225 | Device device, 226 | String[] inputs, 227 | String[] outputs, 228 | String modelPath) { 229 | 230 | try { 231 | byte[] blob = Files.readAllBytes(Paths.get(modelPath)); 232 | Model model = new Model(backend, device, inputs, outputs, blob); 233 | return setModel(key, model); 234 | } catch (IOException ex) { 235 | throw new RedisAIException(ex); 236 | } 237 | } 238 | 239 | /** 240 | * Direct mapping to AI.MODELSET 241 | * 242 | * @param key name of key to store the Model 243 | * @param model Model object 244 | * @return true if Model was properly set in RedisAI server 245 | */ 246 | public boolean setModel(String key, Model model) { 247 | 248 | try (Jedis conn = getConnection()) { 249 | List args = model.getModelSetCommandBytes(key); 250 | return sendCommand(conn, Command.MODEL_SET, args.toArray(new byte[args.size()][])) 251 | .getStatusCodeReply() 252 | .equals("OK"); 253 | } catch (JedisDataException ex) { 254 | throw new RedisAIException(ex); 255 | } 256 | } 257 | 258 | /** 259 | * Direct mapping to AI.MODELSTORE command. 260 | * 261 | *

{@code AI.MODELSTORE [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] 262 | * [INPUTS ...] [OUTPUTS ...] BLOB } 263 | * 264 | * @param key name of key to store the Model 265 | * @param model Model object 266 | * @return true if Model was properly stored in RedisAI server 267 | */ 268 | public boolean storeModel(String key, Model model) { 269 | try (Jedis conn = getConnection()) { 270 | List args = model.getModelStoreCommandArgs(key); 271 | return sendCommand(conn, Command.MODEL_STORE, args.toArray(new byte[args.size()][])) 272 | .getStatusCodeReply() 273 | .equals("OK"); 274 | } catch (JedisDataException ex) { 275 | throw new RedisAIException(ex.getMessage(), ex); 276 | } 277 | } 278 | 279 | /** 280 | * Direct mapping to AI.MODELGET 281 | * 282 | * @param key name of key to get the Model from RedisAI server 283 | * @return Model 284 | * @throws JRedisAIRunTimeException 285 | */ 286 | public Model getModel(String key) { 287 | try (Jedis conn = getConnection()) { 288 | List reply = 289 | sendCommand( 290 | conn, 291 | Command.MODEL_GET, 292 | SafeEncoder.encode(key), 293 | Keyword.META.getRaw(), 294 | Keyword.BLOB.getRaw()) 295 | .getObjectMultiBulkReply(); 296 | if (reply.isEmpty()) { 297 | return null; 298 | } 299 | return Model.createModelFromRespReply(reply); 300 | } 301 | } 302 | 303 | /** 304 | * Direct mapping to AI.MODELDEL 305 | * 306 | * @param key name of key to delete the Model 307 | * @return true if Model was properly delete in RedisAI server 308 | */ 309 | public boolean delModel(String key) { 310 | 311 | try (Jedis conn = getConnection()) { 312 | return sendCommand(conn, Command.MODEL_DEL, SafeEncoder.encode(key)) 313 | .getStatusCodeReply() 314 | .equals("OK"); 315 | } catch (JedisDataException ex) { 316 | throw new RedisAIException(ex); 317 | } 318 | } 319 | 320 | /** 321 | * Direct mapping to AI.SCRIPTSET 322 | * 323 | * @param key name of key to store the Script in RedisAI server 324 | * @param device - the device that will execute the model. can be of CPU or GPU 325 | * @param scriptFile - the file path for the script source code 326 | * @return true if Script was properly set in RedisAI server 327 | */ 328 | public boolean setScriptFile(String key, Device device, String scriptFile) { 329 | try { 330 | Script script = new Script(device, Paths.get(scriptFile)); 331 | return setScript(key, script); 332 | } catch (IOException ex) { 333 | throw new RedisAIException(ex); 334 | } 335 | } 336 | 337 | /** 338 | * Direct mapping to AI.SCRIPTSET 339 | * 340 | * @param key name of key to store the Script in RedisAI server 341 | * @param device - the device that will execute the model. can be of CPU or GPU 342 | * @param source - the script source code 343 | * @return true if Script was properly set in RedisAI server 344 | */ 345 | public boolean setScript(String key, Device device, String source) { 346 | Script script = new Script(device, source); 347 | return setScript(key, script); 348 | } 349 | 350 | /** 351 | * Direct mapping to AI.SCRIPTSET 352 | * 353 | * @param key name of key to store the Script in RedisAI server 354 | * @param script the Script Object 355 | * @return true if Script was properly set in RedisAI server 356 | */ 357 | public boolean setScript(String key, Script script) { 358 | try (Jedis conn = getConnection()) { 359 | List args = script.getScriptSetCommandBytes(key); 360 | return sendCommand(conn, Command.SCRIPT_SET, args.toArray(new byte[args.size()][])) 361 | .getStatusCodeReply() 362 | .equals("OK"); 363 | 364 | } catch (JedisDataException ex) { 365 | throw new RedisAIException(ex); 366 | } 367 | } 368 | 369 | /** 370 | * Direct mapping to AI.MODELSTORE command. 371 | * 372 | *

{@code AI.SCRIPTSTORE [TAG tag] ENTRY_POINTS 373 | * [...] SOURCE "