├── .circleci └── config.yml ├── .dockerignore ├── .gitignore ├── .mvn └── wrapper │ ├── MavenWrapperDownloader.java │ └── maven-wrapper.properties ├── README.md ├── ch02 └── ch2_Inside_the_Mind_of_a_Transformer.ipynb ├── config ├── config-docker-docker.properties ├── config-docker-kube.properties ├── config-jvm-docker.properties ├── config-jvm-kube.properties └── torch_server_config.properties ├── cover.png ├── data-management ├── README.md ├── demo.md ├── pom.xml └── src │ ├── main │ ├── java │ │ └── org │ │ │ └── orca3 │ │ │ └── miniAutoML │ │ │ └── dataManagement │ │ │ ├── DataManagementService.java │ │ │ ├── DatasetCompressor.java │ │ │ ├── DatasetIngestion.java │ │ │ ├── models │ │ │ ├── Dataset.java │ │ │ ├── IntentText.java │ │ │ ├── IntentTextCollection.java │ │ │ ├── Label.java │ │ │ └── MemoryStore.java │ │ │ └── transformers │ │ │ ├── DatasetTransformer.java │ │ │ ├── GenericTransformer.java │ │ │ └── IntentTextTransformer.java │ └── resources │ │ └── logback.xml │ └── test │ ├── java │ └── org │ │ └── orca3 │ │ └── miniAutoML │ │ └── dataManagement │ │ ├── DataManagementServiceTest.java │ │ └── transformers │ │ └── IntentTextTransformerTest.java │ └── resources │ ├── config-test.properties │ ├── datasets │ ├── demo-part1.csv │ ├── demo-part2.csv │ ├── demo-part3.csv │ ├── examples-1.csv │ ├── examples-test.csv │ ├── examples-train.csv │ ├── examples-validation.csv │ ├── intent-1.csv │ ├── labels-1.csv │ ├── labels-test.csv │ ├── labels-train.csv │ ├── labels-validation.csv │ ├── test.csv │ ├── train.csv │ └── validation.csv │ ├── genericDatasets │ ├── bar │ ├── coo │ ├── dzz │ ├── ell │ └── foo │ └── logback-test.xml ├── grpc-contract ├── pom.xml └── src │ └── main │ ├── java │ └── org │ │ └── orca3 │ │ └── miniAutoML │ │ └── ServiceBase.java │ ├── proto │ ├── data_management.proto │ ├── metadata_store.proto │ ├── prediction_service.proto │ ├── torch_management.proto │ ├── torch_serve.proto │ └── training_service.proto │ └── scripts │ └── python_code_gen.sh ├── metadata-store ├── pom.xml └── src │ └── main │ ├── java │ └── org │ │ └── orca3 │ │ └── miniAutoML │ │ └── metadataStore │ │ ├── MetadataStoreService.java │ │ └── models │ │ ├── ArtifactInfo.java │ │ ├── ArtifactRepo.java │ │ └── MemoryStore.java │ └── resources │ └── logback.xml ├── mvnw ├── mvnw.cmd ├── pom.xml ├── prediction-service ├── pom.xml └── src │ └── main │ ├── java │ └── org │ │ └── orca3 │ │ └── miniAutoML │ │ └── prediction │ │ ├── CustomGrpcPredictorBackend.java │ │ ├── PredictionService.java │ │ ├── PredictorBackend.java │ │ ├── PredictorConnectionManager.java │ │ └── TorchGrpcPredictorBackend.java │ └── resources │ └── logback.xml ├── predictor ├── Dockerfile ├── README.md ├── predict.py ├── prediction_service_pb2.py ├── prediction_service_pb2_grpc.py ├── sample_models │ └── 1 │ │ ├── intent_80bf0da.mar │ │ ├── manifest.json │ │ ├── model.pth │ │ └── vocab.pth └── utils.py ├── scripts ├── build-images-locally.sh ├── dm-001-start-minio.sh ├── dm-002-start-server.sh ├── dm-003-create-dataset.sh ├── dm-004-add-commits.sh ├── dm-005-prepare-dataset.sh ├── dm-006-prepare-partial-dataset.sh ├── dm-007-fetch-dataset-version.sh ├── env-vars.sh ├── lab-001-start-all.sh ├── lab-002-upload-data.sh ├── lab-003-first-training.sh ├── lab-004-model-serving.sh ├── lab-005-second-training.sh ├── lab-006-model-serving-torchserve.sh ├── lab-999-tear-down.sh ├── ms-001-start-minio.sh ├── ms-002-start-server.sh ├── ms-003-start-run.sh ├── ms-004-post-epoch.sh ├── ms-005-check-run-status.sh ├── ms-006-finish-run.sh ├── ms-008-get-artifact.sh ├── prepare_data.py ├── ps-001-start-predictor.sh ├── ps-002-start-server.sh ├── ps-003-predict.sh ├── ts-001-start-server-kube.sh ├── ts-001-start-server.sh ├── ts-002-start-run.sh ├── ts-003-check-run.sh ├── ts-004-start-parallel-run.sh └── ts-005-start-run-as-torch.sh ├── services.dockerfile ├── training-code └── text-classification │ ├── Dockerfile │ ├── Readme.md │ ├── data_management_pb2.py │ ├── data_management_pb2_grpc.py │ ├── examples.csv │ ├── labels.csv │ ├── metadata_store_pb2.py │ ├── metadata_store_pb2_grpc.py │ ├── orca3_utils.py │ ├── prediction_service_pb2.py │ ├── prediction_service_pb2_grpc.py │ ├── torchserve_handler.py │ ├── train.py │ ├── training_service_pb2.py │ ├── training_service_pb2_grpc.py │ └── version.py └── training-service ├── README.md ├── distributed_trainer_demo.md ├── pom.xml ├── single_trainer_demo.md └── src └── main ├── java └── org │ └── orca3 │ └── miniAutoML │ └── training │ ├── TrainingService.java │ ├── models │ ├── ExecutedTrainingJob.java │ └── MemoryStore.java │ └── tracker │ ├── DockerTracker.java │ ├── KubectlTracker.java │ └── Tracker.java └── resources └── logback.xml /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | jobs: 3 | build-and-push: 4 | machine: true 5 | steps: 6 | - checkout 7 | - run: | 8 | echo "$DOCKER_PASS" | docker login --username $DOCKER_USER --password-stdin 9 | 10 | - run: | 11 | if [ "${CIRCLE_BRANCH}" == "main" ]; then 12 | ./scripts/build-images-locally.sh 13 | docker push orca3/services:latest 14 | docker push orca3/intent-classification-predictor:latest 15 | docker push orca3/intent-classification:latest 16 | docker push orca3/intent-classification-torch:latest 17 | fi 18 | 19 | workflows: 20 | build: 21 | jobs: 22 | - build-and-push: 23 | context: 24 | - orca3 25 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | **/*.md 2 | **/src/test/ 3 | 4 | ### JetBrains template 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # Generated files 16 | .idea/**/contentModel.xml 17 | 18 | # Sensitive or high-churn files 19 | .idea/**/dataSources/ 20 | .idea/**/dataSources.ids 21 | .idea/**/dataSources.local.xml 22 | .idea/**/sqlDataSources.xml 23 | .idea/**/dynamic.xml 24 | .idea/**/uiDesigner.xml 25 | .idea/**/dbnavigator.xml 26 | 27 | # Gradle 28 | .idea/**/gradle.xml 29 | .idea/**/libraries 30 | 31 | # Gradle and Maven with auto-import 32 | # When using Gradle or Maven with auto-import, you should exclude module files, 33 | # since they will be recreated, and may cause churn. Uncomment if using 34 | # auto-import. 35 | # .idea/artifacts 36 | # .idea/compiler.xml 37 | # .idea/jarRepositories.xml 38 | # .idea/modules.xml 39 | # .idea/*.iml 40 | # .idea/modules 41 | # *.iml 42 | # *.ipr 43 | 44 | # CMake 45 | cmake-build-*/ 46 | 47 | # Mongo Explorer plugin 48 | .idea/**/mongoSettings.xml 49 | 50 | # File-based project format 51 | *.iws 52 | 53 | # IntelliJ 54 | out/ 55 | 56 | # mpeltonen/sbt-idea plugin 57 | .idea_modules/ 58 | 59 | # JIRA plugin 60 | atlassian-ide-plugin.xml 61 | 62 | # Cursive Clojure plugin 63 | .idea/replstate.xml 64 | 65 | # Crashlytics plugin (for Android Studio and IntelliJ) 66 | com_crashlytics_export_strings.xml 67 | crashlytics.properties 68 | crashlytics-build.properties 69 | fabric.properties 70 | 71 | # Editor-based Rest Client 72 | .idea/httpRequests 73 | 74 | # Android studio 3.1+ serialized cache file 75 | .idea/caches/build_file_checksums.ser 76 | 77 | ### Python template 78 | # Byte-compiled / optimized / DLL files 79 | __pycache__/ 80 | *.py[cod] 81 | *$py.class 82 | 83 | # C extensions 84 | *.so 85 | 86 | # Distribution / packaging 87 | .Python 88 | build/ 89 | develop-eggs/ 90 | dist/ 91 | downloads/ 92 | eggs/ 93 | .eggs/ 94 | lib/ 95 | lib64/ 96 | parts/ 97 | sdist/ 98 | var/ 99 | wheels/ 100 | share/python-wheels/ 101 | *.egg-info/ 102 | .installed.cfg 103 | *.egg 104 | MANIFEST 105 | 106 | # PyInstaller 107 | # Usually these files are written by a python script from a template 108 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 109 | *.manifest 110 | *.spec 111 | 112 | # Installer logs 113 | pip-log.txt 114 | pip-delete-this-directory.txt 115 | 116 | # Unit test / coverage reports 117 | htmlcov/ 118 | .tox/ 119 | .nox/ 120 | .coverage 121 | .coverage.* 122 | .cache 123 | nosetests.xml 124 | coverage.xml 125 | *.cover 126 | *.py,cover 127 | .hypothesis/ 128 | .pytest_cache/ 129 | cover/ 130 | 131 | # Translations 132 | *.mo 133 | *.pot 134 | 135 | # Django stuff: 136 | *.log 137 | local_settings.py 138 | db.sqlite3 139 | db.sqlite3-journal 140 | 141 | # Flask stuff: 142 | instance/ 143 | .webassets-cache 144 | 145 | # Scrapy stuff: 146 | .scrapy 147 | 148 | # Sphinx documentation 149 | docs/_build/ 150 | 151 | # PyBuilder 152 | .pybuilder/ 153 | target/ 154 | 155 | # Jupyter Notebook 156 | .ipynb_checkpoints 157 | 158 | # IPython 159 | profile_default/ 160 | ipython_config.py 161 | 162 | # pyenv 163 | # For a library or package, you might want to ignore these files since the code is 164 | # intended to run in multiple environments; otherwise, check them in: 165 | # .python-version 166 | 167 | # pipenv 168 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 169 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 170 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 171 | # install all needed dependencies. 172 | #Pipfile.lock 173 | 174 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 175 | __pypackages__/ 176 | 177 | # Celery stuff 178 | celerybeat-schedule 179 | celerybeat.pid 180 | 181 | # SageMath parsed files 182 | *.sage.py 183 | 184 | # Environments 185 | .env 186 | .venv 187 | env/ 188 | venv/ 189 | ENV/ 190 | env.bak/ 191 | venv.bak/ 192 | 193 | # Spyder project settings 194 | .spyderproject 195 | .spyproject 196 | 197 | # Rope project settings 198 | .ropeproject 199 | 200 | # mkdocs documentation 201 | /site 202 | 203 | # mypy 204 | .mypy_cache/ 205 | .dmypy.json 206 | dmypy.json 207 | 208 | # Pyre type checker 209 | .pyre/ 210 | 211 | # pytype static type analyzer 212 | .pytype/ 213 | 214 | # Cython debug symbols 215 | cython_debug/ 216 | 217 | ### Java template 218 | # Compiled class file 219 | *.class 220 | 221 | # Log file 222 | *.log 223 | 224 | # BlueJ files 225 | *.ctxt 226 | 227 | # Mobile Tools for Java (J2ME) 228 | .mtj.tmp/ 229 | 230 | # Package Files # 231 | *.jar 232 | *.war 233 | *.nar 234 | *.ear 235 | *.zip 236 | *.tar.gz 237 | *.rar 238 | 239 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 240 | hs_err_pid* 241 | 242 | ### macOS template 243 | # General 244 | .DS_Store 245 | .AppleDouble 246 | .LSOverride 247 | 248 | # Icon must end with two \r 249 | Icon 250 | 251 | # Thumbnails 252 | ._* 253 | 254 | # Files that might appear in the root of a volume 255 | .DocumentRevisions-V100 256 | .fseventsd 257 | .Spotlight-V100 258 | .TemporaryItems 259 | .Trashes 260 | .VolumeIcon.icns 261 | .com.apple.timemachine.donotpresent 262 | 263 | # Directories potentially created on remote AFP share 264 | .AppleDB 265 | .AppleDesktop 266 | Network Trash Folder 267 | Temporary Items 268 | .apdisk 269 | 270 | ### Maven template 271 | target/ 272 | pom.xml.tag 273 | pom.xml.releaseBackup 274 | pom.xml.versionsBackup 275 | pom.xml.next 276 | release.properties 277 | dependency-reduced-pom.xml 278 | buildNumber.properties 279 | .mvn/timing.properties 280 | # https://github.com/takari/maven-wrapper#usage-without-binary-jar 281 | .mvn/wrapper/maven-wrapper.jar 282 | 283 | -------------------------------------------------------------------------------- /.mvn/wrapper/MavenWrapperDownloader.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2007-present the original author or authors. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | import java.net.*; 17 | import java.io.*; 18 | import java.nio.channels.*; 19 | import java.util.Properties; 20 | 21 | public class MavenWrapperDownloader { 22 | 23 | private static final String WRAPPER_VERSION = "0.5.6"; 24 | /** 25 | * Default URL to download the maven-wrapper.jar from, if no 'downloadUrl' is provided. 26 | */ 27 | private static final String DEFAULT_DOWNLOAD_URL = "https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/" 28 | + WRAPPER_VERSION + "/maven-wrapper-" + WRAPPER_VERSION + ".jar"; 29 | 30 | /** 31 | * Path to the maven-wrapper.properties file, which might contain a downloadUrl property to 32 | * use instead of the default one. 33 | */ 34 | private static final String MAVEN_WRAPPER_PROPERTIES_PATH = 35 | ".mvn/wrapper/maven-wrapper.properties"; 36 | 37 | /** 38 | * Path where the maven-wrapper.jar will be saved to. 39 | */ 40 | private static final String MAVEN_WRAPPER_JAR_PATH = 41 | ".mvn/wrapper/maven-wrapper.jar"; 42 | 43 | /** 44 | * Name of the property which should be used to override the default download url for the wrapper. 45 | */ 46 | private static final String PROPERTY_NAME_WRAPPER_URL = "wrapperUrl"; 47 | 48 | public static void main(String args[]) { 49 | System.out.println("- Downloader started"); 50 | File baseDirectory = new File(args[0]); 51 | System.out.println("- Using base directory: " + baseDirectory.getAbsolutePath()); 52 | 53 | // If the maven-wrapper.properties exists, read it and check if it contains a custom 54 | // wrapperUrl parameter. 55 | File mavenWrapperPropertyFile = new File(baseDirectory, MAVEN_WRAPPER_PROPERTIES_PATH); 56 | String url = DEFAULT_DOWNLOAD_URL; 57 | if(mavenWrapperPropertyFile.exists()) { 58 | FileInputStream mavenWrapperPropertyFileInputStream = null; 59 | try { 60 | mavenWrapperPropertyFileInputStream = new FileInputStream(mavenWrapperPropertyFile); 61 | Properties mavenWrapperProperties = new Properties(); 62 | mavenWrapperProperties.load(mavenWrapperPropertyFileInputStream); 63 | url = mavenWrapperProperties.getProperty(PROPERTY_NAME_WRAPPER_URL, url); 64 | } catch (IOException e) { 65 | System.out.println("- ERROR loading '" + MAVEN_WRAPPER_PROPERTIES_PATH + "'"); 66 | } finally { 67 | try { 68 | if(mavenWrapperPropertyFileInputStream != null) { 69 | mavenWrapperPropertyFileInputStream.close(); 70 | } 71 | } catch (IOException e) { 72 | // Ignore ... 73 | } 74 | } 75 | } 76 | System.out.println("- Downloading from: " + url); 77 | 78 | File outputFile = new File(baseDirectory.getAbsolutePath(), MAVEN_WRAPPER_JAR_PATH); 79 | if(!outputFile.getParentFile().exists()) { 80 | if(!outputFile.getParentFile().mkdirs()) { 81 | System.out.println( 82 | "- ERROR creating output directory '" + outputFile.getParentFile().getAbsolutePath() + "'"); 83 | } 84 | } 85 | System.out.println("- Downloading to: " + outputFile.getAbsolutePath()); 86 | try { 87 | downloadFileFromURL(url, outputFile); 88 | System.out.println("Done"); 89 | System.exit(0); 90 | } catch (Throwable e) { 91 | System.out.println("- Error downloading"); 92 | e.printStackTrace(); 93 | System.exit(1); 94 | } 95 | } 96 | 97 | private static void downloadFileFromURL(String urlString, File destination) throws Exception { 98 | if (System.getenv("MVNW_USERNAME") != null && System.getenv("MVNW_PASSWORD") != null) { 99 | String username = System.getenv("MVNW_USERNAME"); 100 | char[] password = System.getenv("MVNW_PASSWORD").toCharArray(); 101 | Authenticator.setDefault(new Authenticator() { 102 | @Override 103 | protected PasswordAuthentication getPasswordAuthentication() { 104 | return new PasswordAuthentication(username, password); 105 | } 106 | }); 107 | } 108 | URL website = new URL(urlString); 109 | ReadableByteChannel rbc; 110 | rbc = Channels.newChannel(website.openStream()); 111 | FileOutputStream fos = new FileOutputStream(destination); 112 | fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); 113 | fos.close(); 114 | rbc.close(); 115 | } 116 | 117 | } 118 | -------------------------------------------------------------------------------- /.mvn/wrapper/maven-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.3/apache-maven-3.6.3-bin.zip 2 | wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar 3 | -------------------------------------------------------------------------------- /config/config-docker-docker.properties: -------------------------------------------------------------------------------- 1 | # minio 2 | minio.host=http://minio:9000 3 | minio.accessKey=foooo 4 | minio.secretKey=barbarbar 5 | 6 | # dm 7 | dm.minio.bucketName=mini-automl-dm 8 | dm.server.port=51001 9 | dm.server.host=data-management 10 | 11 | # ps 12 | ps.server.modelCachePath=/tmp/modelCache 13 | ps.server.port=51001 14 | ps.server.host=prediction-service 15 | ps.enabledPredictors=intent-classification,intent-classification-torch 16 | 17 | # ms 18 | ms.minio.bucketName=mini-automl-ms 19 | ms.server.port=51001 20 | ms.server.host=metadata-store 21 | 22 | # ts 23 | ts.server.port=51001 24 | ts.server.host=training-service 25 | ts.backend=docker 26 | ts.backend.dockerNetwork=orca3 27 | ts.trainer.minio.host=minio:9000 28 | ts.trainer.ms.host=metadata-store:51001 29 | 30 | # predictors 31 | predictors.intent-classification.host=intent-classification-predictor 32 | predictors.intent-classification.port=51001 33 | predictors.intent-classification.techStack=customGrpc 34 | 35 | predictors.intent-classification-torch.host=intent-classification-torch-predictor 36 | predictors.intent-classification-torch.port=7070 37 | predictors.intent-classification-torch.management-port=7071 38 | predictors.intent-classification-torch.techStack=torch 39 | -------------------------------------------------------------------------------- /config/config-docker-kube.properties: -------------------------------------------------------------------------------- 1 | # minio 2 | minio.host=http://minio:9000 3 | minio.accessKey=foooo 4 | minio.secretKey=barbarbar 5 | 6 | # dm 7 | dm.minio.bucketName=mini-automl-dm 8 | dm.server.port=51001 9 | dm.server.host=data-management 10 | 11 | # ps 12 | ps.server.modelCachePath=/tmp/modelCache 13 | ps.server.port=51001 14 | ps.server.host=prediction-service 15 | ps.enabledPredictors=intent-classification,intent-classification-torch 16 | 17 | # ms 18 | ms.minio.bucketName=mini-automl-ms 19 | ms.server.port=51001 20 | ms.server.host=metadata-store 21 | 22 | # ts 23 | ts.server.port=51001 24 | ts.server.host=training-service 25 | ts.backend=kubectl 26 | ts.backend.kubectlConfigFile=/.kube/config 27 | ts.backend.kubectlNamespace=orca3 28 | ts.trainer.minio.host=host.docker.internal:9000 29 | ts.trainer.ms.host=host.docker.internal:6002 30 | 31 | # predictors 32 | predictors.intent-classification.host=intent-classification-predictor 33 | predictors.intent-classification.port=51001 34 | predictors.intent-classification.techStack=customGrpc 35 | 36 | predictors.intent-classification-torch.host=intent-classification-torch-predictor 37 | predictors.intent-classification-torch.port=7070 38 | predictors.intent-classification-torch.management-port=7071 39 | predictors.intent-classification-torch.techStack=torch 40 | -------------------------------------------------------------------------------- /config/config-jvm-docker.properties: -------------------------------------------------------------------------------- 1 | # minio 2 | minio.host=http://localhost:9000 3 | minio.accessKey=foooo 4 | minio.secretKey=barbarbar 5 | 6 | # dm 7 | dm.minio.bucketName=mini-automl-dm 8 | dm.server.port=6000 9 | dm.server.host=localhost 10 | 11 | # ps 12 | ps.server.modelCachePath=model_cache 13 | ps.server.port=6001 14 | ps.server.host=localhost 15 | ps.enabledPredictors=intent-classification,intent-classification-torch 16 | 17 | # ms 18 | ms.minio.bucketName=mini-automl-ms 19 | ms.server.port=6002 20 | ms.server.host=localhost 21 | 22 | # ts 23 | ts.server.port=6003 24 | ts.server.host=localhost 25 | ts.backend=docker 26 | ts.backend.dockerNetwork=orca3 27 | ts.trainer.minio.host=host.docker.internal:9000 28 | ts.trainer.ms.host=host.docker.internal:6002 29 | 30 | # predictors 31 | predictors.intent-classification.host=localhost 32 | predictors.intent-classification.port=6101 33 | predictors.intent-classification.techStack=customGrpc 34 | 35 | predictors.intent-classification-torch.host=localhost 36 | predictors.intent-classification-torch.port=6102 37 | predictors.intent-classification-torch.management-port=6103 38 | predictors.intent-classification-torch.techStack=torch 39 | 40 | 41 | -------------------------------------------------------------------------------- /config/config-jvm-kube.properties: -------------------------------------------------------------------------------- 1 | # minio 2 | minio.host=http://localhost:9000 3 | minio.accessKey=foooo 4 | minio.secretKey=barbarbar 5 | 6 | # dm 7 | dm.minio.bucketName=mini-automl-dm 8 | dm.server.port=6000 9 | dm.server.host=localhost 10 | 11 | # ps 12 | ps.server.modelCachePath=model_cache 13 | ps.server.port=6001 14 | ps.server.host=localhost 15 | ps.enabledPredictors=intent-classification,intent-classification-torch 16 | 17 | # ms 18 | ms.minio.bucketName=mini-automl-ms 19 | ms.server.port=6002 20 | ms.server.host=localhost 21 | 22 | # ts 23 | ts.server.port=6003 24 | ts.server.host=localhost 25 | ts.backend=kubectl 26 | ts.backend.kubectlConfigFile= 27 | ts.backend.kubectlNamespace=orca3 28 | ts.trainer.minio.host=host.docker.internal:9000 29 | ts.trainer.ms.host=host.docker.internal:6002 30 | 31 | # predictors 32 | predictors.intent-classification.host=localhost 33 | predictors.intent-classification.port=6101 34 | predictors.intent-classification.techStack=customGrpc 35 | 36 | predictors.intent-classification-torch.host=localhost 37 | predictors.intent-classification-torch.port=6102 38 | predictors.intent-classification-torch.management-port=6103 39 | predictors.intent-classification-torch.techStack=torch 40 | -------------------------------------------------------------------------------- /config/torch_server_config.properties: -------------------------------------------------------------------------------- 1 | grpc_management_port=7071 2 | -------------------------------------------------------------------------------- /cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orca3/MiniAutoML/fa9eb34858f9a5a743489f31348700b4ce32154e/cover.png -------------------------------------------------------------------------------- /data-management/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Management (DM) Service 2 | Dataset management service is a sample Java (GRPC) webservice for demonstrating the design principles introduced in the chapter 4 of book - ``<>``. 3 | This service is written in minimalism (for example persisting data in memory instead of a database) so the code is simple to read, and the local setup is easy. The only external dependency this service take is **Minio**, which we used to mimic cloud 4 | blob storage, such as `AWS S3` or `Azure blob`. Please make sure the minio client has been installed already by typing `mc --version` 5 | 6 | By reading these code, you will obtain a concrete feeling of how the dataset management design concept could be implemented. 7 | 8 | ## Function demo 9 | 10 | See [demo.md](demo.md) 11 | 12 | -------- 13 | 14 | ## Build and play with DM locally 15 | 16 | ### Understand the config file 17 | The DM server takes a few configuration items on startup. This can be found at [config.properties](src/main/resources/config.properties) 18 | - `minio.bucketName`: The minio bucket name we want to use for DM service to store its file. 19 | - `minio.accessKey` & `minio.secretKey`: The credential used to access the minio server. 20 | - `minio.host`: The address of the minio server. 21 | - `server.port` The port number that this server listens to. 22 | 23 | ### Start dependency minio 24 | This can be taken care of by our script [dm-001-start-minio.sh](../scripts/dm-001-start-minio.sh) 25 | 26 | ### Build and run using docker (recommended) 27 | 1. Modify config if needed. Set `minio.host` to `http://minio:9000` 28 | 2. The [dockerfile](../services.dockerfile) in the root folder can be used to build the data-management service directly. Execute `docker build -t orca3/services:latest -f services.dockerfile .` in the root directly will build a docker image called `orca3/services` with `latest` tag. 29 | 3. Start the service using `docker run --name data-management --network orca3 --rm -d -p 6000:51001 orca3/services:latest data-management.jar`. 30 | 4. Now the service can be reached at `localhost:6000`. Try `grpcurl -plaintext localhost:6000 grpc.health.v1.Health/Check` or look at examples in [scripts](../scripts) folder to interact with the service 31 | 5. Everything above has the same effect as running our [dm-002-start-server.sh](../scripts/dm-002-start-server.sh) 32 | 33 | ### Build and run using java (for experienced Java developer) 34 | 1. Modify config if needed. Set `minio.host` to `http://localhost:9000`. Set `server.port` to an unoccupied port number. 35 | 2. Use maven to build the project and produce a runnable Jar `./mvnw clean package -pl data-management -am` 36 | 3. Run the jar using command `java -jar data-management/target/data-management-1.0-SNAPSHOT.jar` 37 | 4. Now the service can be reached at `localhost:51001`. Try open a new terminal tab and execute `grpcurl -plaintext localhost:51001 grpc.health.v1.Health/Check` or look at examples in [scripts](../scripts) folder to interact with the service 38 | -------------------------------------------------------------------------------- /data-management/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | org.orca3 9 | mini-auto-ml 10 | 1.0-SNAPSHOT 11 | 12 | 13 | data-management 14 | 1.0-SNAPSHOT 15 | 16 | 17 | org.orca3.miniAutoML.dataManagement.DataManagementService 18 | 19 | 20 | 21 | 22 | org.orca3 23 | grpc-contract 24 | ${project.version} 25 | 26 | 27 | 28 | io.grpc 29 | grpc-services 30 | 31 | 32 | ch.qos.logback 33 | logback-classic 34 | 35 | 36 | com.opencsv 37 | opencsv 38 | 4.1 39 | 40 | 41 | io.minio 42 | minio 43 | 44 | 45 | 46 | io.grpc 47 | grpc-testing 48 | test 49 | 50 | 51 | junit 52 | junit 53 | RELEASE 54 | test 55 | 56 | 57 | org.testcontainers 58 | testcontainers 59 | 1.15.3 60 | test 61 | 62 | 63 | 64 | 65 | 66 | org.codehaus.mojo 67 | exec-maven-plugin 68 | 3.0.0 69 | 70 | 71 | 72 | java 73 | 74 | 75 | 76 | 77 | ${mainClass} 78 | 79 | 80 | 81 | org.apache.maven.plugins 82 | maven-shade-plugin 83 | 84 | 85 | package 86 | 87 | shade 88 | 89 | 90 | 91 | 92 | ${mainClass} 93 | 94 | 95 | 96 | 97 | 98 | *:* 99 | 100 | META-INF/*.SF 101 | META-INF/*.DSA 102 | META-INF/*.RSA 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/DatasetCompressor.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement; 2 | 3 | import io.minio.MinioClient; 4 | import org.orca3.miniAutoML.dataManagement.models.Dataset; 5 | import org.orca3.miniAutoML.dataManagement.models.MemoryStore; 6 | import org.orca3.miniAutoML.dataManagement.transformers.DatasetTransformer; 7 | import org.orca3.miniAutoML.dataManagement.transformers.GenericTransformer; 8 | import org.orca3.miniAutoML.dataManagement.transformers.IntentTextTransformer; 9 | 10 | import java.util.List; 11 | 12 | public class DatasetCompressor implements Runnable { 13 | private final MinioClient minioClient; 14 | private final String datasetId; 15 | private final DatasetType datasetType; 16 | private final MemoryStore store; 17 | private final List datasetParts; 18 | private final String versionHash; 19 | private final String bucketName; 20 | 21 | public DatasetCompressor(MinioClient minioClient, MemoryStore store, String datasetId, 22 | DatasetType datasetType, List datasetParts, 23 | String versionHash, String bucketName) { 24 | this.minioClient = minioClient; 25 | this.store = store; 26 | this.datasetId = datasetId; 27 | this.datasetType = datasetType; 28 | this.datasetParts = datasetParts; 29 | this.versionHash = versionHash; 30 | this.bucketName = bucketName; 31 | } 32 | 33 | @Override 34 | public void run() { 35 | VersionedSnapshot versionHashDataset; 36 | DatasetTransformer transformer; 37 | Dataset dataset = store.datasets.get(datasetId); 38 | switch (datasetType) { 39 | case TEXT_INTENT: 40 | transformer = new IntentTextTransformer(); 41 | break; 42 | case GENERIC: 43 | default: 44 | transformer = new GenericTransformer(); 45 | } 46 | try { 47 | versionHashDataset = transformer.compress(datasetParts, datasetId, versionHash, bucketName, minioClient); 48 | } catch (Exception e) { 49 | store.datasets.get(datasetId).versionHashRegistry.put(versionHash, VersionedSnapshot.newBuilder() 50 | .setDatasetId(datasetId).setVersionHash(versionHash).setState(SnapshotState.FAILED).build()); 51 | throw new RuntimeException(e); 52 | } 53 | dataset.versionHashRegistry.put(versionHash, versionHashDataset); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/DatasetIngestion.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement; 2 | 3 | import io.minio.MinioClient; 4 | import org.orca3.miniAutoML.dataManagement.transformers.DatasetTransformer; 5 | import org.orca3.miniAutoML.dataManagement.transformers.GenericTransformer; 6 | import org.orca3.miniAutoML.dataManagement.transformers.IntentTextTransformer; 7 | 8 | public class DatasetIngestion { 9 | 10 | public static CommitInfo.Builder ingest(MinioClient minioClient, String datasetId, String commitId, DatasetType datasetType, String ingestBucket, String ingestPath, String bucketName) { 11 | DatasetTransformer transformer; 12 | switch (datasetType) { 13 | case TEXT_INTENT: 14 | transformer = new IntentTextTransformer(); 15 | break; 16 | case GENERIC: 17 | default: 18 | transformer = new GenericTransformer(); 19 | } 20 | try { 21 | return transformer.ingest(ingestBucket, ingestPath, datasetId, commitId, bucketName, minioClient); 22 | } catch (Exception e) { 23 | throw new RuntimeException(e); 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/models/Dataset.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.models; 2 | 3 | import org.orca3.miniAutoML.dataManagement.CommitInfo; 4 | import org.orca3.miniAutoML.dataManagement.DatasetSummary; 5 | import org.orca3.miniAutoML.dataManagement.DatasetType; 6 | import org.orca3.miniAutoML.dataManagement.VersionedSnapshot; 7 | 8 | import java.time.Instant; 9 | import java.util.HashMap; 10 | import java.util.Map; 11 | import java.util.concurrent.atomic.AtomicInteger; 12 | 13 | import static java.time.format.DateTimeFormatter.ISO_INSTANT; 14 | 15 | public class Dataset { 16 | private final String datasetId; 17 | private final String name; 18 | private final String description; 19 | private final DatasetType datasetType; 20 | private final String updatedAt; 21 | private final AtomicInteger commitIdSeed; 22 | public final Map commits; 23 | public final Map versionHashRegistry; 24 | 25 | public Dataset(String datasetId, String name, String description, DatasetType datasetType, String updatedAt) { 26 | this.datasetId = datasetId; 27 | this.name = name; 28 | this.description = description; 29 | this.datasetType = datasetType; 30 | this.updatedAt = updatedAt; 31 | this.commits = new HashMap<>(); 32 | this.commitIdSeed = new AtomicInteger(); 33 | this.versionHashRegistry = new HashMap<>(); 34 | } 35 | 36 | public Dataset(String datasetId, String name, String description, DatasetType datasetType) { 37 | this(datasetId, name, description, datasetType, ISO_INSTANT.format(Instant.now())); 38 | } 39 | 40 | public int getNextCommitId() { 41 | return commitIdSeed.incrementAndGet(); 42 | } 43 | 44 | public int getLastCommitId() { 45 | return commitIdSeed.get(); 46 | } 47 | 48 | public String getDatasetId() { 49 | return datasetId; 50 | } 51 | 52 | public String getName() { 53 | return name; 54 | } 55 | 56 | public String getDescription() { 57 | return description; 58 | } 59 | 60 | public DatasetType getDatasetType() { 61 | return datasetType; 62 | } 63 | 64 | public String getUpdatedAt() { 65 | return updatedAt; 66 | } 67 | 68 | public DatasetSummary toDatasetSummary() { 69 | DatasetSummary.Builder builder = DatasetSummary.newBuilder() 70 | .setDatasetId(datasetId) 71 | .setName(getName()) 72 | .setDescription(getDescription()) 73 | .setDatasetType(getDatasetType()) 74 | .setLastUpdatedAt(getUpdatedAt()) 75 | .addAllCommits(commits.values()); 76 | return builder.build(); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/models/IntentText.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.models; 2 | 3 | import com.opencsv.bean.CsvBindByPosition; 4 | 5 | import java.util.Objects; 6 | 7 | public class IntentText { 8 | @CsvBindByPosition(position = 0) 9 | private String utterance; 10 | 11 | @CsvBindByPosition(position = 1) 12 | private String labels; 13 | 14 | public String getUtterance() { 15 | return utterance; 16 | } 17 | 18 | public IntentText utterance(String utterance) { 19 | this.utterance = utterance; 20 | return this; 21 | } 22 | 23 | public String getLabels() { 24 | return labels; 25 | } 26 | 27 | public String[] getSplicedLabels() { 28 | return labels.split(";"); 29 | } 30 | 31 | public IntentText labels(String labels) { 32 | this.labels = labels; 33 | return this; 34 | } 35 | 36 | public IntentText labels(String[] labels) { 37 | this.labels = String.join(";", labels); 38 | return this; 39 | } 40 | 41 | @Override 42 | public String toString() { 43 | return "IntentText{" + 44 | "utterance='" + utterance + '\'' + 45 | ", labels='" + labels + '\'' + 46 | '}'; 47 | } 48 | 49 | @Override 50 | public boolean equals(Object o) { 51 | if (this == o) return true; 52 | if (o == null || getClass() != o.getClass()) return false; 53 | IntentText that = (IntentText) o; 54 | return Objects.equals(utterance, that.utterance) && Objects.equals(labels, that.labels); 55 | } 56 | 57 | @Override 58 | public int hashCode() { 59 | return Objects.hash(utterance, labels); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/models/IntentTextCollection.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.models; 2 | 3 | import com.google.common.collect.ImmutableMap; 4 | import com.google.common.collect.Lists; 5 | import com.google.common.collect.Maps; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | public class IntentTextCollection { 11 | private List texts = Lists.newArrayList(); 12 | /** 13 | * Key: label string 14 | * Value: label id 15 | */ 16 | private Map labels = Maps.newHashMap(); 17 | 18 | public List getTexts() { 19 | return texts; 20 | } 21 | 22 | public IntentTextCollection texts(List texts) { 23 | this.texts = texts; 24 | return this; 25 | } 26 | 27 | public Map getLabels() { 28 | return labels; 29 | } 30 | 31 | public IntentTextCollection labels(Map labels) { 32 | this.labels = labels; 33 | return this; 34 | } 35 | 36 | public Map stats() { 37 | return ImmutableMap.builder() 38 | .put("numLabels", Integer.toString(labels.size())) 39 | .put("numExamples", Integer.toString(texts.size())) 40 | .build(); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/models/Label.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.models; 2 | 3 | import com.opencsv.bean.CsvBindByPosition; 4 | 5 | import java.util.Objects; 6 | 7 | public class Label { 8 | @CsvBindByPosition(position = 0) 9 | private String label; 10 | 11 | @CsvBindByPosition(position = 1) 12 | private String index; 13 | 14 | public String getIndex() { 15 | return index; 16 | } 17 | 18 | public Label index(String index) { 19 | this.index = index; 20 | return this; 21 | } 22 | 23 | public String getLabel() { 24 | return label; 25 | } 26 | 27 | public Label label(String label) { 28 | this.label = label; 29 | return this; 30 | } 31 | 32 | @Override 33 | public String toString() { 34 | return "Label{" + 35 | "index='" + index + '\'' + 36 | ", label='" + label + '\'' + 37 | '}'; 38 | } 39 | 40 | @Override 41 | public boolean equals(Object o) { 42 | if (this == o) return true; 43 | if (o == null || getClass() != o.getClass()) return false; 44 | Label label1 = (Label) o; 45 | return Objects.equals(index, label1.index) && Objects.equals(label, label1.label); 46 | } 47 | 48 | @Override 49 | public int hashCode() { 50 | return Objects.hash(index, label); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/models/MemoryStore.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.models; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.concurrent.atomic.AtomicInteger; 6 | 7 | public class MemoryStore { 8 | public final Map datasets; 9 | public final AtomicInteger datasetIdSeed; 10 | 11 | public MemoryStore() { 12 | this.datasets = new HashMap<>(); 13 | this.datasetIdSeed = new AtomicInteger(); 14 | } 15 | 16 | public void clear() { 17 | datasets.clear(); 18 | datasetIdSeed.set(0); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/transformers/DatasetTransformer.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.transformers; 2 | 3 | import io.minio.MinioClient; 4 | import io.minio.errors.MinioException; 5 | import org.orca3.miniAutoML.dataManagement.CommitInfo; 6 | import org.orca3.miniAutoML.dataManagement.DatasetPart; 7 | import org.orca3.miniAutoML.dataManagement.VersionedSnapshot; 8 | 9 | import java.nio.file.Paths; 10 | import java.util.List; 11 | 12 | public interface DatasetTransformer { 13 | VersionedSnapshot compress(List parts, String datasetId, String versionHash, String bucketName, MinioClient minioClient) throws MinioException; 14 | 15 | CommitInfo.Builder ingest(String ingestBucket, String ingestPath, String datasetId, String commitId, String bucketName, MinioClient minioClient) throws MinioException; 16 | 17 | static String getCommitRoot(String datasetId, String commitId) { 18 | return Paths.get("dataset", datasetId, "commit", commitId).toString(); 19 | } 20 | 21 | static String getVersionHashRoot(String datasetId, String versionHash) { 22 | return Paths.get("versionedDatasets", datasetId, versionHash).toString(); 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /data-management/src/main/java/org/orca3/miniAutoML/dataManagement/transformers/GenericTransformer.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.dataManagement.transformers; 2 | 3 | import io.minio.CopyObjectArgs; 4 | import io.minio.CopySource; 5 | import io.minio.ListObjectsArgs; 6 | import io.minio.MinioClient; 7 | import io.minio.Result; 8 | import io.minio.errors.MinioException; 9 | import io.minio.messages.Item; 10 | import org.orca3.miniAutoML.dataManagement.CommitInfo; 11 | import org.orca3.miniAutoML.dataManagement.DatasetPart; 12 | import org.orca3.miniAutoML.dataManagement.FileInfo; 13 | import org.orca3.miniAutoML.dataManagement.SnapshotState; 14 | import org.orca3.miniAutoML.dataManagement.VersionedSnapshot; 15 | 16 | import java.io.IOException; 17 | import java.nio.file.Paths; 18 | import java.security.InvalidKeyException; 19 | import java.security.NoSuchAlgorithmException; 20 | import java.time.Instant; 21 | import java.util.List; 22 | 23 | import static java.time.format.DateTimeFormatter.ISO_INSTANT; 24 | 25 | public class GenericTransformer implements DatasetTransformer { 26 | 27 | @Override 28 | public VersionedSnapshot compress(List parts, String datasetId, String versionHash, String bucketName, MinioClient minioClient) throws MinioException { 29 | String versionHashRoot = DatasetTransformer.getVersionHashRoot(datasetId, versionHash); 30 | VersionedSnapshot.Builder versionSnapshotBuilder = VersionedSnapshot.newBuilder() 31 | .setDatasetId(datasetId).setVersionHash(versionHash).setState(SnapshotState.READY) 32 | .setRoot(FileInfo.newBuilder().setName("root").setBucket(bucketName).setPath(versionHashRoot).build()); 33 | for (int i = 0; i < parts.size(); i++) { 34 | int j = 0; 35 | for (Result r : minioClient.listObjects(ListObjectsArgs.builder().bucket(bucketName).prefix(parts.get(i).getPathPrefix()).build())) { 36 | String newFilename = String.format("part-%d-%d", i, j); 37 | String newPath = Paths.get(versionHashRoot, newFilename).toString(); 38 | try { 39 | minioClient.copyObject(CopyObjectArgs.builder() 40 | .bucket(bucketName).object(newPath) 41 | .source(CopySource.builder().bucket(bucketName).object(r.get().objectName()).build()) 42 | .build()); 43 | versionSnapshotBuilder.addParts(FileInfo.newBuilder().setName(newFilename).setPath(newPath).setBucket(bucketName).build()); 44 | } catch (InvalidKeyException | IOException | NoSuchAlgorithmException e) { 45 | throw new RuntimeException(e); 46 | } 47 | j++; 48 | } 49 | } 50 | return versionSnapshotBuilder 51 | .build(); 52 | } 53 | 54 | @Override 55 | public CommitInfo.Builder ingest(String ingestBucket, String ingestPath, String datasetId, String commitId, String bucketName, MinioClient minioClient) throws MinioException { 56 | int i = 0; 57 | String commitRoot = DatasetTransformer.getCommitRoot(datasetId, commitId); 58 | for (Result r : minioClient.listObjects(ListObjectsArgs.builder().bucket(bucketName).prefix(ingestPath).build())) { 59 | try { 60 | minioClient.copyObject(CopyObjectArgs.builder() 61 | .bucket(bucketName).object(Paths.get(commitRoot, String.format("part-%d", i)).toString()) 62 | .source(CopySource.builder().bucket(bucketName).object(r.get().objectName()).build()) 63 | .build()); 64 | } catch (InvalidKeyException | IOException | NoSuchAlgorithmException e) { 65 | throw new RuntimeException(e); 66 | } 67 | i++; 68 | } 69 | return CommitInfo.newBuilder() 70 | .setDatasetId(datasetId) 71 | .setCommitId(commitId) 72 | .setCreatedAt(ISO_INSTANT.format(Instant.now())) 73 | .setPath(commitRoot); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /data-management/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /data-management/src/test/resources/config-test.properties: -------------------------------------------------------------------------------- 1 | minio.bucketName=mini-automl-dm-test 2 | minio.accessKey=foooo 3 | minio.secretKey=barbarbar 4 | minio.host=http://localhost:9000 5 | server.port=51002 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/examples-1.csv: -------------------------------------------------------------------------------- 1 | "I am still waiting on my credit card?",0;1 2 | "I couldn’t purchase gas in Costco",2 3 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/intent-1.csv: -------------------------------------------------------------------------------- 1 | "I am still waiting on my credit card?",activate_my_card;card_arrival 2 | "I couldn’t purchase gas in Costco",card_not_working 3 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/labels-1.csv: -------------------------------------------------------------------------------- 1 | 0,activate_my_card 2 | 1,card_arrival 3 | 2,card_not_working 4 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/labels-test.csv: -------------------------------------------------------------------------------- 1 | 0,restaurant_reviews 2 | 1,nutrition_info 3 | 2,account_blocked 4 | 3,oil_change_how 5 | 4,time 6 | 5,weather 7 | 6,redeem_rewards 8 | 7,interest_rate 9 | 8,gas_type 10 | 9,accept_reservations 11 | 10,smart_home 12 | 11,user_name 13 | 12,report_lost_card 14 | 13,repeat 15 | 14,whisper_mode 16 | 15,what_are_your_hobbies 17 | 16,order 18 | 17,jump_start 19 | 18,schedule_meeting 20 | 19,meeting_schedule 21 | 20,freeze_account 22 | 21,what_song 23 | 22,meaning_of_life 24 | 23,restaurant_reservation 25 | 24,traffic 26 | 25,make_call 27 | 26,text 28 | 27,bill_balance 29 | 28,improve_credit_score 30 | 29,change_language 31 | 30,no 32 | 31,measurement_conversion 33 | 32,timer 34 | 33,flip_coin 35 | 34,do_you_have_pets 36 | 35,balance 37 | 36,tell_joke 38 | 37,last_maintenance 39 | 38,exchange_rate 40 | 39,uber 41 | 40,car_rental 42 | 41,credit_limit 43 | 42,oos 44 | 43,shopping_list 45 | 44,expiration_date 46 | 45,routing 47 | 46,meal_suggestion 48 | 47,tire_change 49 | 48,todo_list 50 | 49,card_declined 51 | 50,rewards_balance 52 | 51,change_accent 53 | 52,vaccines 54 | 53,reminder_update 55 | 54,food_last 56 | 55,change_ai_name 57 | 56,bill_due 58 | 57,who_do_you_work_for 59 | 58,share_location 60 | 59,international_visa 61 | 60,calendar 62 | 61,translate 63 | 62,carry_on 64 | 63,book_flight 65 | 64,insurance_change 66 | 65,todo_list_update 67 | 66,timezone 68 | 67,cancel_reservation 69 | 68,transactions 70 | 69,credit_score 71 | 70,report_fraud 72 | 71,spending_history 73 | 72,directions 74 | 73,spelling 75 | 74,insurance 76 | 75,what_is_your_name 77 | 76,reminder 78 | 77,where_are_you_from 79 | 78,distance 80 | 79,payday 81 | 80,flight_status 82 | 81,find_phone 83 | 82,greeting 84 | 83,alarm 85 | 84,order_status 86 | 85,confirm_reservation 87 | 86,cook_time 88 | 87,damaged_card 89 | 88,reset_settings 90 | 89,pin_change 91 | 90,replacement_card_duration 92 | 91,new_card 93 | 92,roll_dice 94 | 93,income 95 | 94,taxes 96 | 95,date 97 | 96,who_made_you 98 | 97,pto_request 99 | 98,tire_pressure 100 | 99,how_old_are_you 101 | 100,rollover_401k 102 | 101,pto_request_status 103 | 102,how_busy 104 | 103,application_status 105 | 104,recipe 106 | 105,calendar_update 107 | 106,play_music 108 | 107,yes 109 | 108,direct_deposit 110 | 109,credit_limit_change 111 | 110,gas 112 | 111,pay_bill 113 | 112,ingredients_list 114 | 113,lost_luggage 115 | 114,goodbye 116 | 115,what_can_i_ask_you 117 | 116,book_hotel 118 | 117,are_you_a_bot 119 | 118,next_song 120 | 119,change_speed 121 | 120,plug_type 122 | 121,maybe 123 | 122,w2 124 | 123,oil_change_when 125 | 124,thank_you 126 | 125,shopping_list_update 127 | 126,pto_balance 128 | 127,order_checks 129 | 128,travel_alert 130 | 129,fun_fact 131 | 130,sync_device 132 | 131,schedule_maintenance 133 | 132,apr 134 | 133,transfer 135 | 134,ingredient_substitution 136 | 135,calories 137 | 136,current_location 138 | 137,international_fees 139 | 138,calculator 140 | 139,definition 141 | 140,next_holiday 142 | 141,update_playlist 143 | 142,mpg 144 | 143,min_payment 145 | 144,change_user_name 146 | 145,restaurant_suggestion 147 | 146,travel_notification 148 | 147,cancel 149 | 148,pto_used 150 | 149,travel_suggestion 151 | 150,change_volume 152 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/labels-train.csv: -------------------------------------------------------------------------------- 1 | 0,restaurant_reviews 2 | 1,nutrition_info 3 | 2,account_blocked 4 | 3,oil_change_how 5 | 4,time 6 | 5,weather 7 | 6,redeem_rewards 8 | 7,interest_rate 9 | 8,gas_type 10 | 9,accept_reservations 11 | 10,smart_home 12 | 11,user_name 13 | 12,report_lost_card 14 | 13,repeat 15 | 14,whisper_mode 16 | 15,what_are_your_hobbies 17 | 16,order 18 | 17,jump_start 19 | 18,schedule_meeting 20 | 19,meeting_schedule 21 | 20,freeze_account 22 | 21,what_song 23 | 22,meaning_of_life 24 | 23,restaurant_reservation 25 | 24,traffic 26 | 25,make_call 27 | 26,text 28 | 27,bill_balance 29 | 28,improve_credit_score 30 | 29,change_language 31 | 30,no 32 | 31,measurement_conversion 33 | 32,timer 34 | 33,flip_coin 35 | 34,do_you_have_pets 36 | 35,balance 37 | 36,tell_joke 38 | 37,last_maintenance 39 | 38,exchange_rate 40 | 39,uber 41 | 40,car_rental 42 | 41,credit_limit 43 | 42,oos 44 | 43,shopping_list 45 | 44,expiration_date 46 | 45,routing 47 | 46,meal_suggestion 48 | 47,tire_change 49 | 48,todo_list 50 | 49,card_declined 51 | 50,rewards_balance 52 | 51,change_accent 53 | 52,vaccines 54 | 53,reminder_update 55 | 54,food_last 56 | 55,change_ai_name 57 | 56,bill_due 58 | 57,who_do_you_work_for 59 | 58,share_location 60 | 59,international_visa 61 | 60,calendar 62 | 61,translate 63 | 62,carry_on 64 | 63,book_flight 65 | 64,insurance_change 66 | 65,todo_list_update 67 | 66,timezone 68 | 67,cancel_reservation 69 | 68,transactions 70 | 69,credit_score 71 | 70,report_fraud 72 | 71,spending_history 73 | 72,directions 74 | 73,spelling 75 | 74,insurance 76 | 75,what_is_your_name 77 | 76,reminder 78 | 77,where_are_you_from 79 | 78,distance 80 | 79,payday 81 | 80,flight_status 82 | 81,find_phone 83 | 82,greeting 84 | 83,alarm 85 | 84,order_status 86 | 85,confirm_reservation 87 | 86,cook_time 88 | 87,damaged_card 89 | 88,reset_settings 90 | 89,pin_change 91 | 90,replacement_card_duration 92 | 91,new_card 93 | 92,roll_dice 94 | 93,income 95 | 94,taxes 96 | 95,date 97 | 96,who_made_you 98 | 97,pto_request 99 | 98,tire_pressure 100 | 99,how_old_are_you 101 | 100,rollover_401k 102 | 101,pto_request_status 103 | 102,how_busy 104 | 103,application_status 105 | 104,recipe 106 | 105,calendar_update 107 | 106,play_music 108 | 107,yes 109 | 108,direct_deposit 110 | 109,credit_limit_change 111 | 110,gas 112 | 111,pay_bill 113 | 112,ingredients_list 114 | 113,lost_luggage 115 | 114,goodbye 116 | 115,what_can_i_ask_you 117 | 116,book_hotel 118 | 117,are_you_a_bot 119 | 118,next_song 120 | 119,change_speed 121 | 120,plug_type 122 | 121,maybe 123 | 122,w2 124 | 123,oil_change_when 125 | 124,thank_you 126 | 125,shopping_list_update 127 | 126,pto_balance 128 | 127,order_checks 129 | 128,travel_alert 130 | 129,fun_fact 131 | 130,sync_device 132 | 131,schedule_maintenance 133 | 132,apr 134 | 133,transfer 135 | 134,ingredient_substitution 136 | 135,calories 137 | 136,current_location 138 | 137,international_fees 139 | 138,calculator 140 | 139,definition 141 | 140,next_holiday 142 | 141,update_playlist 143 | 142,mpg 144 | 143,min_payment 145 | 144,change_user_name 146 | 145,restaurant_suggestion 147 | 146,travel_notification 148 | 147,cancel 149 | 148,pto_used 150 | 149,travel_suggestion 151 | 150,change_volume 152 | -------------------------------------------------------------------------------- /data-management/src/test/resources/datasets/labels-validation.csv: -------------------------------------------------------------------------------- 1 | 0,restaurant_reviews 2 | 1,nutrition_info 3 | 2,account_blocked 4 | 3,oil_change_how 5 | 4,time 6 | 5,weather 7 | 6,redeem_rewards 8 | 7,interest_rate 9 | 8,gas_type 10 | 9,accept_reservations 11 | 10,smart_home 12 | 11,user_name 13 | 12,report_lost_card 14 | 13,repeat 15 | 14,whisper_mode 16 | 15,what_are_your_hobbies 17 | 16,order 18 | 17,jump_start 19 | 18,schedule_meeting 20 | 19,meeting_schedule 21 | 20,freeze_account 22 | 21,what_song 23 | 22,meaning_of_life 24 | 23,restaurant_reservation 25 | 24,traffic 26 | 25,make_call 27 | 26,text 28 | 27,bill_balance 29 | 28,improve_credit_score 30 | 29,change_language 31 | 30,no 32 | 31,measurement_conversion 33 | 32,timer 34 | 33,flip_coin 35 | 34,do_you_have_pets 36 | 35,balance 37 | 36,tell_joke 38 | 37,last_maintenance 39 | 38,exchange_rate 40 | 39,uber 41 | 40,car_rental 42 | 41,credit_limit 43 | 42,oos 44 | 43,shopping_list 45 | 44,expiration_date 46 | 45,routing 47 | 46,meal_suggestion 48 | 47,tire_change 49 | 48,todo_list 50 | 49,card_declined 51 | 50,rewards_balance 52 | 51,change_accent 53 | 52,vaccines 54 | 53,reminder_update 55 | 54,food_last 56 | 55,change_ai_name 57 | 56,bill_due 58 | 57,who_do_you_work_for 59 | 58,share_location 60 | 59,international_visa 61 | 60,calendar 62 | 61,translate 63 | 62,carry_on 64 | 63,book_flight 65 | 64,insurance_change 66 | 65,todo_list_update 67 | 66,timezone 68 | 67,cancel_reservation 69 | 68,transactions 70 | 69,credit_score 71 | 70,report_fraud 72 | 71,spending_history 73 | 72,directions 74 | 73,spelling 75 | 74,insurance 76 | 75,what_is_your_name 77 | 76,reminder 78 | 77,where_are_you_from 79 | 78,distance 80 | 79,payday 81 | 80,flight_status 82 | 81,find_phone 83 | 82,greeting 84 | 83,alarm 85 | 84,order_status 86 | 85,confirm_reservation 87 | 86,cook_time 88 | 87,damaged_card 89 | 88,reset_settings 90 | 89,pin_change 91 | 90,replacement_card_duration 92 | 91,new_card 93 | 92,roll_dice 94 | 93,income 95 | 94,taxes 96 | 95,date 97 | 96,who_made_you 98 | 97,pto_request 99 | 98,tire_pressure 100 | 99,how_old_are_you 101 | 100,rollover_401k 102 | 101,pto_request_status 103 | 102,how_busy 104 | 103,application_status 105 | 104,recipe 106 | 105,calendar_update 107 | 106,play_music 108 | 107,yes 109 | 108,direct_deposit 110 | 109,credit_limit_change 111 | 110,gas 112 | 111,pay_bill 113 | 112,ingredients_list 114 | 113,lost_luggage 115 | 114,goodbye 116 | 115,what_can_i_ask_you 117 | 116,book_hotel 118 | 117,are_you_a_bot 119 | 118,next_song 120 | 119,change_speed 121 | 120,plug_type 122 | 121,maybe 123 | 122,w2 124 | 123,oil_change_when 125 | 124,thank_you 126 | 125,shopping_list_update 127 | 126,pto_balance 128 | 127,order_checks 129 | 128,travel_alert 130 | 129,fun_fact 131 | 130,sync_device 132 | 131,schedule_maintenance 133 | 132,apr 134 | 133,transfer 135 | 134,ingredient_substitution 136 | 135,calories 137 | 136,current_location 138 | 137,international_fees 139 | 138,calculator 140 | 139,definition 141 | 140,next_holiday 142 | 141,update_playlist 143 | 142,mpg 144 | 143,min_payment 145 | 144,change_user_name 146 | 145,restaurant_suggestion 147 | 146,travel_notification 148 | 147,cancel 149 | 148,pto_used 150 | 149,travel_suggestion 151 | 150,change_volume 152 | -------------------------------------------------------------------------------- /data-management/src/test/resources/genericDatasets/bar: -------------------------------------------------------------------------------- 1 | bar 2 | bar 3 | bar 4 | bar 5 | bar 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/genericDatasets/coo: -------------------------------------------------------------------------------- 1 | coo 2 | coo 3 | coo 4 | coo 5 | coo 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/genericDatasets/dzz: -------------------------------------------------------------------------------- 1 | dzz 2 | dzz 3 | dzz 4 | dzz 5 | dzz 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/genericDatasets/ell: -------------------------------------------------------------------------------- 1 | ell 2 | ell 3 | ell 4 | ell 5 | ell 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/genericDatasets/foo: -------------------------------------------------------------------------------- 1 | foo 2 | foo 3 | foo 4 | foo 5 | foo 6 | -------------------------------------------------------------------------------- /data-management/src/test/resources/logback-test.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /grpc-contract/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | org.orca3 9 | mini-auto-ml 10 | 1.0-SNAPSHOT 11 | 12 | 13 | grpc-contract 14 | 1.0-SNAPSHOT 15 | 16 | 17 | 3.12.0 18 | 19 | 20 | 21 | io.grpc 22 | grpc-services 23 | provided 24 | 25 | 26 | io.grpc 27 | grpc-netty-shaded 28 | ${grpc.version} 29 | 30 | 31 | io.grpc 32 | grpc-protobuf 33 | ${grpc.version} 34 | 35 | 36 | io.grpc 37 | grpc-stub 38 | ${grpc.version} 39 | 40 | 41 | ch.qos.logback 42 | logback-classic 43 | 44 | 45 | org.apache.tomcat 46 | annotations-api 47 | 6.0.53 48 | provided 49 | 50 | 51 | 52 | 53 | 54 | 55 | kr.motd.maven 56 | os-maven-plugin 57 | 1.6.2 58 | 59 | 60 | 61 | 62 | org.xolstice.maven.plugins 63 | protobuf-maven-plugin 64 | 0.6.1 65 | 66 | com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} 67 | grpc-java 68 | io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} 69 | 70 | 71 | 72 | generate-grpc-java 73 | 74 | compile 75 | compile-custom 76 | 77 | generate-sources 78 | 79 | 80 | 81 | 82 | exec-maven-plugin 83 | org.codehaus.mojo 84 | 85 | 86 | generate-grpc-python 87 | generate-sources 88 | 89 | exec 90 | 91 | 92 | ${skipGrpcPython} 93 | ${project.build.scriptSourceDirectory}/python_code_gen.sh 94 | ${project.build.scriptSourceDirectory} 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /grpc-contract/src/main/java/org/orca3/miniAutoML/ServiceBase.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML; 2 | 3 | import io.grpc.BindableService; 4 | import io.grpc.Metadata; 5 | import io.grpc.MethodDescriptor; 6 | import io.grpc.Server; 7 | import io.grpc.ServerBuilder; 8 | import io.grpc.ServerCall; 9 | import io.grpc.ServerCallHandler; 10 | import io.grpc.ServerInterceptor; 11 | import io.grpc.Status; 12 | import io.grpc.health.v1.HealthCheckResponse; 13 | import io.grpc.protobuf.services.HealthStatusManager; 14 | import io.grpc.protobuf.services.ProtoReflectionService; 15 | import org.slf4j.Logger; 16 | import org.slf4j.LoggerFactory; 17 | 18 | import java.io.FileInputStream; 19 | import java.io.IOException; 20 | import java.util.Optional; 21 | import java.util.Properties; 22 | import java.util.concurrent.TimeUnit; 23 | 24 | public class ServiceBase { 25 | private static final Logger logger = LoggerFactory.getLogger(ServiceBase.class); 26 | 27 | public static Properties getConfigProperties() throws IOException { 28 | String configLocation = Optional.ofNullable(System.getenv("APP_CONFIG")).orElse("config/config-jvm-docker.properties"); 29 | logger.info(String.format("Reading config from %s on the file system (customizable by modifying environment variable APP_CONFIG)", configLocation)); 30 | Properties props = new Properties(); 31 | props.load(new FileInputStream(configLocation)); 32 | return props; 33 | } 34 | 35 | public static void startService(int port, BindableService service, Runnable shutdownHook) throws IOException, InterruptedException { 36 | Logger logger = LoggerFactory.getLogger(service.getClass()); 37 | HealthStatusManager health = new HealthStatusManager(); 38 | final Server server = ServerBuilder.forPort(port) 39 | .addService(service) 40 | .addService(ProtoReflectionService.newInstance()) 41 | .addService(health.getHealthService()) 42 | .intercept(new GrpcInterceptor()) 43 | .build() 44 | .start(); 45 | logger.info("Listening on port " + port); 46 | Runtime.getRuntime().addShutdownHook(new Thread(() -> { 47 | health.setStatus("", HealthCheckResponse.ServingStatus.NOT_SERVING); 48 | // Start graceful shutdown 49 | server.shutdown(); 50 | shutdownHook.run(); 51 | try { 52 | if (!server.awaitTermination(30, TimeUnit.SECONDS)) { 53 | server.shutdownNow(); 54 | server.awaitTermination(5, TimeUnit.SECONDS); 55 | } 56 | } catch (InterruptedException ex) { 57 | server.shutdownNow(); 58 | } 59 | })); 60 | health.setStatus("", HealthCheckResponse.ServingStatus.SERVING); 61 | server.awaitTermination(); 62 | 63 | } 64 | 65 | static class GrpcInterceptor implements ServerInterceptor { 66 | 67 | @Override 68 | public ServerCall.Listener interceptCall( 69 | ServerCall call, Metadata headers, ServerCallHandler next) { 70 | GrpcServerCall grpcServerCall = new GrpcServerCall<>(call); 71 | 72 | ServerCall.Listener listener = next.startCall(grpcServerCall, headers); 73 | 74 | return new GrpcForwardingServerCallListener<>(call.getMethodDescriptor(), listener) { 75 | @Override 76 | public void onMessage(M message) { 77 | logger.info("Method: {}, Message: {}", methodName, message); 78 | super.onMessage(message); 79 | } 80 | }; 81 | } 82 | 83 | 84 | } 85 | 86 | private static class GrpcServerCall extends ServerCall { 87 | 88 | ServerCall serverCall; 89 | 90 | protected GrpcServerCall(ServerCall serverCall) { 91 | this.serverCall = serverCall; 92 | } 93 | 94 | @Override 95 | public void request(int numMessages) { 96 | serverCall.request(numMessages); 97 | } 98 | 99 | @Override 100 | public void sendHeaders(Metadata headers) { 101 | serverCall.sendHeaders(headers); 102 | } 103 | 104 | @Override 105 | public void sendMessage(R message) { 106 | logger.info("Method: {}, Response: {}", serverCall.getMethodDescriptor().getFullMethodName(), message); 107 | serverCall.sendMessage(message); 108 | } 109 | 110 | @Override 111 | public void close(Status status, Metadata trailers) { 112 | serverCall.close(status, trailers); 113 | } 114 | 115 | @Override 116 | public boolean isCancelled() { 117 | return serverCall.isCancelled(); 118 | } 119 | 120 | @Override 121 | public MethodDescriptor getMethodDescriptor() { 122 | return serverCall.getMethodDescriptor(); 123 | } 124 | } 125 | 126 | private static class GrpcForwardingServerCallListener extends io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener { 127 | 128 | String methodName; 129 | 130 | protected GrpcForwardingServerCallListener(MethodDescriptor method, ServerCall.Listener listener) { 131 | super(listener); 132 | methodName = method.getFullMethodName(); 133 | } 134 | } 135 | } 136 | 137 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/data_management.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_multiple_files = true; 4 | option java_package = "org.orca3.miniAutoML.dataManagement"; 5 | option java_outer_classname = "DataManagementProto"; 6 | 7 | import "google/protobuf/empty.proto"; 8 | 9 | package data_management; 10 | 11 | service DataManagementService { 12 | rpc GetDatasetSummary (DatasetPointer) returns (DatasetSummary); 13 | rpc ListDatasets (ListQueryOptions) returns (stream DatasetPointer); 14 | rpc CreateDataset (CreateDatasetRequest) returns (DatasetSummary); 15 | rpc UpdateDataset (CreateCommitRequest) returns (DatasetSummary); 16 | rpc PrepareTrainingDataset (DatasetQuery) returns (SnapshotVersion); 17 | rpc FetchTrainingDataset (VersionQuery) returns (VersionedSnapshot); 18 | rpc DeleteDataset (DatasetPointer) returns (google.protobuf.Empty); 19 | } 20 | 21 | message CreateDatasetRequest { 22 | string name = 1; 23 | string description = 2; 24 | DatasetType dataset_type = 3; 25 | string bucket = 4; 26 | string path = 5; 27 | repeated Tag tags = 6; 28 | } 29 | 30 | message ListQueryOptions { 31 | int32 limit = 1; 32 | int32 index = 2; 33 | } 34 | 35 | message CreateCommitRequest { 36 | string dataset_id = 1; 37 | string commit_message = 2; 38 | string bucket = 3; 39 | string path = 4; 40 | repeated Tag tags = 5; 41 | } 42 | 43 | message DatasetQuery { 44 | string dataset_id = 1; 45 | string commit_id = 2; 46 | repeated Tag tags = 3; 47 | } 48 | 49 | message VersionQuery { 50 | string dataset_id = 1; 51 | string version_hash = 2; 52 | } 53 | 54 | message DatasetPointer { 55 | string dataset_id = 1; 56 | } 57 | 58 | message DatasetSummary { 59 | string dataset_id = 1; 60 | string name = 2; 61 | string description = 3; 62 | DatasetType dataset_type = 4; 63 | string last_updated_at = 5; 64 | repeated CommitInfo commits = 6; 65 | 66 | } 67 | 68 | message SnapshotVersion { 69 | string dataset_id = 1; 70 | string name = 2; 71 | string description = 3; 72 | DatasetType dataset_type = 4; 73 | string last_updated_at = 5; 74 | string version_hash = 6; 75 | repeated CommitInfo commits = 7; 76 | } 77 | 78 | message DatasetPart { 79 | string dataset_id = 1; 80 | string commit_id = 2; 81 | string bucket = 3; 82 | string path_prefix = 4; 83 | } 84 | 85 | message VersionedSnapshot { 86 | string dataset_id = 1; 87 | string version_hash = 2; 88 | SnapshotState state = 3; 89 | repeated FileInfo parts = 4; 90 | FileInfo root = 5; 91 | map statistics = 7; 92 | } 93 | 94 | message CommitInfo { 95 | string dataset_id = 1; 96 | string commit_id = 2; 97 | string created_at = 3; 98 | string commit_message = 4; 99 | repeated Tag tags = 5; 100 | string path = 6; 101 | map statistics = 7; 102 | } 103 | 104 | message Tag { 105 | string tag_key = 1; 106 | string tag_value = 2; 107 | } 108 | 109 | message FileInfo { 110 | string name = 1; 111 | string bucket = 2; 112 | string path = 3; 113 | } 114 | 115 | enum DatasetType { 116 | GENERIC = 0; 117 | TEXT_INTENT = 1; 118 | } 119 | 120 | enum SnapshotState { 121 | RUNNING = 0; 122 | READY = 1; 123 | FAILED = 2; 124 | } 125 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/metadata_store.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_multiple_files = true; 4 | option java_package = "org.orca3.miniAutoML.metadataStore"; 5 | option java_outer_classname = "MetadataStoreProto"; 6 | 7 | import "data_management.proto"; 8 | 9 | package metadata_store; 10 | 11 | service MetadataStoreService { 12 | rpc LogRunStart (LogRunStartRequest) returns (LogRunStartResponse); 13 | rpc LogEpoch (LogEpochRequest) returns (LogEpochResponse); 14 | rpc LogRunEnd (LogRunEndRequest) returns (LogRunEndResponse); 15 | rpc GetRunStatus (GetRunStatusRequest) returns (GetRunStatusResponse); 16 | rpc CreateArtifact (CreateArtifactRequest) returns (CreateArtifactResponse); 17 | rpc GetArtifact (GetArtifactRequest) returns (GetArtifactResponse); 18 | } 19 | 20 | message LogRunStartRequest { 21 | string start_time = 1; 22 | string run_id = 2; 23 | string run_name = 3; 24 | TracingInformation tracing = 4; 25 | } 26 | 27 | message LogRunStartResponse { 28 | RunInfo run_info = 1; 29 | string bucket = 2; 30 | string path = 3; 31 | } 32 | 33 | message LogEpochRequest { 34 | EpochInfo epoch_info = 1; 35 | } 36 | 37 | message LogEpochResponse { 38 | EpochInfo epoch_info = 1; 39 | } 40 | 41 | message LogRunEndRequest { 42 | string run_id = 1; 43 | string end_time = 2; 44 | bool success = 3; 45 | string message = 4; 46 | } 47 | 48 | message LogRunEndResponse { 49 | RunInfo run_info = 1; 50 | } 51 | 52 | message GetRunStatusRequest { 53 | string run_id = 5; 54 | } 55 | 56 | message GetRunStatusResponse { 57 | RunInfo run_info = 1; 58 | } 59 | 60 | message CreateArtifactRequest { 61 | data_management.FileInfo artifact = 1; 62 | string run_id = 2; 63 | string algorithm = 3; 64 | } 65 | 66 | message CreateArtifactResponse { 67 | string version = 1; 68 | data_management.FileInfo artifact = 2; 69 | string run_id = 3; 70 | string name = 4; 71 | string algorithm = 5; 72 | } 73 | 74 | message GetArtifactRequest { 75 | string run_id = 3; 76 | } 77 | 78 | message GetArtifactResponse { 79 | string name = 1; 80 | string version = 2; 81 | data_management.FileInfo artifact = 3; 82 | string run_id = 4; 83 | string algorithm = 5; 84 | } 85 | 86 | //========= 87 | message EpochInfo { 88 | string start_time = 1; 89 | string end_time = 2; 90 | string run_id = 3; 91 | string epoch_id = 4; 92 | map metrics = 5; 93 | 94 | } 95 | 96 | message RunInfo { 97 | string start_time = 1; 98 | string end_time = 2; 99 | bool success = 3; 100 | string message = 4; 101 | string run_id = 5; 102 | string run_name = 6; 103 | TracingInformation tracing = 7; 104 | map epochs = 8; 105 | } 106 | 107 | message TracingInformation { 108 | string dataset_id = 1; 109 | string version_hash = 2; 110 | string code_version = 3; 111 | } 112 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/prediction_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_multiple_files = true; 4 | option java_package = "org.orca3.miniAutoML.prediction"; 5 | option java_outer_classname = "PredictionServiceProto"; 6 | 7 | package prediction; 8 | 9 | service PredictionService { 10 | rpc Predict(PredictRequest) returns (PredictResponse); 11 | } 12 | 13 | service Predictor { 14 | rpc PredictorPredict(PredictorPredictRequest) returns (PredictorPredictResponse); 15 | } 16 | 17 | message PredictRequest { 18 | string runId = 3; 19 | string document = 4; 20 | } 21 | 22 | message PredictResponse { 23 | string response = 1; 24 | } 25 | 26 | message PredictorPredictRequest { 27 | string runId = 1; 28 | string document = 2; 29 | } 30 | 31 | message PredictorPredictResponse { 32 | string response = 1; 33 | } 34 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/torch_management.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.pytorch.serve.grpc.management; 4 | 5 | option java_multiple_files = true; 6 | 7 | message ManagementResponse { 8 | // Response string of different management API calls. 9 | string msg = 1; 10 | } 11 | 12 | message DescribeModelRequest { 13 | // Name of model to describe. 14 | string model_name = 1; //required 15 | // Version of model to describe. 16 | string model_version = 2; //optional 17 | } 18 | 19 | message ListModelsRequest { 20 | // Use this parameter to specify the maximum number of items to return. When this value is present, TorchServe does not return more than the specified number of items, but it might return fewer. This value is optional. If you include a value, it must be between 1 and 1000, inclusive. If you do not include a value, it defaults to 100. 21 | int32 limit = 1; //optional 22 | 23 | // The token to retrieve the next set of results. TorchServe provides the token when the response from a previous call has more results than the maximum page size. 24 | int32 next_page_token = 2; //optional 25 | } 26 | 27 | message RegisterModelRequest { 28 | // Inference batch size, default: 1. 29 | int32 batch_size = 1; //optional 30 | 31 | // Inference handler entry-point. This value will override handler in MANIFEST.json if present. 32 | string handler = 2; //optional 33 | 34 | // Number of initial workers, default: 0. 35 | int32 initial_workers = 3; //optional 36 | 37 | // Maximum delay for batch aggregation, default: 100. 38 | int32 max_batch_delay = 4; //optional 39 | 40 | // Name of model. This value will override modelName in MANIFEST.json if present. 41 | string model_name = 5; //optional 42 | 43 | // Maximum time, in seconds, the TorchServe waits for a response from the model inference code, default: 120. 44 | int32 response_timeout = 6; //optional 45 | 46 | // Runtime for the model custom service code. This value will override runtime in MANIFEST.json if present. 47 | string runtime = 7; //optional 48 | 49 | // Decides whether creation of worker synchronous or not, default: false. 50 | bool synchronous = 8; //optional 51 | 52 | // Model archive download url, support local file or HTTP(s) protocol. 53 | string url = 9; //required 54 | 55 | // Decides whether S3 SSE KMS enabled or not, default: false. 56 | bool s3_sse_kms = 10; //optional 57 | } 58 | 59 | message ScaleWorkerRequest { 60 | 61 | // Name of model to scale workers. 62 | string model_name = 1; //required 63 | 64 | // Model version. 65 | string model_version = 2; //optional 66 | 67 | // Maximum number of worker processes. 68 | int32 max_worker = 3; //optional 69 | 70 | // Minimum number of worker processes. 71 | int32 min_worker = 4; //optional 72 | 73 | // Number of GPU worker processes to create. 74 | int32 number_gpu = 5; //optional 75 | 76 | // Decides whether the call is synchronous or not, default: false. 77 | bool synchronous = 6; //optional 78 | 79 | // Waiting up to the specified wait time if necessary for a worker to complete all pending requests. Use 0 to terminate backend worker process immediately. Use -1 for wait infinitely. 80 | int32 timeout = 7; //optional 81 | } 82 | 83 | message SetDefaultRequest { 84 | // Name of model whose default version needs to be updated. 85 | string model_name = 1; //required 86 | 87 | // Version of model to be set as default version for the model 88 | string model_version = 2; //required 89 | } 90 | 91 | message UnregisterModelRequest { 92 | // Name of model to unregister. 93 | string model_name = 1; //required 94 | 95 | // Name of model to unregister. 96 | string model_version = 2; //optional 97 | } 98 | 99 | service ManagementAPIsService { 100 | // Provides detailed information about the default version of a model. 101 | rpc DescribeModel(DescribeModelRequest) returns (ManagementResponse) {} 102 | 103 | // List registered models in TorchServe. 104 | rpc ListModels(ListModelsRequest) returns (ManagementResponse) {} 105 | 106 | // Register a new model in TorchServe. 107 | rpc RegisterModel(RegisterModelRequest) returns (ManagementResponse) {} 108 | 109 | // Configure number of workers for a default version of a model.This is a asynchronous call by default. Caller need to call describeModel to check if the model workers has been changed. 110 | rpc ScaleWorker(ScaleWorkerRequest) returns (ManagementResponse) {} 111 | 112 | // Set default version of a model 113 | rpc SetDefault(SetDefaultRequest) returns (ManagementResponse) {} 114 | 115 | // Unregister the default version of a model from TorchServe if it is the only version available.This is a asynchronous call by default. Caller can call listModels to confirm model is unregistered 116 | rpc UnregisterModel(UnregisterModelRequest) returns (ManagementResponse) {} 117 | } 118 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/torch_serve.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package org.pytorch.serve.grpc.inference; 4 | 5 | import "google/protobuf/empty.proto"; 6 | 7 | option java_multiple_files = true; 8 | 9 | message PredictionsRequest { 10 | // Name of model. 11 | string model_name = 1; //required 12 | 13 | // Version of model to run prediction on. 14 | string model_version = 2; //optional 15 | 16 | // input data for model prediction 17 | map input = 3; //required 18 | } 19 | 20 | message PredictionResponse { 21 | // TorchServe health 22 | bytes prediction = 1; 23 | } 24 | 25 | message TorchServeHealthResponse { 26 | // TorchServe health 27 | string health = 1; 28 | } 29 | 30 | service InferenceAPIsService { 31 | rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {} 32 | 33 | // Predictions entry point to get inference using default model version. 34 | rpc Predictions(PredictionsRequest) returns (PredictionResponse) {} 35 | } 36 | -------------------------------------------------------------------------------- /grpc-contract/src/main/proto/training_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_multiple_files = true; 4 | option java_package = "org.orca3.miniAutoML.training"; 5 | option java_outer_classname = "TrainingServiceProto"; 6 | 7 | package training; 8 | 9 | service TrainingService { 10 | rpc Train(TrainRequest) returns (TrainResponse); 11 | rpc GetTrainingStatus(GetTrainingStatusRequest) returns (GetTrainingStatusResponse); 12 | } 13 | 14 | message TrainRequest { 15 | TrainingJobMetadata metadata = 1; 16 | } 17 | 18 | message TrainResponse { 19 | int32 job_id = 1; 20 | } 21 | 22 | message GetTrainingStatusRequest { 23 | int32 job_id = 1; 24 | } 25 | 26 | message GetTrainingStatusResponse { 27 | TrainingStatus status = 1; 28 | int32 job_id = 2; 29 | string message = 3; 30 | TrainingJobMetadata metadata = 4; 31 | int32 positionInQueue = 5; 32 | } 33 | 34 | enum TrainingStatus { 35 | queuing = 0; 36 | launch = 1; 37 | running = 2; 38 | succeed = 3; 39 | failure = 4; 40 | } 41 | 42 | message TrainingJobMetadata { 43 | string algorithm = 1; 44 | string dataset_id = 2; 45 | string name = 3; 46 | string train_data_version_hash = 4; 47 | map parameters = 5; 48 | string output_model_name = 6; 49 | } 50 | -------------------------------------------------------------------------------- /grpc-contract/src/main/scripts/python_code_gen.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 -m grpc_tools.protoc \ 4 | -I proto \ 5 | --python_out="$(dirname "$0")/../../../../training-code/text-classification/" \ 6 | --grpc_python_out="$(dirname "$0")/../../../../training-code/text-classification/" \ 7 | --proto_path=../proto ../proto/*.proto 8 | 9 | python3 -m grpc_tools.protoc \ 10 | -I proto \ 11 | --python_out="$(dirname "$0")/../../../../predictor/" \ 12 | --grpc_python_out="$(dirname "$0")/../../../../predictor/" \ 13 | --proto_path=../proto ../proto/prediction_service.proto 14 | -------------------------------------------------------------------------------- /metadata-store/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | org.orca3 9 | mini-auto-ml 10 | 1.0-SNAPSHOT 11 | 12 | 13 | metadata-store 14 | 1.0-SNAPSHOT 15 | 16 | 17 | org.orca3.miniAutoML.metadataStore.MetadataStoreService 18 | 19 | 20 | 21 | 22 | org.orca3 23 | grpc-contract 24 | ${project.version} 25 | 26 | 27 | 28 | io.grpc 29 | grpc-services 30 | 31 | 32 | ch.qos.logback 33 | logback-classic 34 | 35 | 36 | io.minio 37 | minio 38 | 39 | 40 | 41 | io.grpc 42 | grpc-testing 43 | test 44 | 45 | 46 | junit 47 | junit 48 | RELEASE 49 | test 50 | 51 | 52 | org.testcontainers 53 | testcontainers 54 | 1.15.3 55 | test 56 | 57 | 58 | 59 | 60 | 61 | org.codehaus.mojo 62 | exec-maven-plugin 63 | 3.0.0 64 | 65 | 66 | 67 | java 68 | 69 | 70 | 71 | 72 | ${mainClass} 73 | 74 | 75 | 76 | org.apache.maven.plugins 77 | maven-shade-plugin 78 | 79 | 80 | package 81 | 82 | shade 83 | 84 | 85 | 86 | 87 | ${mainClass} 88 | 89 | 90 | 91 | 92 | 93 | *:* 94 | 95 | META-INF/*.SF 96 | META-INF/*.DSA 97 | META-INF/*.RSA 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /metadata-store/src/main/java/org/orca3/miniAutoML/metadataStore/models/ArtifactInfo.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.metadataStore.models; 2 | 3 | import org.orca3.miniAutoML.dataManagement.FileInfo; 4 | 5 | public class ArtifactInfo { 6 | private final FileInfo fileInfo; 7 | private final String runId; 8 | private final String artifactName; 9 | private final String version; 10 | private final String algorithm; 11 | 12 | public String getRunId() { 13 | return runId; 14 | } 15 | 16 | public FileInfo getFileInfo() { 17 | return fileInfo; 18 | } 19 | 20 | public String getArtifactName() { 21 | return artifactName; 22 | } 23 | 24 | public String getVersion() { 25 | return version; 26 | } 27 | 28 | public String getAlgorithm() { 29 | return algorithm; 30 | } 31 | 32 | public ArtifactInfo(FileInfo fileInfo, String runId, String artifactName, String version, String algorithm) { 33 | this.fileInfo = fileInfo; 34 | this.runId = runId; 35 | this.artifactName = artifactName; 36 | this.version = version; 37 | this.algorithm = algorithm; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /metadata-store/src/main/java/org/orca3/miniAutoML/metadataStore/models/ArtifactRepo.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.metadataStore.models; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.concurrent.atomic.AtomicInteger; 6 | 7 | public class ArtifactRepo { 8 | private final AtomicInteger seed; 9 | public final Map artifacts; 10 | private final String name; 11 | 12 | public ArtifactRepo(String name) { 13 | this.name = name; 14 | this.seed = new AtomicInteger(); 15 | this.artifacts = new HashMap<>(); 16 | } 17 | 18 | public String getName() { 19 | return name; 20 | } 21 | 22 | public int getSeed() { 23 | return seed.incrementAndGet(); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /metadata-store/src/main/java/org/orca3/miniAutoML/metadataStore/models/MemoryStore.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.metadataStore.models; 2 | 3 | import org.orca3.miniAutoML.metadataStore.RunInfo; 4 | 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | 8 | public class MemoryStore { 9 | public final Map runInfoMap; 10 | public final Map artifactRepos; 11 | public final Map runIdLookup; 12 | 13 | public MemoryStore() { 14 | this.runInfoMap = new HashMap<>(); 15 | this.artifactRepos = new HashMap<>(); 16 | this.runIdLookup = new HashMap<>(); 17 | } 18 | 19 | public ArtifactRepo getRepo(String artifactName) { 20 | ArtifactRepo repo; 21 | if (artifactRepos.containsKey(artifactName)) { 22 | repo = artifactRepos.get(artifactName); 23 | } else { 24 | repo = new ArtifactRepo(artifactName); 25 | artifactRepos.put(artifactName, repo); 26 | } 27 | return repo; 28 | } 29 | 30 | public void put(String artifactName, String version, ArtifactInfo artifactInfo) { 31 | getRepo(artifactName).artifacts.put(version, artifactInfo); 32 | runIdLookup.put(artifactInfo.getRunId(), artifactInfo); 33 | } 34 | 35 | public void clear() { 36 | runInfoMap.clear(); 37 | artifactRepos.clear(); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /metadata-store/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /mvnw.cmd: -------------------------------------------------------------------------------- 1 | @REM ---------------------------------------------------------------------------- 2 | @REM Licensed to the Apache Software Foundation (ASF) under one 3 | @REM or more contributor license agreements. See the NOTICE file 4 | @REM distributed with this work for additional information 5 | @REM regarding copyright ownership. The ASF licenses this file 6 | @REM to you under the Apache License, Version 2.0 (the 7 | @REM "License"); you may not use this file except in compliance 8 | @REM with the License. You may obtain a copy of the License at 9 | @REM 10 | @REM http://www.apache.org/licenses/LICENSE-2.0 11 | @REM 12 | @REM Unless required by applicable law or agreed to in writing, 13 | @REM software distributed under the License is distributed on an 14 | @REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 | @REM KIND, either express or implied. See the License for the 16 | @REM specific language governing permissions and limitations 17 | @REM under the License. 18 | @REM ---------------------------------------------------------------------------- 19 | 20 | @REM ---------------------------------------------------------------------------- 21 | @REM Maven Start Up Batch script 22 | @REM 23 | @REM Required ENV vars: 24 | @REM JAVA_HOME - location of a JDK home dir 25 | @REM 26 | @REM Optional ENV vars 27 | @REM M2_HOME - location of maven2's installed home dir 28 | @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands 29 | @REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending 30 | @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven 31 | @REM e.g. to debug Maven itself, use 32 | @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 33 | @REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files 34 | @REM ---------------------------------------------------------------------------- 35 | 36 | @REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' 37 | @echo off 38 | @REM set title of command window 39 | title %0 40 | @REM enable echoing by setting MAVEN_BATCH_ECHO to 'on' 41 | @if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% 42 | 43 | @REM set %HOME% to equivalent of $HOME 44 | if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") 45 | 46 | @REM Execute a user defined script before this one 47 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre 48 | @REM check for pre script, once with legacy .bat ending and once with .cmd ending 49 | if exist "%HOME%\mavenrc_pre.bat" call "%HOME%\mavenrc_pre.bat" 50 | if exist "%HOME%\mavenrc_pre.cmd" call "%HOME%\mavenrc_pre.cmd" 51 | :skipRcPre 52 | 53 | @setlocal 54 | 55 | set ERROR_CODE=0 56 | 57 | @REM To isolate internal variables from possible post scripts, we use another setlocal 58 | @setlocal 59 | 60 | @REM ==== START VALIDATION ==== 61 | if not "%JAVA_HOME%" == "" goto OkJHome 62 | 63 | echo. 64 | echo Error: JAVA_HOME not found in your environment. >&2 65 | echo Please set the JAVA_HOME variable in your environment to match the >&2 66 | echo location of your Java installation. >&2 67 | echo. 68 | goto error 69 | 70 | :OkJHome 71 | if exist "%JAVA_HOME%\bin\java.exe" goto init 72 | 73 | echo. 74 | echo Error: JAVA_HOME is set to an invalid directory. >&2 75 | echo JAVA_HOME = "%JAVA_HOME%" >&2 76 | echo Please set the JAVA_HOME variable in your environment to match the >&2 77 | echo location of your Java installation. >&2 78 | echo. 79 | goto error 80 | 81 | @REM ==== END VALIDATION ==== 82 | 83 | :init 84 | 85 | @REM Find the project base dir, i.e. the directory that contains the folder ".mvn". 86 | @REM Fallback to current working directory if not found. 87 | 88 | set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% 89 | IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir 90 | 91 | set EXEC_DIR=%CD% 92 | set WDIR=%EXEC_DIR% 93 | :findBaseDir 94 | IF EXIST "%WDIR%"\.mvn goto baseDirFound 95 | cd .. 96 | IF "%WDIR%"=="%CD%" goto baseDirNotFound 97 | set WDIR=%CD% 98 | goto findBaseDir 99 | 100 | :baseDirFound 101 | set MAVEN_PROJECTBASEDIR=%WDIR% 102 | cd "%EXEC_DIR%" 103 | goto endDetectBaseDir 104 | 105 | :baseDirNotFound 106 | set MAVEN_PROJECTBASEDIR=%EXEC_DIR% 107 | cd "%EXEC_DIR%" 108 | 109 | :endDetectBaseDir 110 | 111 | IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig 112 | 113 | @setlocal EnableExtensions EnableDelayedExpansion 114 | for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a 115 | @endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% 116 | 117 | :endReadAdditionalConfig 118 | 119 | SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" 120 | set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" 121 | set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain 122 | 123 | set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" 124 | 125 | FOR /F "tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( 126 | IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B 127 | ) 128 | 129 | @REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central 130 | @REM This allows using the maven wrapper in projects that prohibit checking in binary data. 131 | if exist %WRAPPER_JAR% ( 132 | if "%MVNW_VERBOSE%" == "true" ( 133 | echo Found %WRAPPER_JAR% 134 | ) 135 | ) else ( 136 | if not "%MVNW_REPOURL%" == "" ( 137 | SET DOWNLOAD_URL="%MVNW_REPOURL%/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" 138 | ) 139 | if "%MVNW_VERBOSE%" == "true" ( 140 | echo Couldn't find %WRAPPER_JAR%, downloading it ... 141 | echo Downloading from: %DOWNLOAD_URL% 142 | ) 143 | 144 | powershell -Command "&{"^ 145 | "$webclient = new-object System.Net.WebClient;"^ 146 | "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^ 147 | "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^ 148 | "}"^ 149 | "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%DOWNLOAD_URL%', '%WRAPPER_JAR%')"^ 150 | "}" 151 | if "%MVNW_VERBOSE%" == "true" ( 152 | echo Finished downloading %WRAPPER_JAR% 153 | ) 154 | ) 155 | @REM End of extension 156 | 157 | @REM Provide a "standardized" way to retrieve the CLI args that will 158 | @REM work with both Windows and non-Windows executions. 159 | set MAVEN_CMD_LINE_ARGS=%* 160 | 161 | %MAVEN_JAVA_EXE% %JVM_CONFIG_MAVEN_PROPS% %MAVEN_OPTS% %MAVEN_DEBUG_OPTS% -classpath %WRAPPER_JAR% "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* 162 | if ERRORLEVEL 1 goto error 163 | goto end 164 | 165 | :error 166 | set ERROR_CODE=1 167 | 168 | :end 169 | @endlocal & set ERROR_CODE=%ERROR_CODE% 170 | 171 | if not "%MAVEN_SKIP_RC%" == "" goto skipRcPost 172 | @REM check for post script, once with legacy .bat ending and once with .cmd ending 173 | if exist "%HOME%\mavenrc_post.bat" call "%HOME%\mavenrc_post.bat" 174 | if exist "%HOME%\mavenrc_post.cmd" call "%HOME%\mavenrc_post.cmd" 175 | :skipRcPost 176 | 177 | @REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' 178 | if "%MAVEN_BATCH_PAUSE%" == "on" pause 179 | 180 | if "%MAVEN_TERMINATE_CMD%" == "on" exit %ERROR_CODE% 181 | 182 | exit /B %ERROR_CODE% 183 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.orca3 8 | mini-auto-ml 9 | 1.0-SNAPSHOT 10 | pom 11 | 12 | 13 | 11 14 | 11 15 | 1.38.0 16 | true 17 | 18 | 19 | 20 | grpc-contract 21 | data-management 22 | metadata-store 23 | training-service 24 | prediction-service 25 | 26 | 27 | 28 | 29 | 30 | io.grpc 31 | grpc-services 32 | ${grpc.version} 33 | 34 | 35 | ch.qos.logback 36 | logback-classic 37 | 1.2.3 38 | 39 | 40 | io.minio 41 | minio 42 | 8.2.2 43 | 44 | 45 | io.grpc 46 | grpc-testing 47 | ${grpc.version} 48 | test 49 | 50 | 51 | 52 | 53 | 54 | 55 | python 56 | 57 | false 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /prediction-service/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | org.orca3 9 | mini-auto-ml 10 | 1.0-SNAPSHOT 11 | 12 | 13 | prediction-service 14 | 1.0-SNAPSHOT 15 | 16 | 17 | org.orca3.miniAutoML.prediction.PredictionService 18 | 19 | 20 | 21 | 22 | org.orca3 23 | grpc-contract 24 | ${project.version} 25 | 26 | 27 | io.grpc 28 | grpc-services 29 | 30 | 31 | ch.qos.logback 32 | logback-classic 33 | 34 | 35 | io.minio 36 | minio 37 | 38 | 39 | 40 | io.grpc 41 | grpc-testing 42 | test 43 | 44 | 45 | junit 46 | junit 47 | RELEASE 48 | test 49 | 50 | 51 | 52 | 53 | 54 | org.codehaus.mojo 55 | exec-maven-plugin 56 | 3.0.0 57 | 58 | 59 | 60 | java 61 | 62 | 63 | 64 | 65 | ${mainClass} 66 | 67 | 68 | 69 | org.apache.maven.plugins 70 | maven-shade-plugin 71 | 72 | 73 | package 74 | 75 | shade 76 | 77 | 78 | 79 | 80 | ${mainClass} 81 | 82 | 83 | 84 | 85 | 86 | *:* 87 | 88 | META-INF/*.SF 89 | META-INF/*.DSA 90 | META-INF/*.RSA 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /prediction-service/src/main/java/org/orca3/miniAutoML/prediction/CustomGrpcPredictorBackend.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.prediction; 2 | 3 | import com.google.common.collect.Sets; 4 | import io.grpc.ManagedChannel; 5 | import io.minio.DownloadObjectArgs; 6 | import io.minio.ListObjectsArgs; 7 | import io.minio.MinioClient; 8 | import io.minio.Result; 9 | import io.minio.messages.Item; 10 | import org.orca3.miniAutoML.metadataStore.GetArtifactResponse; 11 | 12 | import java.io.File; 13 | import java.nio.file.Files; 14 | import java.nio.file.Path; 15 | import java.nio.file.Paths; 16 | import java.util.Set; 17 | 18 | public class CustomGrpcPredictorBackend implements PredictorBackend { 19 | private final PredictorGrpc.PredictorBlockingStub stub; 20 | private final String modelCachePath; 21 | private final MinioClient minioClient; 22 | private final Set downloadedModels; 23 | 24 | public CustomGrpcPredictorBackend(ManagedChannel channel, String modelCachePath, MinioClient minioClient) { 25 | stub = PredictorGrpc.newBlockingStub(channel); 26 | this.modelCachePath = modelCachePath; 27 | this.minioClient = minioClient; 28 | this.downloadedModels = Sets.newHashSet(); 29 | } 30 | 31 | @Override 32 | public void registerModel(GetArtifactResponse artifact) { 33 | return; 34 | } 35 | 36 | @Override 37 | public void downloadModel(String runId, GetArtifactResponse artifactResponse) { 38 | if (downloadedModels.contains(runId)) { 39 | return; 40 | } 41 | final String bucket = artifactResponse.getArtifact().getBucket(); 42 | try { 43 | Path tempRoot = Paths.get(modelCachePath, runId); 44 | if (Files.notExists(tempRoot)) { 45 | Files.createDirectories(tempRoot); 46 | } 47 | for (Result item : minioClient.listObjects(ListObjectsArgs.builder() 48 | .bucket(bucket) 49 | .prefix(String.format("%s/", artifactResponse.getArtifact().getPath())) 50 | .build())) { 51 | String objectName = item.get().objectName(); 52 | String fileName = Paths.get(objectName).getFileName().toString(); 53 | minioClient.downloadObject(DownloadObjectArgs.builder() 54 | .bucket(bucket) 55 | .object(item.get().objectName()) 56 | .filename(new File(tempRoot.toString(), fileName).getAbsolutePath()) 57 | .build()); 58 | } 59 | } catch (Exception e) { 60 | throw new RuntimeException(e); 61 | } 62 | registerModel(artifactResponse); 63 | downloadedModels.add(runId); 64 | } 65 | 66 | @Override 67 | public String predict(GetArtifactResponse artifact, String document) { 68 | return stub.predictorPredict(PredictorPredictRequest.newBuilder() 69 | .setDocument(document).setRunId(artifact.getRunId()).build()).getResponse(); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /prediction-service/src/main/java/org/orca3/miniAutoML/prediction/PredictionService.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.prediction; 2 | 3 | import io.grpc.ManagedChannel; 4 | import io.grpc.ManagedChannelBuilder; 5 | import io.grpc.Status; 6 | import io.grpc.stub.StreamObserver; 7 | import io.minio.MinioClient; 8 | import org.orca3.miniAutoML.ServiceBase; 9 | import org.orca3.miniAutoML.metadataStore.GetArtifactRequest; 10 | import org.orca3.miniAutoML.metadataStore.GetArtifactResponse; 11 | import org.orca3.miniAutoML.metadataStore.MetadataStoreServiceGrpc; 12 | import org.slf4j.Logger; 13 | import org.slf4j.LoggerFactory; 14 | 15 | import java.io.IOException; 16 | import java.util.Properties; 17 | 18 | public class PredictionService extends PredictionServiceGrpc.PredictionServiceImplBase { 19 | private static final Logger logger = LoggerFactory.getLogger(PredictionService.class); 20 | private final Config config; 21 | private final MetadataStoreServiceGrpc.MetadataStoreServiceBlockingStub msClient; 22 | private final PredictorConnectionManager predictorManager; 23 | 24 | public PredictionService(ManagedChannel msChannel, PredictorConnectionManager predictorManager, Config config) { 25 | this.config = config; 26 | this.msClient = MetadataStoreServiceGrpc.newBlockingStub(msChannel); 27 | this.predictorManager = predictorManager; 28 | } 29 | 30 | public static void main(String[] args) throws IOException, InterruptedException { 31 | logger.info("Hello, Prediction Service!"); 32 | Properties props = ServiceBase.getConfigProperties(); 33 | Config config = new Config(props); 34 | ManagedChannel msChannel = ManagedChannelBuilder.forAddress(config.msHost, Integer.parseInt(config.msPort)) 35 | .usePlaintext().build(); 36 | MinioClient minioClient = MinioClient.builder() 37 | .endpoint(config.minioHost) 38 | .credentials(config.minioAccessKey, config.minioSecretKey) 39 | .build(); 40 | PredictorConnectionManager connectionManager = new PredictorConnectionManager(config.modelCachePath, minioClient); 41 | for (String predictor : config.predictors) { 42 | connectionManager.registerPredictor(predictor, props); 43 | } 44 | PredictionService psService = new PredictionService(msChannel, connectionManager, config); 45 | ServiceBase.startService(Integer.parseInt(config.serverPort), psService, () -> { 46 | // Graceful shutdown 47 | msChannel.shutdown(); 48 | connectionManager.shutdown(); 49 | }); 50 | 51 | } 52 | 53 | @Override 54 | public void predict(PredictRequest request, StreamObserver responseObserver) { 55 | String runId = request.getRunId(); 56 | GetArtifactResponse artifactInfo; 57 | 58 | if (predictorManager.containsArtifact(runId)) { 59 | artifactInfo = predictorManager.getArtifact(runId); 60 | } else { 61 | try { 62 | artifactInfo = msClient.getArtifact(GetArtifactRequest.newBuilder() 63 | .setRunId(runId).build()); 64 | } catch (Exception ex) { 65 | String msg = String.format("Cannot locate model artifact for runId %s.", runId); 66 | logger.error(msg, ex); 67 | responseObserver.onError(Status.NOT_FOUND.withDescription(msg).asException()); 68 | return; 69 | } 70 | } 71 | 72 | PredictorBackend predictor; 73 | if (predictorManager.containsPredictor(artifactInfo.getAlgorithm())) { 74 | predictor = predictorManager.getPredictor(artifactInfo.getAlgorithm()); 75 | } else { 76 | String msg = String.format("Algorithm %s doesn't have supporting predictor.", artifactInfo.getAlgorithm()); 77 | logger.error(msg); 78 | responseObserver.onError(Status.FAILED_PRECONDITION.withDescription(msg).asException()); 79 | return; 80 | } 81 | try { 82 | predictor.downloadModel(runId, artifactInfo); 83 | String r = predictor.predict(artifactInfo, request.getDocument()); 84 | responseObserver.onNext(PredictResponse.newBuilder().setResponse(r).build()); 85 | responseObserver.onCompleted(); 86 | } catch (Exception ex) { 87 | String msg = String.format("Prediction failed for algorithm %s: %s", artifactInfo.getAlgorithm(), ex.getMessage()); 88 | logger.error(msg, ex); 89 | responseObserver.onError(Status.UNKNOWN 90 | .withDescription(msg) 91 | .asException()); 92 | } 93 | 94 | } 95 | 96 | static class Config { 97 | final String msPort; 98 | final String msHost; 99 | final String minioAccessKey; 100 | final String minioSecretKey; 101 | final String minioHost; 102 | final String serverPort; 103 | final String modelCachePath; 104 | final String[] predictors; 105 | 106 | 107 | public Config(Properties properties) { 108 | this.msPort = properties.getProperty("ms.server.port"); 109 | this.msHost = properties.getProperty("ms.server.host"); 110 | this.minioAccessKey = properties.getProperty("minio.accessKey"); 111 | this.minioSecretKey = properties.getProperty("minio.secretKey"); 112 | this.minioHost = properties.getProperty("minio.host"); 113 | this.serverPort = properties.getProperty("ps.server.port"); 114 | this.modelCachePath = properties.getProperty("ps.server.modelCachePath"); 115 | this.predictors = properties.getProperty("ps.enabledPredictors").split(","); 116 | 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /prediction-service/src/main/java/org/orca3/miniAutoML/prediction/PredictorBackend.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.prediction; 2 | 3 | import org.orca3.miniAutoML.metadataStore.GetArtifactResponse; 4 | 5 | public interface PredictorBackend { 6 | void downloadModel(String runId, GetArtifactResponse artifactResponse); 7 | 8 | String predict(GetArtifactResponse artifact, String document); 9 | 10 | void registerModel(GetArtifactResponse artifact); 11 | } 12 | -------------------------------------------------------------------------------- /prediction-service/src/main/java/org/orca3/miniAutoML/prediction/PredictorConnectionManager.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.prediction; 2 | 3 | import com.google.common.collect.Maps; 4 | import io.grpc.ManagedChannel; 5 | import io.grpc.ManagedChannelBuilder; 6 | import io.grpc.Status; 7 | import io.minio.DownloadObjectArgs; 8 | import io.minio.ListObjectsArgs; 9 | import io.minio.MinioClient; 10 | import io.minio.Result; 11 | import io.minio.messages.Item; 12 | import org.orca3.miniAutoML.metadataStore.GetArtifactResponse; 13 | 14 | import java.io.File; 15 | import java.nio.file.Files; 16 | import java.nio.file.Path; 17 | import java.nio.file.Paths; 18 | import java.util.HashMap; 19 | import java.util.List; 20 | import java.util.Map; 21 | import java.util.Properties; 22 | 23 | public class PredictorConnectionManager { 24 | private final Map> channels = new HashMap<>(); 25 | private final Map clients = new HashMap<>(); 26 | private final String modelCachePath; 27 | private final MinioClient minioClient; 28 | private final Map artifactCache; 29 | 30 | public PredictorConnectionManager(String modelCachePath, MinioClient minioClient) { 31 | this.modelCachePath = modelCachePath; 32 | this.minioClient = minioClient; 33 | this.artifactCache = Maps.newHashMap(); 34 | } 35 | 36 | public boolean containsArtifact(String runId) { 37 | return artifactCache.containsKey(runId); 38 | } 39 | 40 | public GetArtifactResponse getArtifact(String runId) { 41 | return artifactCache.get(runId); 42 | } 43 | 44 | 45 | public void registerPredictor(String algorithm, Properties properties) { 46 | String host = properties.getProperty(String.format("predictors.%s.host", algorithm)); 47 | int port = Integer.parseInt(properties.getProperty(String.format("predictors.%s.port", algorithm))); 48 | String predictorType = properties.getProperty(String.format("predictors.%s.techStack", algorithm)); 49 | if (channels.containsKey(algorithm)) { 50 | channels.remove(algorithm).forEach(ManagedChannel::shutdown); 51 | } 52 | ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port) 53 | .usePlaintext().build(); 54 | switch (predictorType) { 55 | case "torch": 56 | int managementPort = Integer.parseInt(properties.getProperty(String.format("predictors.%s.management-port", algorithm))); 57 | ManagedChannel managementChannel = ManagedChannelBuilder.forAddress(host, managementPort) 58 | .usePlaintext().build(); 59 | channels.put(algorithm, List.of(channel, managementChannel)); 60 | clients.put(algorithm, new TorchGrpcPredictorBackend(channel, managementChannel, modelCachePath, minioClient)); 61 | break; 62 | case "customGrpc": 63 | default: 64 | channels.put(algorithm, List.of(channel)); 65 | clients.put(algorithm, new CustomGrpcPredictorBackend(channel, modelCachePath, minioClient)); 66 | break; 67 | } 68 | } 69 | 70 | public boolean containsPredictor(String algorithm) { 71 | return clients.containsKey(algorithm); 72 | } 73 | 74 | public PredictorBackend getPredictor(String algorithm) { 75 | return clients.get(algorithm); 76 | } 77 | 78 | public void shutdown() { 79 | channels.values().forEach(channels -> channels.forEach(ManagedChannel::shutdown)); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /prediction-service/src/main/java/org/orca3/miniAutoML/prediction/TorchGrpcPredictorBackend.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.prediction; 2 | 3 | import com.google.common.collect.ImmutableMap; 4 | import com.google.common.collect.Sets; 5 | import com.google.protobuf.ByteString; 6 | import io.grpc.ManagedChannel; 7 | import io.minio.DownloadObjectArgs; 8 | import io.minio.MinioClient; 9 | import org.orca3.miniAutoML.metadataStore.GetArtifactResponse; 10 | import org.pytorch.serve.grpc.inference.InferenceAPIsServiceGrpc; 11 | import org.pytorch.serve.grpc.inference.PredictionsRequest; 12 | import org.pytorch.serve.grpc.management.ManagementAPIsServiceGrpc; 13 | import org.pytorch.serve.grpc.management.ManagementResponse; 14 | import org.pytorch.serve.grpc.management.RegisterModelRequest; 15 | import org.pytorch.serve.grpc.management.ScaleWorkerRequest; 16 | import org.slf4j.Logger; 17 | import org.slf4j.LoggerFactory; 18 | 19 | import java.io.File; 20 | import java.nio.charset.StandardCharsets; 21 | import java.nio.file.Paths; 22 | import java.util.Set; 23 | 24 | public class TorchGrpcPredictorBackend implements PredictorBackend { 25 | private static final Logger logger = LoggerFactory.getLogger(TorchGrpcPredictorBackend.class); 26 | private static final String MODEL_FILE_NAME_TEMPLATE = "model-%s.mar"; 27 | private static final String TORCH_MODEL_MAR = "model.mar"; 28 | private static final String TORCH_MODEL_NAME_TEMPLATE = "%s-%s"; 29 | 30 | private final InferenceAPIsServiceGrpc.InferenceAPIsServiceBlockingStub stub; 31 | private final ManagementAPIsServiceGrpc.ManagementAPIsServiceBlockingStub managementStub; 32 | private final Set downloadedModels; 33 | private final String modelCachePath; 34 | private final MinioClient minioClient; 35 | 36 | 37 | public TorchGrpcPredictorBackend(ManagedChannel predictorChannel, ManagedChannel managementChannel, 38 | String modelCachePath, MinioClient minioClient) { 39 | this.stub = InferenceAPIsServiceGrpc.newBlockingStub(predictorChannel); 40 | this.managementStub = ManagementAPIsServiceGrpc.newBlockingStub(managementChannel); 41 | this.modelCachePath = modelCachePath; 42 | this.minioClient = minioClient; 43 | this.downloadedModels = Sets.newHashSet(); 44 | } 45 | 46 | @Override 47 | public void downloadModel(String runId, GetArtifactResponse artifactResponse) { 48 | if (downloadedModels.contains(runId)) { 49 | return; 50 | } 51 | final String bucket = artifactResponse.getArtifact().getBucket(); 52 | try { 53 | minioClient.downloadObject(DownloadObjectArgs.builder() 54 | .bucket(bucket) 55 | .object(Paths.get(artifactResponse.getArtifact().getPath(), TORCH_MODEL_MAR).toString()) 56 | .filename(new File(modelCachePath, String.format(MODEL_FILE_NAME_TEMPLATE, runId)).getAbsolutePath()) 57 | .build()); 58 | } catch (Exception e) { 59 | throw new RuntimeException(e); 60 | } 61 | registerModel(artifactResponse); 62 | downloadedModels.add(runId); 63 | } 64 | 65 | @Override 66 | public void registerModel(GetArtifactResponse artifact) { 67 | String modelUrl = String.format(MODEL_FILE_NAME_TEMPLATE, artifact.getRunId()); 68 | try { 69 | String torchModelName = String.format(TORCH_MODEL_NAME_TEMPLATE, artifact.getName(), artifact.getVersion()); 70 | ManagementResponse r = managementStub.registerModel(RegisterModelRequest.newBuilder() 71 | .setUrl(modelUrl) 72 | .setModelName(torchModelName) 73 | .build()); 74 | logger.info(r.getMsg()); 75 | managementStub.scaleWorker(ScaleWorkerRequest.newBuilder() 76 | .setModelName(torchModelName) 77 | .setMinWorker(1) 78 | .build()); 79 | } catch (Exception e) { 80 | logger.error("Failed to register model", e); 81 | throw new RuntimeException(e); 82 | } 83 | } 84 | 85 | @Override 86 | public String predict(GetArtifactResponse artifact, String document) { 87 | return stub.predictions(PredictionsRequest.newBuilder() 88 | .setModelName(String.format(TORCH_MODEL_NAME_TEMPLATE, artifact.getName(), artifact.getVersion())) 89 | .putAllInput(ImmutableMap.of("data", ByteString.copyFrom(document, StandardCharsets.UTF_8))) 90 | .build()).getPrediction() 91 | .toString(StandardCharsets.UTF_8); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /prediction-service/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /predictor/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime 2 | 3 | RUN pip3 install grpcio protobuf~=3.20.0 grpcio-health-checking 4 | ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 5 | 6 | RUN mkdir /opt/intent-predictor 7 | COPY *.py /opt/intent-predictor/ 8 | WORKDIR /opt/intent-predictor 9 | 10 | # Add folder for the logs. 11 | RUN mkdir /models 12 | RUN mkdir /logs 13 | 14 | ENV MODEL_DIR=/models 15 | EXPOSE 51001 16 | 17 | RUN chgrp -R 0 /opt/intent-predictor \ 18 | && chmod -R g+rwX /opt/intent-predictor \ 19 | && chgrp -R 0 /models \ 20 | && chmod -R g+rwX /models \ 21 | && chgrp -R 0 /logs \ 22 | && chmod -R g+rwX /logs 23 | 24 | ENTRYPOINT ["python3", "-u", "/opt/intent-predictor/predict.py"] 25 | -------------------------------------------------------------------------------- /predictor/README.md: -------------------------------------------------------------------------------- 1 | # Run TorchServe Locally with a Sample Intent Model 2 | 3 | The following instructions assume that your current working directory is in the 4 | `predictor` subfolder. Please make sure Docker is running before proceeding. 5 | 6 | ## Step 1: Copy the sample intent model to a directory for TorchServe 7 | ```shell 8 | mkdir -p /tmp/model_store/torchserving 9 | cp sample_models/1/intent*.mar /tmp/model_store/torchserving 10 | ``` 11 | 12 | ## Step 2: Run the TorchServe container 13 | ```shell 14 | docker pull pytorch/torchserve:0.4.2-cpu 15 | docker run --rm --shm-size=1g \ 16 | --ulimit memlock=-1 \ 17 | --ulimit stack=67108864 \ 18 | -p8080:8080 \ 19 | -p8081:8081 \ 20 | -p8082:8082 \ 21 | -p7070:7070 \ 22 | -p7071:7071 \ 23 | --mount type=bind,source=/tmp/model_store/torchserving,target=/tmp/models pytorch/torchserve:0.4.2-cpu torchserve --model-store=/tmp/models 24 | ``` 25 | 26 | ## Step 3: Register model with TorchServe management API 27 | ```shell 28 | curl -X POST "http://localhost:8081/models?url=intent_80bf0da.mar&initial_workers=1&model_name=intent" 29 | ``` 30 | The response should look like 31 | ```shell 32 | { 33 | "status": "Model \"intent\" Version: 1.0 registered with 1 initial workers" 34 | } 35 | ``` 36 | 37 | ## Step 4: Request predictions from the default version of the intent model 38 | ```shell 39 | curl --location --request GET 'http://localhost:8080/predictions/intent' \ 40 | --header 'Content-Type: text/plain' \ 41 | --data-raw 'make a 10 minute timer' 42 | ``` 43 | The response should look like 44 | ```shell 45 | { 46 | "predict_res": "timer" 47 | } 48 | ``` 49 | 50 | ## Step 5: Request predictions from a specific version of the intent model 51 | This version is created at training time. 52 | ```shell 53 | curl --location --request GET 'http://localhost:8080/predictions/intent/1.0' \ 54 | --header 'Content-Type: text/plain' \ 55 | --data-raw 'make a 10 minute timer' 56 | ``` 57 | The response should look like 58 | ```shell 59 | { 60 | "predict_res": "timer" 61 | } 62 | ``` -------------------------------------------------------------------------------- /predictor/prediction_service_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import prediction_service_pb2 as prediction__service__pb2 6 | 7 | 8 | class PredictionServiceStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Predict = channel.unary_unary( 18 | '/prediction.PredictionService/Predict', 19 | request_serializer=prediction__service__pb2.PredictRequest.SerializeToString, 20 | response_deserializer=prediction__service__pb2.PredictResponse.FromString, 21 | ) 22 | 23 | 24 | class PredictionServiceServicer(object): 25 | """Missing associated documentation comment in .proto file.""" 26 | 27 | def Predict(self, request, context): 28 | """Missing associated documentation comment in .proto file.""" 29 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 30 | context.set_details('Method not implemented!') 31 | raise NotImplementedError('Method not implemented!') 32 | 33 | 34 | def add_PredictionServiceServicer_to_server(servicer, server): 35 | rpc_method_handlers = { 36 | 'Predict': grpc.unary_unary_rpc_method_handler( 37 | servicer.Predict, 38 | request_deserializer=prediction__service__pb2.PredictRequest.FromString, 39 | response_serializer=prediction__service__pb2.PredictResponse.SerializeToString, 40 | ), 41 | } 42 | generic_handler = grpc.method_handlers_generic_handler( 43 | 'prediction.PredictionService', rpc_method_handlers) 44 | server.add_generic_rpc_handlers((generic_handler,)) 45 | 46 | 47 | # This class is part of an EXPERIMENTAL API. 48 | class PredictionService(object): 49 | """Missing associated documentation comment in .proto file.""" 50 | 51 | @staticmethod 52 | def Predict(request, 53 | target, 54 | options=(), 55 | channel_credentials=None, 56 | call_credentials=None, 57 | insecure=False, 58 | compression=None, 59 | wait_for_ready=None, 60 | timeout=None, 61 | metadata=None): 62 | return grpc.experimental.unary_unary(request, target, '/prediction.PredictionService/Predict', 63 | prediction__service__pb2.PredictRequest.SerializeToString, 64 | prediction__service__pb2.PredictResponse.FromString, 65 | options, channel_credentials, 66 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 67 | 68 | 69 | class PredictorStub(object): 70 | """Missing associated documentation comment in .proto file.""" 71 | 72 | def __init__(self, channel): 73 | """Constructor. 74 | 75 | Args: 76 | channel: A grpc.Channel. 77 | """ 78 | self.PredictorPredict = channel.unary_unary( 79 | '/prediction.Predictor/PredictorPredict', 80 | request_serializer=prediction__service__pb2.PredictorPredictRequest.SerializeToString, 81 | response_deserializer=prediction__service__pb2.PredictorPredictResponse.FromString, 82 | ) 83 | 84 | 85 | class PredictorServicer(object): 86 | """Missing associated documentation comment in .proto file.""" 87 | 88 | def PredictorPredict(self, request, context): 89 | """Missing associated documentation comment in .proto file.""" 90 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 91 | context.set_details('Method not implemented!') 92 | raise NotImplementedError('Method not implemented!') 93 | 94 | 95 | def add_PredictorServicer_to_server(servicer, server): 96 | rpc_method_handlers = { 97 | 'PredictorPredict': grpc.unary_unary_rpc_method_handler( 98 | servicer.PredictorPredict, 99 | request_deserializer=prediction__service__pb2.PredictorPredictRequest.FromString, 100 | response_serializer=prediction__service__pb2.PredictorPredictResponse.SerializeToString, 101 | ), 102 | } 103 | generic_handler = grpc.method_handlers_generic_handler( 104 | 'prediction.Predictor', rpc_method_handlers) 105 | server.add_generic_rpc_handlers((generic_handler,)) 106 | 107 | 108 | # This class is part of an EXPERIMENTAL API. 109 | class Predictor(object): 110 | """Missing associated documentation comment in .proto file.""" 111 | 112 | @staticmethod 113 | def PredictorPredict(request, 114 | target, 115 | options=(), 116 | channel_credentials=None, 117 | call_credentials=None, 118 | insecure=False, 119 | compression=None, 120 | wait_for_ready=None, 121 | timeout=None, 122 | metadata=None): 123 | return grpc.experimental.unary_unary(request, target, '/prediction.Predictor/PredictorPredict', 124 | prediction__service__pb2.PredictorPredictRequest.SerializeToString, 125 | prediction__service__pb2.PredictorPredictResponse.FromString, 126 | options, channel_credentials, 127 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 128 | -------------------------------------------------------------------------------- /predictor/sample_models/1/intent_80bf0da.mar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orca3/MiniAutoML/fa9eb34858f9a5a743489f31348700b4ce32154e/predictor/sample_models/1/intent_80bf0da.mar -------------------------------------------------------------------------------- /predictor/sample_models/1/manifest.json: -------------------------------------------------------------------------------- 1 | {"Algorithm": "intent-classification", "Framework": "Pytorch", "FrameworkVersion": "1.9.0", "ModelName": "intent", "CodeVersion": "80bf0da", "ModelVersion": "1.0", "classes": {"0": "cancel", "1": "ingredients_list", "2": "nutrition_info", "3": "greeting", "4": "taxes", "5": "change_volume", "6": "order_status", "7": "change_speed", "8": "find_phone", "9": "balance", "10": "plug_type", "11": "weather", "12": "interest_rate", "13": "flip_coin", "14": "measurement_conversion", "15": "food_last", "16": "text", "17": "meeting_schedule", "18": "direct_deposit", "19": "order", "20": "calculator", "21": "apr", "22": "pto_used", "23": "account_blocked", "24": "change_language", "25": "exchange_rate", "26": "mpg", "27": "restaurant_suggestion", "28": "transactions", "29": "play_music", "30": "insurance_change", "31": "tire_change", "32": "flight_status", "33": "report_fraud", "34": "calendar_update", "35": "thank_you", "36": "change_ai_name", "37": "fun_fact", "38": "shopping_list_update", "39": "change_accent", "40": "how_old_are_you", "41": "improve_credit_score", "42": "credit_score", "43": "confirm_reservation", "44": "international_fees", "45": "maybe", "46": "user_name", "47": "goodbye", "48": "book_flight", "49": "how_busy", "50": "cancel_reservation", "51": "routing", "52": "schedule_maintenance", "53": "restaurant_reviews", "54": "schedule_meeting", "55": "lost_luggage", "56": "international_visa", "57": "what_song", "58": "payday", "59": "tire_pressure", "60": "update_playlist", "61": "new_card", "62": "calendar", "63": "order_checks", "64": "who_made_you", "65": "last_maintenance", "66": "oil_change_how", "67": "reminder", "68": "gas_type", "69": "credit_limit_change", "70": "next_song", "71": "book_hotel", "72": "spending_history", "73": "travel_suggestion", "74": "rollover_401k", "75": "todo_list_update", "76": "make_call", "77": "insurance", "78": "date", "79": "who_do_you_work_for", "80": "damaged_card", "81": "meaning_of_life", "82": "min_payment", "83": "expiration_date", "84": "translate", "85": "are_you_a_bot", "86": "cook_time", "87": "oos", "88": "application_status", "89": "w2", "90": "todo_list", "91": "tell_joke", "92": "pay_bill", "93": "current_location", "94": "traffic", "95": "where_are_you_from", "96": "change_user_name", "97": "pin_change", "98": "what_is_your_name", "99": "travel_alert", "100": "report_lost_card", "101": "share_location", "102": "bill_balance", "103": "shopping_list", "104": "pto_request_status", "105": "next_holiday", "106": "reminder_update", "107": "transfer", "108": "sync_device", "109": "directions", "110": "bill_due", "111": "car_rental", "112": "card_declined", "113": "reset_settings", "114": "replacement_card_duration", "115": "income", "116": "no", "117": "smart_home", "118": "distance", "119": "vaccines", "120": "timezone", "121": "meal_suggestion", "122": "recipe", "123": "rewards_balance", "124": "uber", "125": "oil_change_when", "126": "timer", "127": "pto_balance", "128": "repeat", "129": "ingredient_substitution", "130": "alarm", "131": "credit_limit", "132": "gas", "133": "accept_reservations", "134": "definition", "135": "redeem_rewards", "136": "pto_request", "137": "carry_on", "138": "roll_dice", "139": "jump_start", "140": "spelling", "141": "yes", "142": "what_can_i_ask_you", "143": "whisper_mode", "144": "calories", "145": "what_are_your_hobbies", "146": "freeze_account", "147": "travel_notification", "148": "restaurant_reservation", "149": "do_you_have_pets", "150": "time"}, "fc_size": 128} 2 | -------------------------------------------------------------------------------- /predictor/sample_models/1/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orca3/MiniAutoML/fa9eb34858f9a5a743489f31348700b4ce32154e/predictor/sample_models/1/model.pth -------------------------------------------------------------------------------- /predictor/sample_models/1/vocab.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orca3/MiniAutoML/fa9eb34858f9a5a743489f31348700b4ce32154e/predictor/sample_models/1/vocab.pth -------------------------------------------------------------------------------- /predictor/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class PredictorConfig: 5 | @staticmethod 6 | def int_or_default(variable, default): 7 | if variable is None: 8 | return default 9 | else: 10 | return int(variable) 11 | 12 | def __str__(self) -> str: 13 | results = [ 14 | "{}={}".format("MODEL_DIR", self.MODEL_DIR), 15 | "{}={}".format("FC_SIZE", self.FC_SIZE), 16 | ] 17 | return "\n".join(results) 18 | 19 | def __init__(self): 20 | self.MODEL_DIR = os.getenv('MODEL_DIR') or "/models" 21 | self.FC_SIZE = self.int_or_default(os.getenv('FC_SIZE'), 128) 22 | -------------------------------------------------------------------------------- /scripts/build-images-locally.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo "Building orca3/services" 5 | docker build \ 6 | -t orca3/services:latest \ 7 | -f "$(dirname "$0")/../services.dockerfile" \ 8 | "$(dirname "$0")/.." 9 | echo "" 10 | 11 | echo "Building orca3/intent-classification-predictor" 12 | docker build \ 13 | -t orca3/intent-classification-predictor:latest \ 14 | -f "$(dirname "$0")/../predictor/Dockerfile" \ 15 | "$(dirname "$0")/../predictor" 16 | echo "" 17 | 18 | echo "Building orca3/intent-classification & orca3/intent-classification-torch" 19 | docker build \ 20 | -t orca3/intent-classification:latest \ 21 | -t orca3/intent-classification-torch:latest \ 22 | -f "$(dirname "$0")/../training-code/text-classification/Dockerfile" \ 23 | "$(dirname "$0")/../training-code/text-classification" 24 | echo "" 25 | 26 | -------------------------------------------------------------------------------- /scripts/dm-001-start-minio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ ! "$(docker network ls | grep orca3)" ]; then 5 | docker network create orca3 6 | echo "Created docker network orca3" 7 | else 8 | echo "Docker network orca3 already exists" 9 | fi 10 | 11 | if [ ! "$(docker ps -a | grep minio)" ]; then 12 | docker run --name minio --network orca3 --rm -d -p "${MINIO_PORT}":9000 -e MINIO_ROOT_USER -e MINIO_ROOT_PASSWORD minio/minio server /data 13 | echo "Started minio docker container and listen on port 9000" 14 | else 15 | echo "Minio docker container is already running" 16 | fi 17 | -------------------------------------------------------------------------------- /scripts/dm-002-start-server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ ! "$(docker ps -a | grep data-management)" ]; then 5 | docker run --name data-management \ 6 | --network orca3 \ 7 | --rm -d \ 8 | -p "${DM_PORT}":51001 \ 9 | "${IMAGE_NAME}" \ 10 | data-management.jar 11 | echo "Started data-management docker container and listen on port ${DM_PORT}" 12 | else 13 | echo "data-management docker container is already running" 14 | fi 15 | -------------------------------------------------------------------------------- /scripts/dm-003-create-dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | mc alias -q set myminio http://127.0.0.1:"${MINIO_PORT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" 4 | 5 | echo 6 | echo "Upload raw data to cloud object storage to get a data url. For demo purpose, we upload data to 'mini-automl-dm' bucket in the local MinIO server, data url to reference the data is 'upload/001.csv'" 7 | mc -q cp data-management/src/test/resources/datasets/demo-part1.csv myminio/"${MINIO_DM_BUCKET}"/upload/001.csv 8 | echo 9 | echo "Creating intent dataset" 10 | grpcurl -plaintext \ 11 | -d '{"name": "dataset-1", "dataset_type": "TEXT_INTENT", "bucket": "mini-automl-dm", "path": "upload/001.csv", "tags": [{"tag_key": "category", "tag_value": "aaa"}]}' \ 12 | localhost:"${DM_PORT}" data_management.DataManagementService/CreateDataset 13 | -------------------------------------------------------------------------------- /scripts/dm-004-add-commits.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | mc alias -q set myminio http://127.0.0.1:"${MINIO_PORT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" 4 | 5 | if [ "$1" != "" ]; then 6 | echo 7 | echo "Uploading new data file" 8 | mc -q cp data-management/src/test/resources/datasets/demo-part2.csv myminio/"${MINIO_DM_BUCKET}"/upload/002.csv 9 | mc -q cp data-management/src/test/resources/datasets/demo-part3.csv myminio/"${MINIO_DM_BUCKET}"/upload/003.csv 10 | echo 11 | echo "Adding new commit to dataset $1" 12 | grpcurl -plaintext \ 13 | -d "{\"dataset_id\": \"$1\", \"commit_message\": \"More training data\", \"bucket\": \"${MINIO_DM_BUCKET}\", \"path\": \"upload/002.csv\", \"tags\": [{\"tag_key\": \"category\", \"tag_value\": \"aaa\"}]}" \ 14 | localhost:"${DM_PORT}" data_management.DataManagementService/UpdateDataset 15 | grpcurl -plaintext \ 16 | -d "{\"dataset_id\": \"$1\", \"commit_message\": \"More training data\", \"bucket\": \"${MINIO_DM_BUCKET}\", \"path\": \"upload/003.csv\", \"tags\": [{\"tag_key\": \"category\", \"tag_value\": \"bbb\"}]}" \ 17 | localhost:"${DM_PORT}" data_management.DataManagementService/UpdateDataset 18 | else 19 | echo "Requires dataset_id as the first parameter" 20 | fi 21 | 22 | -------------------------------------------------------------------------------- /scripts/dm-005-prepare-dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" != "" ]; then 5 | echo "Prepare a version of dataset $1 that contains all commits" 6 | grpcurl -plaintext \ 7 | -d "{\"dataset_id\": \"$1\"}" \ 8 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 9 | else 10 | echo "Requires dataset_id as the first parameter" 11 | fi 12 | 13 | -------------------------------------------------------------------------------- /scripts/dm-006-prepare-partial-dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" != "" ] && [ "$2" != "" ]; then 5 | echo "Prepare a version of dataset $1 that contains only training data with tag category:$2" 6 | grpcurl -plaintext \ 7 | -d "{\"dataset_id\": \"$1\", \"tags\":[{\"tag_key\":\"category\", \"tag_value\":\"$2\"}]}" \ 8 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 9 | else 10 | echo "Requires dataset_id as the first parameter" 11 | echo "Requires tag_value as the second parameter" 12 | 13 | fi 14 | 15 | -------------------------------------------------------------------------------- /scripts/dm-007-fetch-dataset-version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" != "" ] && [ "$2" != "" ]; then 5 | echo "Fetching dataset $1 with version $2" 6 | grpcurl -plaintext \ 7 | -d "{\"dataset_id\": \"$1\", \"version_hash\": \"$2\"}" \ 8 | localhost:"${DM_PORT}" data_management.DataManagementService/FetchTrainingDataset 9 | else 10 | echo "Requires dataset_id as the first parameter" 11 | echo "Requires version_hash as the second parameter" 12 | fi 13 | 14 | -------------------------------------------------------------------------------- /scripts/env-vars.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export IMAGE_NAME=orca3/services:latest 3 | export MINIO_ROOT_USER=foooo 4 | export MINIO_ROOT_PASSWORD=barbarbar 5 | export MINIO_DM_BUCKET=mini-automl-dm 6 | export MINIO_PORT=9000 7 | export DM_PORT=6000 8 | export PS_PORT=6001 9 | export MS_PORT=6002 10 | export TS_PORT=6003 11 | export ICP_PORT=6101 12 | export ICP_TORCH_PORT=6102 13 | export ICP_TORCH_MGMT_PORT=6103 14 | -------------------------------------------------------------------------------- /scripts/lab-001-start-all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if ! docker network ls | grep -q orca3 ; then 5 | docker network create orca3 6 | echo "Created docker network orca3" 7 | else 8 | echo "Docker network orca3 already exists" 9 | fi 10 | 11 | if ! docker ps -a | grep -q minio ; then 12 | docker run --name minio \ 13 | --network orca3 \ 14 | -d \ 15 | -p "${MINIO_PORT}":9000 \ 16 | -e MINIO_ROOT_USER -e MINIO_ROOT_PASSWORD \ 17 | minio/minio server /data 18 | echo "Started minio docker container and listen on port 9000" 19 | else 20 | echo "Minio docker container is already running" 21 | fi 22 | 23 | if ! docker ps -a | grep -q data-management ; then 24 | docker run --name data-management \ 25 | --network orca3 \ 26 | -d \ 27 | -p "${DM_PORT}":51001 \ 28 | "${IMAGE_NAME}" \ 29 | data-management.jar 30 | echo "Started data-management docker container and listen on port ${DM_PORT}" 31 | else 32 | echo "data-management docker container is already running" 33 | fi 34 | 35 | if ! docker ps -a | grep -q metadata-store ; then 36 | docker run --name metadata-store \ 37 | --network orca3 \ 38 | -d \ 39 | -p "${MS_PORT}":51001 \ 40 | "${IMAGE_NAME}" \ 41 | metadata-store.jar 42 | echo "Started metadata-store docker container and listen on port ${MS_PORT}" 43 | else 44 | echo "metadata-store docker container is already running" 45 | fi 46 | 47 | rm -r model_cache 48 | mkdir -p model_cache 49 | MODEL_CACHE_DIR="$(pwd)/model_cache" 50 | 51 | if ! docker ps -a | grep -q intent-classification-predictor ; then 52 | docker run --name intent-classification-predictor \ 53 | --network orca3 \ 54 | -d \ 55 | -p "${ICP_PORT}":51001 \ 56 | -v "${MODEL_CACHE_DIR}":/models \ 57 | orca3/intent-classification-predictor:latest 58 | echo "Started intent-classification-predictor docker container and listen on port ${ICP_PORT}" 59 | else 60 | echo "intent-classification-predictor docker container is already running" 61 | fi 62 | 63 | if ! docker ps -a | grep -q intent-classification-torch-predictor ; then 64 | docker run --name intent-classification-torch-predictor \ 65 | --network orca3 \ 66 | -d \ 67 | -p "${ICP_TORCH_PORT}":7070 -p "${ICP_TORCH_MGMT_PORT}":7071 \ 68 | -v "${MODEL_CACHE_DIR}":/models \ 69 | -v "$(pwd)/config/torch_server_config.properties":/home/model-server/config.properties \ 70 | pytorch/torchserve:0.5.2-cpu torchserve \ 71 | --start --model-store /models 72 | echo "Started intent-classification-torch-predictor docker container and listen on port ${ICP_TORCH_PORT} & ${ICP_TORCH_MGMT_PORT}" 73 | else 74 | echo "intent-classification-torch-predictor docker container is already running" 75 | fi 76 | 77 | if ! docker ps -a | grep -q prediction-service ; then 78 | docker run --name prediction-service \ 79 | --network orca3 \ 80 | -d \ 81 | -p "${PS_PORT}":51001 \ 82 | -v "${MODEL_CACHE_DIR}":/tmp/modelCache \ 83 | "${IMAGE_NAME}" \ 84 | prediction-service.jar 85 | echo "Started prediction-service docker container and listen on port ${PS_PORT}" 86 | else 87 | echo "prediction-service docker container is already running" 88 | fi 89 | 90 | if ! docker image ls | grep -vq predictor | grep -q "orca3/intent-classification" ; then 91 | docker pull orca3/intent-classification:latest 92 | echo "pull intent-classification training image" 93 | else 94 | echo "intent-classification image already exists" 95 | fi 96 | 97 | if ! docker ps -a | grep -q training-service ; then 98 | docker run --name training-service \ 99 | --network orca3 \ 100 | -d \ 101 | -p "${TS_PORT}":51001 \ 102 | -v /var/run/docker.sock:/var/run/docker.sock \ 103 | "${IMAGE_NAME}" \ 104 | training-service.jar 105 | echo "Started training-service docker container and listen on port ${TS_PORT}" 106 | else 107 | echo "training-service docker container is already running" 108 | fi 109 | -------------------------------------------------------------------------------- /scripts/lab-002-upload-data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | mc alias -q set myminio http://127.0.0.1:"${MINIO_PORT}" "${MINIO_ROOT_USER}" "${MINIO_ROOT_PASSWORD}" 5 | 6 | conda install -c huggingface -c conda-forge datasets~=1.18.0 7 | python3 "$(dirname "$0")/prepare_data.py" 8 | 9 | 10 | echo 11 | echo "Upload raw data to cloud object storage to get a data url'" 12 | mc -q cp tweet_emotion_part1.csv myminio/"${MINIO_DM_BUCKET}"/upload/tweet_emotion_part1.csv 13 | mc -q cp tweet_emotion_part2.csv myminio/"${MINIO_DM_BUCKET}"/upload/tweet_emotion_part2.csv 14 | echo 15 | echo "Creating intent dataset" 16 | 17 | grpcurl -plaintext \ 18 | -d '{"name": "tweet_emotion", "dataset_type": "TEXT_INTENT", "bucket": "mini-automl-dm", "path": "upload/tweet_emotion_part1.csv"}' \ 19 | localhost:"${DM_PORT}" data_management.DataManagementService/CreateDataset 20 | -------------------------------------------------------------------------------- /scripts/lab-003-first-training.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" == "" ]; then 5 | echo "Requires dataset_id as the first parameter" 6 | exit 1 7 | fi 8 | 9 | echo "dataset_id is $1" 10 | dataset_id=$1 11 | 12 | function prepare_dataset() { 13 | grpcurl -plaintext \ 14 | -d "{\"dataset_id\": \"$1\"}" \ 15 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 16 | } 17 | 18 | _temp=$(prepare_dataset "$dataset_id") 19 | version_hash=$(echo -n "$_temp" | jq ".version_hash") 20 | 21 | echo "version_hash is $version_hash" 22 | 23 | function start_training() { 24 | grpcurl -plaintext \ 25 | -d "{ 26 | \"metadata\": { 27 | \"algorithm\":\"intent-classification\", 28 | \"dataset_id\":\"$1\", 29 | \"name\":\"test1\", 30 | \"train_data_version_hash\":$2, 31 | \"output_model_name\":\"twitter-model\", 32 | \"parameters\": { 33 | \"LR\":\"4\", 34 | \"EPOCHS\":\"15\", 35 | \"BATCH_SIZE\":\"64\", 36 | \"FC_SIZE\":\"128\" 37 | } 38 | } 39 | }" \ 40 | localhost:"${TS_PORT}" training.TrainingService/Train 41 | 42 | } 43 | 44 | _temp=$(start_training "$dataset_id" "$version_hash") 45 | job_id=$(echo -n "$_temp" | jq ".job_id") 46 | echo "job_id is $job_id" 47 | 48 | function check_job_status() { 49 | grpcurl -plaintext \ 50 | -d "{\"job_id\": \"$1\"}" \ 51 | localhost:"${TS_PORT}" training.TrainingService/GetTrainingStatus 52 | } 53 | 54 | job_status="unknown" 55 | until [ "$job_status" == "\"failure\"" ] || [ "$job_status" == "\"succeed\"" ]; 56 | do 57 | echo "job $job_id is currently in $job_status status, check back in 5 seconds" 58 | sleep 5 59 | _temp=$(check_job_status "$job_id") 60 | job_status=$(echo -n "$_temp" | jq ".status") 61 | done 62 | 63 | grpcurl -plaintext \ 64 | -d "{\"run_id\": \"$job_id\"}" \ 65 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/GetRunStatus 66 | 67 | grpcurl -plaintext \ 68 | -d "{ 69 | \"runId\": \"$job_id\", 70 | \"document\": \"You can have a certain #arrogance, and I think that's fine, but what you should never lose is the #respect for the others.\" 71 | }" \ 72 | localhost:"${PS_PORT}" prediction.PredictionService/Predict 73 | -------------------------------------------------------------------------------- /scripts/lab-004-model-serving.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" == "" ]; then 5 | echo "Requires model_id/run_id as the first parameter" 6 | exit 1 7 | fi 8 | 9 | # model_id is run_id and job_id from training service 10 | model_id=$1 11 | document=$2 12 | 13 | echo "model_id is $model_id" 14 | echo "document is $document" 15 | 16 | grpcurl -plaintext \ 17 | -d "{ 18 | \"runId\": \"$model_id\", 19 | \"document\": \"$document\" 20 | }" \ 21 | localhost:"${PS_PORT}" prediction.PredictionService/Predict 22 | 23 | -------------------------------------------------------------------------------- /scripts/lab-005-second-training.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$1" == "" ]; then 5 | echo "Requires dataset_id as the first parameter" 6 | exit 1 7 | fi 8 | 9 | dataset_id=$1 10 | echo "dataset_id is $dataset_id" 11 | 12 | function update_dataset() { 13 | grpcurl -plaintext \ 14 | -d "{ 15 | \"dataset_id\": \"$1\", 16 | \"commit_message\": \"tweet_emotion_part2\", 17 | \"bucket\": \"${MINIO_DM_BUCKET}\", 18 | \"path\": \"upload/tweet_emotion_part2.csv\" 19 | }" \ 20 | localhost:"${DM_PORT}" data_management.DataManagementService/UpdateDataset 21 | } 22 | 23 | update_dataset "$dataset_id" 24 | 25 | function prepare_dataset() { 26 | grpcurl -plaintext \ 27 | -d "{\"dataset_id\": \"$1\"}" \ 28 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 29 | } 30 | 31 | _temp=$(prepare_dataset "$dataset_id") 32 | version_hash=$(echo -n "$_temp" | jq ".version_hash") 33 | 34 | echo "version_hash is $version_hash" 35 | 36 | function start_training() { 37 | grpcurl -plaintext \ 38 | -d "{ 39 | \"metadata\": { 40 | \"algorithm\":\"intent-classification\", 41 | \"dataset_id\":\"$1\", 42 | \"name\":\"test1\", 43 | \"train_data_version_hash\":$2, 44 | \"output_model_name\":\"twitter-model\", 45 | \"parameters\": { 46 | \"LR\":\"50\", 47 | \"EPOCHS\":\"15\", 48 | \"BATCH_SIZE\":\"64\", 49 | \"FC_SIZE\":\"1024\" 50 | } 51 | } 52 | }" \ 53 | localhost:"${TS_PORT}" training.TrainingService/Train 54 | 55 | } 56 | 57 | _temp=$(start_training "$dataset_id" "$version_hash") 58 | job_id=$(echo -n "$_temp" | jq ".job_id") 59 | echo "job_id is $job_id" 60 | 61 | function check_job_status() { 62 | grpcurl -plaintext \ 63 | -d "{\"job_id\": \"$1\"}" \ 64 | localhost:"${TS_PORT}" training.TrainingService/GetTrainingStatus 65 | } 66 | 67 | job_status="unknown" 68 | until [ "$job_status" == "\"failure\"" ] || [ "$job_status" == "\"succeed\"" ]; 69 | do 70 | echo "job $job_id is currently in $job_status status, check back in 5 seconds" 71 | sleep 5 72 | _temp=$(check_job_status "$job_id") 73 | job_status=$(echo -n "$_temp" | jq ".status") 74 | done 75 | 76 | grpcurl -plaintext \ 77 | -d "{\"run_id\": \"$job_id\"}" \ 78 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/GetRunStatus 79 | 80 | grpcurl -plaintext \ 81 | -d "{ 82 | \"runId\": \"$job_id\", 83 | \"document\": \"You can have a certain #arrogance, and I think that's fine, but what you should never lose is the #respect for the others.\" 84 | }" \ 85 | localhost:"${PS_PORT}" prediction.PredictionService/Predict 86 | 87 | -------------------------------------------------------------------------------- /scripts/lab-006-model-serving-torchserve.sh: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /scripts/lab-999-tear-down.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | function tear_down() { 5 | container_name=$1 6 | if docker ps -a | grep -q "$container_name" ; then 7 | docker stop "$container_name" > /dev/null 2>&1 8 | docker rm "$container_name" > /dev/null 2>&1 9 | fi 10 | } 11 | 12 | tear_down "minio" 13 | tear_down "data-management" 14 | tear_down "prediction-service" 15 | tear_down "metadata-store" 16 | tear_down "training-service" 17 | tear_down "intent-classification-predictor" 18 | tear_down "intent-classification-torch-predictor" 19 | -------------------------------------------------------------------------------- /scripts/ms-001-start-minio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ ! "$(docker network ls | grep orca3)" ]; then 5 | docker network create orca3 6 | echo "Created docker network orca3" 7 | else 8 | echo "Docker network orca3 already exists" 9 | fi 10 | 11 | if [ ! "$(docker ps -a | grep minio)" ]; then 12 | docker run --name minio --network orca3 --rm -d -p "${MINIO_PORT}":9000 -e MINIO_ROOT_USER -e MINIO_ROOT_PASSWORD minio/minio server /data 13 | echo "Started minio docker container and listen on port 9000" 14 | else 15 | echo "Minio docker container is already running" 16 | fi 17 | -------------------------------------------------------------------------------- /scripts/ms-002-start-server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ ! "$(docker ps -a | grep metadata-store)" ]; then 5 | docker run --name metadata-store \ 6 | --network orca3 \ 7 | --rm -d \ 8 | -p "${MS_PORT}":51001 \ 9 | "${IMAGE_NAME}" \ 10 | metadata-store.jar 11 | echo "Started metadata-store docker container and listen on port ${MS_PORT}" 12 | else 13 | echo "metadata-store docker container is already running" 14 | fi 15 | -------------------------------------------------------------------------------- /scripts/ms-003-start-run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | run_id="${1:-1}" 6 | echo "Starting run $run_id" 7 | grpcurl -plaintext \ 8 | -d "{ 9 | \"start_time\": \"2021-08-01T00:00:00Z\", 10 | \"run_id\": \"$run_id\", 11 | \"run_name\": \"demo-run\", 12 | \"tracing\": { 13 | \"dataset_id\": \"1\", 14 | \"version_hash\": \"hashBA==\", 15 | \"code_version\": \"12a3bfd\" 16 | } 17 | }" \ 18 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/LogRunStart 19 | -------------------------------------------------------------------------------- /scripts/ms-004-post-epoch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | run_id="${1:-1}" 6 | epoch_id="${2:-1}" 7 | echo "Posting epoch $epoch_id for run $run_id" 8 | grpcurl -plaintext \ 9 | -d "{ 10 | \"epoch_info\": { 11 | \"start_time\":\"2021-08-01T01:00:00Z\", 12 | \"end_time\":\"2021-08-01T01:30:00Z\", 13 | \"run_id\":\"$run_id\", 14 | \"epoch_id\":\"$epoch_id\", 15 | \"metrics\": {\"foo_metrics\":\"bar_value\"} 16 | } 17 | }" \ 18 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/LogEpoch 19 | -------------------------------------------------------------------------------- /scripts/ms-005-check-run-status.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | run_id="${1:-1}" 6 | echo "Checking run $run_id's status" 7 | grpcurl -plaintext \ 8 | -d "{ 9 | \"run_id\":\"$run_id\" 10 | }" \ 11 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/GetRunStatus 12 | -------------------------------------------------------------------------------- /scripts/ms-006-finish-run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | run_id="${1:-1}" 6 | echo "Finishing run $run_id" 7 | grpcurl -plaintext \ 8 | -d "{ 9 | \"end_time\":\"2021-08-01T02:00:00Z\", 10 | \"run_id\":\"$run_id\", 11 | \"success\":true, 12 | \"message\":\"$run_id successfully completed\" 13 | }" \ 14 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/LogRunEnd 15 | -------------------------------------------------------------------------------- /scripts/ms-008-get-artifact.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | if [ "$1" != "" ] && [ "$2" != "" ]; then 6 | name="${1}" 7 | version="${2}" 8 | echo "GetArtifact $name version $version" 9 | grpcurl -plaintext \ 10 | -d "{ 11 | \"name\":\"$name\", 12 | \"version\":\"$version\" 13 | }" \ 14 | localhost:"${MS_PORT}" metadata_store.MetadataStoreService/GetArtifact 15 | else 16 | echo "Requires artifact_name as the first parameter" 17 | echo "Requires artifact_version as the second parameter" 18 | 19 | fi 20 | -------------------------------------------------------------------------------- /scripts/prepare_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | from datasets import load_dataset 4 | dataset = load_dataset('tweet_eval', 'emotion', split='train') 5 | dataset = dataset.map(lambda x: {'label': dataset.features['label'].names[x['label']]}) 6 | part_1 = dataset.filter(lambda x: x['label'] != 'optimism') 7 | part_2 = dataset.filter(lambda x: x['label'] == 'optimism') 8 | part_1.to_pandas().to_csv('tweet_emotion_part1.csv', header=False, index=False, quoting=csv.QUOTE_ALL) 9 | part_2.to_pandas().to_csv('tweet_emotion_part2.csv', header=False, index=False, quoting=csv.QUOTE_ALL) 10 | -------------------------------------------------------------------------------- /scripts/ps-001-start-predictor.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | mkdir -p model_cache 5 | MODEL_CACHE_DIR="$(pwd)/model_cache" 6 | 7 | if [ ! "$(docker ps -a | grep intent-classification-predictor)" ]; then 8 | docker run --name intent-classification-predictor \ 9 | --network orca3 \ 10 | --rm -d \ 11 | -p "${ICP_PORT}":51001 \ 12 | -v "${MODEL_CACHE_DIR}":/models \ 13 | orca3/intent-classification-predictor:latest 14 | echo "Started intent-classification-predictor docker container and listen on port ${ICP_PORT}" 15 | else 16 | echo "intent-classification-predictor docker container is already running" 17 | fi 18 | 19 | if [ ! "$(docker ps -a | grep intent-classification-torch-predictor)" ]; then 20 | docker run --name intent-classification-torch-predictor \ 21 | --network orca3 \ 22 | --rm -d \ 23 | -p "${ICP_TORCH_PORT}":7070 -p "${ICP_TORCH_MGMT_PORT}":7071 \ 24 | -v "${MODEL_CACHE_DIR}":/models \ 25 | -v "$(pwd)/config/torch_server_config.properties":/home/model-server/config.properties \ 26 | pytorch/torchserve:0.5.2-cpu torchserve \ 27 | --start --model-store /models 28 | echo "Started intent-classification-torch-predictor docker container and listen on port ${ICP_TORCH_PORT} & ${ICP_TORCH_MGMT_PORT}" 29 | else 30 | echo "intent-classification-torch-predictor docker container is already running" 31 | fi 32 | 33 | -------------------------------------------------------------------------------- /scripts/ps-002-start-server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | mkdir -p model_cache 5 | MODEL_CACHE_DIR="$(pwd)/model_cache" 6 | 7 | if [ ! "$(docker ps -a | grep prediction-service)" ]; then 8 | docker run --name prediction-service \ 9 | --network orca3 \ 10 | --rm -d \ 11 | -p "${PS_PORT}":51001 \ 12 | -v "${MODEL_CACHE_DIR}":/tmp/modelCache \ 13 | "${IMAGE_NAME}" \ 14 | prediction-service.jar 15 | echo "Started prediction-service docker container and listen on port ${PS_PORT}" 16 | else 17 | echo "prediction-service docker container is already running" 18 | fi 19 | -------------------------------------------------------------------------------- /scripts/ps-003-predict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | if [ "$1" != "" ]; then 6 | echo "Check the status of run $1" 7 | grpcurl -plaintext \ 8 | -d "{ 9 | \"runId\": \"$1\", 10 | \"document\": \"merry chirstmas\" 11 | }" \ 12 | localhost:"${PS_PORT}" prediction.PredictionService/Predict 13 | else 14 | echo "Requires run_id as the first parameter" 15 | fi 16 | 17 | -------------------------------------------------------------------------------- /scripts/ts-001-start-server-kube.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$(docker ps | grep training-service)" ]; then 5 | echo "training-service docker container is already running, stop it" 6 | docker stop training-service 7 | fi 8 | 9 | if [ ! "$(docker ps -a | grep training-service)" ]; then 10 | docker run --name training-service \ 11 | --network orca3 \ 12 | --rm -d \ 13 | -p "${TS_PORT}":51001 \ 14 | -v $HOME/.kube/config:/.kube/config \ 15 | --env APP_CONFIG=config/config-docker-kube.properties \ 16 | "${IMAGE_NAME}" \ 17 | training-service.jar 18 | echo "Started training-service docker container and listen on port ${TS_PORT}" 19 | else 20 | echo "training-service docker container is already running" 21 | fi 22 | -------------------------------------------------------------------------------- /scripts/ts-001-start-server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | if [ "$(docker ps | grep training-service)" ]; then 5 | echo "training-service docker container is already running, stop it" 6 | docker stop training-service 7 | fi 8 | 9 | if [ ! "$(docker ps -a | grep training-service)" ]; then 10 | docker run --name training-service \ 11 | --network orca3 \ 12 | --rm -d \ 13 | -p "${TS_PORT}":51001 \ 14 | -v /var/run/docker.sock:/var/run/docker.sock \ 15 | "${IMAGE_NAME}" \ 16 | training-service.jar 17 | echo "Started training-service docker container and listen on port ${TS_PORT}" 18 | else 19 | echo "training-service docker container is already running" 20 | fi 21 | -------------------------------------------------------------------------------- /scripts/ts-002-start-run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo "dataset_id is $1" 5 | dataset_id=$1 6 | 7 | function prepare_dataset() { 8 | grpcurl -plaintext \ 9 | -d "{\"dataset_id\": \"$1\"}" \ 10 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 11 | } 12 | 13 | _temp=$(prepare_dataset "$dataset_id") 14 | version_hash=$(echo -n "$_temp" | jq ".version_hash") 15 | 16 | echo "version_hash is $version_hash" 17 | 18 | echo 19 | grpcurl -plaintext \ 20 | -d "{ 21 | \"metadata\": { 22 | \"algorithm\":\"intent-classification\", 23 | \"dataset_id\":\"$1\", 24 | \"name\":\"test1\", 25 | \"train_data_version_hash\":$version_hash, 26 | \"output_model_name\":\"my-intent-classification-model\", 27 | \"parameters\": { 28 | \"LR\":\"4\", 29 | \"EPOCHS\":\"15\", 30 | \"BATCH_SIZE\":\"64\", 31 | \"FC_SIZE\":\"128\" 32 | } 33 | } 34 | }" \ 35 | localhost:"${TS_PORT}" training.TrainingService/Train 36 | -------------------------------------------------------------------------------- /scripts/ts-003-check-run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | if [ "$1" != "" ]; then 6 | echo "Check the status of run $1" 7 | grpcurl -plaintext \ 8 | -d "{\"job_id\": \"$1\"}" \ 9 | localhost:"${TS_PORT}" training.TrainingService/GetTrainingStatus 10 | else 11 | echo "Requires run_id as the first parameter" 12 | fi 13 | -------------------------------------------------------------------------------- /scripts/ts-004-start-parallel-run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo "dataset_id is $1" 5 | dataset_id=$1 6 | 7 | function prepare_dataset() { 8 | grpcurl -plaintext \ 9 | -d "{\"dataset_id\": \"$1\"}" \ 10 | localhost:"${DM_PORT}" data_management.DataManagementService/PrepareTrainingDataset 11 | } 12 | 13 | _temp=$(prepare_dataset "$dataset_id") 14 | version_hash=$(echo -n "$_temp" | jq ".version_hash") 15 | 16 | echo "version_hash is $version_hash" 17 | 18 | echo 19 | grpcurl -plaintext \ 20 | -d "{ 21 | \"metadata\": { 22 | \"algorithm\":\"intent-classification\", 23 | \"dataset_id\":\"$1\", 24 | \"name\":\"test-parallel\", 25 | \"train_data_version_hash\":$version_hash, 26 | \"output_model_name\":\"my-parallel-intent-classification-model\", 27 | \"parameters\": { 28 | \"LR\":\"4\", 29 | \"EPOCHS\":\"10\", 30 | \"BATCH_SIZE\":\"64\", 31 | \"PARALLEL_INSTANCES\":\"3\", 32 | \"FC_SIZE\":\"128\" 33 | } 34 | } 35 | }" \ 36 | localhost:${TS_PORT} training.TrainingService/Train 37 | -------------------------------------------------------------------------------- /scripts/ts-005-start-run-as-torch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source "$(dirname "$0")/env-vars.sh" 3 | 4 | echo 5 | grpcurl -plaintext \ 6 | -d "{ 7 | \"metadata\": { 8 | \"algorithm\":\"intent-classification-torch\", 9 | \"dataset_id\":\"1\", 10 | \"name\":\"test1\", 11 | \"train_data_version_hash\":\"hashDg==\", 12 | \"output_model_name\":\"my-intent-classification-model\", 13 | \"parameters\": { 14 | \"LR\":\"4\", 15 | \"EPOCHS\":\"15\", 16 | \"BATCH_SIZE\":\"64\", 17 | \"FC_SIZE\":\"128\" 18 | } 19 | } 20 | }" \ 21 | localhost:"${TS_PORT}" training.TrainingService/Train 22 | -------------------------------------------------------------------------------- /services.dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | # on local, try docker build -t orca3/services:latest -f services.dockerfile . 3 | FROM openjdk:11 AS builder 4 | WORKDIR /app 5 | COPY .mvn/ .mvn 6 | COPY mvnw pom.xml ./ 7 | COPY data-management data-management 8 | COPY grpc-contract grpc-contract 9 | COPY metadata-store metadata-store 10 | COPY training-service training-service 11 | COPY prediction-service prediction-service 12 | RUN ./mvnw package -DskipTests 13 | 14 | FROM openjdk:11 AS run 15 | WORKDIR /app 16 | RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \ 17 | wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \ 18 | chmod +x /bin/grpc_health_probe 19 | COPY --from=builder /app/training-service/target/training-service-1.0-SNAPSHOT.jar ./training-service.jar 20 | COPY --from=builder /app/data-management/target/data-management-1.0-SNAPSHOT.jar ./data-management.jar 21 | COPY --from=builder /app/metadata-store/target/metadata-store-1.0-SNAPSHOT.jar ./metadata-store.jar 22 | COPY --from=builder /app/prediction-service/target/prediction-service-1.0-SNAPSHOT.jar ./prediction-service.jar 23 | COPY config ./config 24 | ENV APP_CONFIG config/config-docker-docker.properties 25 | 26 | ENTRYPOINT ["java", "-jar"] 27 | -------------------------------------------------------------------------------- /training-code/text-classification/Dockerfile: -------------------------------------------------------------------------------- 1 | # docker build -t orca3/intent-classification-predictor:latest -f predictor/Dockerfile predictor 2 | FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime 3 | 4 | RUN pip3 install minio protobuf~=3.20.0 grpcio torch-model-archiver 5 | ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python 6 | 7 | RUN mkdir /opt/intent-classification 8 | COPY *.py /opt/intent-classification/ 9 | WORKDIR /opt/intent-classification 10 | 11 | # Add folder for the logs. 12 | RUN mkdir /model 13 | RUN mkdir /logs 14 | 15 | RUN chgrp -R 0 /opt/intent-classification \ 16 | && chmod -R g+rwX /opt/intent-classification \ 17 | && chgrp -R 0 /model \ 18 | && chmod -R g+rwX /model \ 19 | && chgrp -R 0 /logs \ 20 | && chmod -R g+rwX /logs 21 | 22 | ENTRYPOINT ["python3", "-u", "/opt/intent-classification/train.py"] 23 | -------------------------------------------------------------------------------- /training-code/text-classification/Readme.md: -------------------------------------------------------------------------------- 1 | ## Set up local python env 2 | ``` 3 | conda create --name intent-pytorch python=3.7 4 | conda activate intent-pytorch 5 | conda install -c pytorch pytorch torchtext torch-model-archiver 6 | conda install -c conda-forge minio 7 | 8 | conda activate pytorch-p3.8 9 | ``` 10 | ### Run training python script 11 | 12 | `examples.csv` and `labels.csv` are the training dataset, we need to upload these training data to minio for running our training script below: (PS: The "Build and run docker image" section will cover how to upload data to minio server) 13 | 14 | ``` 15 | # single process training 16 | python train.py 17 | python train.py 0 1 18 | 19 | # distributed training with two processes, first parameter is RANK, second is WORLD_SIZE. 20 | # RANK=0 is master 21 | python train.py 0 2 22 | python trian.py 1 2 23 | ``` 24 | 25 | ## Build and run docker image 26 | Run `scripts/ts-000-build-trainer.sh` to build image, image name is `localhost:3000/orca3/intent-classification`. 27 | 28 | Then start minio server by running `scripts/ms-001-start-minio.sh`. And upload the training dataset. 29 | 30 | ``` 31 | mc mb minio/mini-automl-dm 32 | mc mb minio/mini-automl-ms 33 | mc cp examples.csv minio/mini-automl-dm/versionedDatasets/1/hashDg==/examples.csv 34 | mc cp labels.csv minio/mini-automl-dm/versionedDatasets/1/hashDg==/labels.csv 35 | ``` 36 | Last step, run distributed training with two docker instances. 37 | 38 | ``` 39 | docker run --name trainer1 --hostname=trainer1 --rm --net orca3 -p 12356:12356 -e MASTER_ADDR="trainer1" -e WORLD_SIZE=2 -e RANK=0 -e MASTER_PORT=12356 -e MINIO_SERVER="minio:9000" -e TRAINING_DATA_PATH="versionedDatasets/1/hashDg==/" -it localhost:3000/orca3/intent-classification 40 | 41 | docker run --name trainer2 --hostname=trainer2 --rm --net orca3 -e MASTER_ADDR="trainer1" -e WORLD_SIZE="2" -e RANK=1 -e MASTER_PORT=12356 -e MINIO_SERVER="minio:9000" -e TRAINING_DATA_PATH="versionedDatasets/1/hashDg==/" -it localhost:3000/orca3/intent-classification 42 | ``` 43 | 44 | Single docker container training 45 | ``` 46 | docker run --name trainer1 --hostname=trainer1 --rm --net orca3 -p 12356:12356 -e WORLD_SIZE=1 -e MASTER_PORT=12356 -e MINIO_SERVER="minio:9000" -e TRAINING_DATA_PATH="versionedDatasets/1/hashDg==/" -it localhost:3000/orca3/intent-classification 47 | ``` 48 | -------------------------------------------------------------------------------- /training-code/text-classification/labels.csv: -------------------------------------------------------------------------------- 1 | "1","cancel" 2 | "2","ingredients_list" 3 | "3","nutrition_info" 4 | "4","greeting" 5 | "5","taxes" 6 | "6","change_volume" 7 | "7","order_status" 8 | "8","change_speed" 9 | "9","find_phone" 10 | "10","balance" 11 | "11","plug_type" 12 | "12","weather" 13 | "13","interest_rate" 14 | "14","flip_coin" 15 | "15","measurement_conversion" 16 | "16","food_last" 17 | "17","text" 18 | "18","meeting_schedule" 19 | "19","direct_deposit" 20 | "20","order" 21 | "21","calculator" 22 | "22","apr" 23 | "23","pto_used" 24 | "24","account_blocked" 25 | "25","change_language" 26 | "26","exchange_rate" 27 | "27","mpg" 28 | "28","restaurant_suggestion" 29 | "29","transactions" 30 | "30","play_music" 31 | "31","insurance_change" 32 | "32","tire_change" 33 | "33","flight_status" 34 | "34","report_fraud" 35 | "35","calendar_update" 36 | "36","thank_you" 37 | "37","change_ai_name" 38 | "38","fun_fact" 39 | "39","shopping_list_update" 40 | "40","change_accent" 41 | "41","how_old_are_you" 42 | "42","improve_credit_score" 43 | "43","credit_score" 44 | "44","confirm_reservation" 45 | "45","international_fees" 46 | "46","maybe" 47 | "47","user_name" 48 | "48","goodbye" 49 | "49","book_flight" 50 | "50","how_busy" 51 | "51","cancel_reservation" 52 | "52","routing" 53 | "53","schedule_maintenance" 54 | "54","restaurant_reviews" 55 | "55","schedule_meeting" 56 | "56","lost_luggage" 57 | "57","international_visa" 58 | "58","what_song" 59 | "59","payday" 60 | "60","tire_pressure" 61 | "61","update_playlist" 62 | "62","new_card" 63 | "63","calendar" 64 | "64","order_checks" 65 | "65","who_made_you" 66 | "66","last_maintenance" 67 | "67","oil_change_how" 68 | "68","reminder" 69 | "69","gas_type" 70 | "70","credit_limit_change" 71 | "71","next_song" 72 | "72","book_hotel" 73 | "73","spending_history" 74 | "74","travel_suggestion" 75 | "75","rollover_401k" 76 | "76","todo_list_update" 77 | "77","make_call" 78 | "78","insurance" 79 | "79","date" 80 | "80","who_do_you_work_for" 81 | "81","damaged_card" 82 | "82","meaning_of_life" 83 | "83","min_payment" 84 | "84","expiration_date" 85 | "85","translate" 86 | "86","are_you_a_bot" 87 | "87","cook_time" 88 | "88","oos" 89 | "89","application_status" 90 | "90","w2" 91 | "91","todo_list" 92 | "92","tell_joke" 93 | "93","pay_bill" 94 | "94","current_location" 95 | "95","traffic" 96 | "96","where_are_you_from" 97 | "97","change_user_name" 98 | "98","pin_change" 99 | "99","what_is_your_name" 100 | "100","travel_alert" 101 | "101","report_lost_card" 102 | "102","share_location" 103 | "103","bill_balance" 104 | "104","shopping_list" 105 | "105","pto_request_status" 106 | "106","next_holiday" 107 | "107","reminder_update" 108 | "108","transfer" 109 | "109","sync_device" 110 | "110","directions" 111 | "111","bill_due" 112 | "112","car_rental" 113 | "113","card_declined" 114 | "114","reset_settings" 115 | "115","replacement_card_duration" 116 | "116","income" 117 | "117","no" 118 | "118","smart_home" 119 | "119","distance" 120 | "120","vaccines" 121 | "121","timezone" 122 | "122","meal_suggestion" 123 | "123","recipe" 124 | "124","rewards_balance" 125 | "125","uber" 126 | "126","oil_change_when" 127 | "127","timer" 128 | "128","pto_balance" 129 | "129","repeat" 130 | "130","ingredient_substitution" 131 | "131","alarm" 132 | "132","credit_limit" 133 | "133","gas" 134 | "134","accept_reservations" 135 | "135","definition" 136 | "136","redeem_rewards" 137 | "137","pto_request" 138 | "138","carry_on" 139 | "139","roll_dice" 140 | "140","jump_start" 141 | "141","spelling" 142 | "142","yes" 143 | "143","what_can_i_ask_you" 144 | "144","whisper_mode" 145 | "145","calories" 146 | "146","what_are_your_hobbies" 147 | "147","freeze_account" 148 | "148","travel_notification" 149 | "149","restaurant_reservation" 150 | "150","do_you_have_pets" 151 | "151","time" 152 | -------------------------------------------------------------------------------- /training-code/text-classification/prediction_service_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import prediction_service_pb2 as prediction__service__pb2 6 | 7 | 8 | class PredictionServiceStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Predict = channel.unary_unary( 18 | '/prediction.PredictionService/Predict', 19 | request_serializer=prediction__service__pb2.PredictRequest.SerializeToString, 20 | response_deserializer=prediction__service__pb2.PredictResponse.FromString, 21 | ) 22 | 23 | 24 | class PredictionServiceServicer(object): 25 | """Missing associated documentation comment in .proto file.""" 26 | 27 | def Predict(self, request, context): 28 | """Missing associated documentation comment in .proto file.""" 29 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 30 | context.set_details('Method not implemented!') 31 | raise NotImplementedError('Method not implemented!') 32 | 33 | 34 | def add_PredictionServiceServicer_to_server(servicer, server): 35 | rpc_method_handlers = { 36 | 'Predict': grpc.unary_unary_rpc_method_handler( 37 | servicer.Predict, 38 | request_deserializer=prediction__service__pb2.PredictRequest.FromString, 39 | response_serializer=prediction__service__pb2.PredictResponse.SerializeToString, 40 | ), 41 | } 42 | generic_handler = grpc.method_handlers_generic_handler( 43 | 'prediction.PredictionService', rpc_method_handlers) 44 | server.add_generic_rpc_handlers((generic_handler,)) 45 | 46 | 47 | # This class is part of an EXPERIMENTAL API. 48 | class PredictionService(object): 49 | """Missing associated documentation comment in .proto file.""" 50 | 51 | @staticmethod 52 | def Predict(request, 53 | target, 54 | options=(), 55 | channel_credentials=None, 56 | call_credentials=None, 57 | insecure=False, 58 | compression=None, 59 | wait_for_ready=None, 60 | timeout=None, 61 | metadata=None): 62 | return grpc.experimental.unary_unary(request, target, '/prediction.PredictionService/Predict', 63 | prediction__service__pb2.PredictRequest.SerializeToString, 64 | prediction__service__pb2.PredictResponse.FromString, 65 | options, channel_credentials, 66 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 67 | 68 | 69 | class PredictorStub(object): 70 | """Missing associated documentation comment in .proto file.""" 71 | 72 | def __init__(self, channel): 73 | """Constructor. 74 | 75 | Args: 76 | channel: A grpc.Channel. 77 | """ 78 | self.PredictorPredict = channel.unary_unary( 79 | '/prediction.Predictor/PredictorPredict', 80 | request_serializer=prediction__service__pb2.PredictorPredictRequest.SerializeToString, 81 | response_deserializer=prediction__service__pb2.PredictorPredictResponse.FromString, 82 | ) 83 | 84 | 85 | class PredictorServicer(object): 86 | """Missing associated documentation comment in .proto file.""" 87 | 88 | def PredictorPredict(self, request, context): 89 | """Missing associated documentation comment in .proto file.""" 90 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 91 | context.set_details('Method not implemented!') 92 | raise NotImplementedError('Method not implemented!') 93 | 94 | 95 | def add_PredictorServicer_to_server(servicer, server): 96 | rpc_method_handlers = { 97 | 'PredictorPredict': grpc.unary_unary_rpc_method_handler( 98 | servicer.PredictorPredict, 99 | request_deserializer=prediction__service__pb2.PredictorPredictRequest.FromString, 100 | response_serializer=prediction__service__pb2.PredictorPredictResponse.SerializeToString, 101 | ), 102 | } 103 | generic_handler = grpc.method_handlers_generic_handler( 104 | 'prediction.Predictor', rpc_method_handlers) 105 | server.add_generic_rpc_handlers((generic_handler,)) 106 | 107 | 108 | # This class is part of an EXPERIMENTAL API. 109 | class Predictor(object): 110 | """Missing associated documentation comment in .proto file.""" 111 | 112 | @staticmethod 113 | def PredictorPredict(request, 114 | target, 115 | options=(), 116 | channel_credentials=None, 117 | call_credentials=None, 118 | insecure=False, 119 | compression=None, 120 | wait_for_ready=None, 121 | timeout=None, 122 | metadata=None): 123 | return grpc.experimental.unary_unary(request, target, '/prediction.Predictor/PredictorPredict', 124 | prediction__service__pb2.PredictorPredictRequest.SerializeToString, 125 | prediction__service__pb2.PredictorPredictResponse.FromString, 126 | options, channel_credentials, 127 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 128 | -------------------------------------------------------------------------------- /training-code/text-classification/torchserve_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import os 4 | import logging 5 | 6 | from torch import nn 7 | from torchtext.data.utils import get_tokenizer 8 | from ts.torch_handler.base_handler import BaseHandler 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class ModelHandler(BaseHandler): 13 | """ 14 | A custom model handler implementation for serving intent classification prediction 15 | in torch serving server. 16 | """ 17 | 18 | class TextClassificationModel(nn.Module): 19 | def __init__(self, vocab_size, embed_dim, fc_size, num_class): 20 | super(ModelHandler.TextClassificationModel, self).__init__() 21 | self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) 22 | self.fc1 = nn.Linear(embed_dim, fc_size) 23 | self.fc2 = nn.Linear(fc_size, fc_size) 24 | self.fc3 = nn.Linear(fc_size, num_class) 25 | self.init_weights() 26 | 27 | def init_weights(self): 28 | initrange = 0.5 29 | self.embedding.weight.data.uniform_(-initrange, initrange) 30 | self.fc1.weight.data.uniform_(-initrange, initrange) 31 | self.fc1.bias.data.zero_() 32 | self.fc2.weight.data.uniform_(-initrange, initrange) 33 | self.fc2.bias.data.zero_() 34 | self.fc3.weight.data.uniform_(-initrange, initrange) 35 | self.fc3.bias.data.zero_() 36 | 37 | def forward(self, text, offsets): 38 | embedded = self.embedding(text, offsets) 39 | return self.fc3(self.fc2(self.fc1(embedded))) 40 | 41 | def __init__(self): 42 | self.context = None 43 | self.model = None 44 | self.initialized = False 45 | self.fcsize = 128 46 | self.manifest = None 47 | self.tokenizer = get_tokenizer('basic_english') 48 | 49 | def initialize(self, ctx): 50 | """ 51 | Initialize model. This will be called during model loading time 52 | :param context: Initial context contains model server system properties. 53 | :return: 54 | """ 55 | 56 | self.context = ctx 57 | properties = ctx.system_properties 58 | model_dir = properties.get("model_dir") 59 | model_path = os.path.join(model_dir, "model.pth") 60 | vacab_path = os.path.join(model_dir, "vocab.pth") 61 | manifest_path = os.path.join(model_dir, "manifest.json") 62 | 63 | # load vocabulary 64 | self.vocab = torch.load(vacab_path) 65 | 66 | # load model manifest, including label index map. 67 | with open(manifest_path, 'r') as f: 68 | self.manifest = json.loads(f.read()) 69 | classes = self.manifest['classes'] 70 | 71 | num_class = len(classes) 72 | vocab_size = len(self.vocab) 73 | emsize = 64 74 | self.model = self.TextClassificationModel(vocab_size, emsize, self.fcsize, num_class).to("cpu") 75 | self.model.load_state_dict(torch.load(model_path)) 76 | self.model.eval() 77 | 78 | logger.info('intent classification model file loaded successfully') 79 | self.initialized = True 80 | 81 | def preprocess(self, data): 82 | """ 83 | Transform raw input into model input data. 84 | :param batch: list of raw requests, should match batch size 85 | :return: list of preprocessed model input data 86 | """ 87 | # Take the input data and make it inference ready 88 | logger.info('data={}'.format(data)) 89 | 90 | preprocessed_data = data[0].get("data") 91 | if preprocessed_data is None: 92 | preprocessed_data = data[0].get("body") 93 | 94 | text_pipeline = lambda x: self.vocab(self.tokenizer(x)) 95 | 96 | user_input = " ".join(str(preprocessed_data)) 97 | processed_text = torch.tensor(text_pipeline(user_input), dtype=torch.int64) 98 | offsets = [0, processed_text.size(0)] 99 | offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) 100 | 101 | logger.info('UserInput={}; TensorInput={}; Offset={}'.format(user_input, processed_text, offsets)) 102 | return (processed_text, offsets) 103 | 104 | def inference(self, model_input): 105 | """ 106 | Internal inference methods 107 | :param model_input: transformed model input data 108 | :return: list of inference output in NDArray 109 | """ 110 | # Do some inference call to engine here and return output 111 | model_output = self.model.forward(model_input[0], model_input[1]) 112 | return model_output 113 | 114 | def postprocess(self, inference_output): 115 | """ 116 | Return inference result. 117 | :param inference_output: list of inference output 118 | :return: list of predict results 119 | """ 120 | # Take output from network and post-process to desired format 121 | res_index = inference_output.argmax(1).item() 122 | logger.info("return {}".format(res_index)) 123 | classes = self.manifest['classes'] 124 | postprocess_output = classes[str(res_index)] 125 | return [{"result":postprocess_output}] 126 | 127 | def handle(self, data, context): 128 | """ 129 | Invoke by TorchServe for prediction request. 130 | Do pre-processing of data, prediction using model and postprocessing of prediciton output 131 | :param data: Input data for prediction 132 | :param context: Initial context contains model server system properties. 133 | :return: prediction output 134 | """ 135 | model_input = self.preprocess(data) 136 | model_output = self.inference(model_input) 137 | return self.postprocess(model_output) 138 | 139 | 140 | ## local test 141 | # class Context: 142 | # system_properties={"model_dir":"/Users/chi.wang/workspace/cw/book/MiniAutoML/42"} 143 | 144 | # class PredictPayload: 145 | # def get(self, str): 146 | # return "make a 10 minute timer" 147 | 148 | # ctx = Context() 149 | # handler = ModelHandler() 150 | # handler.initialize(ctx) 151 | 152 | # print("prediction={}".format(handler.handle([PredictPayload()], ctx))) 153 | 154 | ## torch serve package command 155 | # torch-model-archiver --model-name intent_classification --version 1.0 --model-file torchserve_model.py --serialized-file /Users/chi.wang/workspace/cw/book/MiniAutoML/42/model.pth --handler torchserve_handler.py --extra-files /Users/chi.wang/workspace/cw/book/MiniAutoML/42/vocab.pth,/Users/chi.wang/workspace/cw/book/MiniAutoML/42/manifest.json 156 | -------------------------------------------------------------------------------- /training-code/text-classification/training_service_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import training_service_pb2 as training__service__pb2 6 | 7 | 8 | class TrainingServiceStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Train = channel.unary_unary( 18 | '/training.TrainingService/Train', 19 | request_serializer=training__service__pb2.TrainRequest.SerializeToString, 20 | response_deserializer=training__service__pb2.TrainResponse.FromString, 21 | ) 22 | self.GetTrainingStatus = channel.unary_unary( 23 | '/training.TrainingService/GetTrainingStatus', 24 | request_serializer=training__service__pb2.GetTrainingStatusRequest.SerializeToString, 25 | response_deserializer=training__service__pb2.GetTrainingStatusResponse.FromString, 26 | ) 27 | 28 | 29 | class TrainingServiceServicer(object): 30 | """Missing associated documentation comment in .proto file.""" 31 | 32 | def Train(self, request, context): 33 | """Missing associated documentation comment in .proto file.""" 34 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 35 | context.set_details('Method not implemented!') 36 | raise NotImplementedError('Method not implemented!') 37 | 38 | def GetTrainingStatus(self, request, context): 39 | """Missing associated documentation comment in .proto file.""" 40 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 41 | context.set_details('Method not implemented!') 42 | raise NotImplementedError('Method not implemented!') 43 | 44 | 45 | def add_TrainingServiceServicer_to_server(servicer, server): 46 | rpc_method_handlers = { 47 | 'Train': grpc.unary_unary_rpc_method_handler( 48 | servicer.Train, 49 | request_deserializer=training__service__pb2.TrainRequest.FromString, 50 | response_serializer=training__service__pb2.TrainResponse.SerializeToString, 51 | ), 52 | 'GetTrainingStatus': grpc.unary_unary_rpc_method_handler( 53 | servicer.GetTrainingStatus, 54 | request_deserializer=training__service__pb2.GetTrainingStatusRequest.FromString, 55 | response_serializer=training__service__pb2.GetTrainingStatusResponse.SerializeToString, 56 | ), 57 | } 58 | generic_handler = grpc.method_handlers_generic_handler( 59 | 'training.TrainingService', rpc_method_handlers) 60 | server.add_generic_rpc_handlers((generic_handler,)) 61 | 62 | 63 | # This class is part of an EXPERIMENTAL API. 64 | class TrainingService(object): 65 | """Missing associated documentation comment in .proto file.""" 66 | 67 | @staticmethod 68 | def Train(request, 69 | target, 70 | options=(), 71 | channel_credentials=None, 72 | call_credentials=None, 73 | insecure=False, 74 | compression=None, 75 | wait_for_ready=None, 76 | timeout=None, 77 | metadata=None): 78 | return grpc.experimental.unary_unary(request, target, '/training.TrainingService/Train', 79 | training__service__pb2.TrainRequest.SerializeToString, 80 | training__service__pb2.TrainResponse.FromString, 81 | options, channel_credentials, 82 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 83 | 84 | @staticmethod 85 | def GetTrainingStatus(request, 86 | target, 87 | options=(), 88 | channel_credentials=None, 89 | call_credentials=None, 90 | insecure=False, 91 | compression=None, 92 | wait_for_ready=None, 93 | timeout=None, 94 | metadata=None): 95 | return grpc.experimental.unary_unary(request, target, '/training.TrainingService/GetTrainingStatus', 96 | training__service__pb2.GetTrainingStatusRequest.SerializeToString, 97 | training__service__pb2.GetTrainingStatusResponse.FromString, 98 | options, channel_credentials, 99 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 100 | -------------------------------------------------------------------------------- /training-code/text-classification/version.py: -------------------------------------------------------------------------------- 1 | gitsha="231c0d2" 2 | -------------------------------------------------------------------------------- /training-service/README.md: -------------------------------------------------------------------------------- 1 | # Training Service (TS) 2 | Training service is a sample Java (GRPC) webservice for demonstrating the design principles introduced in the chapter 5 of book - `Engineering Deep Learning Systems`. 3 | This service is written in minimalism (for example persisting data in memory instead of a database) so the code is simple to read, and the local setup is easy. 4 | There are multiple external dependencies required for this service to run 5 | - **Minio**, which we used to mimic cloud blob storage, such as `AWS S3` or `Azure blob`. 6 | - **Data Management Service**, which we introduced in chapter 4. 7 | - **Metadata Store Service**, which we introduced in chapter 5. 8 | - Training container runtime, either **Docker** or **Kubernetes** (recommended) 9 | 10 | By reading these code, you will learn how the training service design concept could be implemented. 11 | 12 | ## Function demo 13 | 14 | See [single trainer demo](single_trainer_demo.md) 15 | 16 | See [distributed trainer demo](distributed_trainer_demo.md) 17 | 18 | -------- 19 | 20 | ## Build and play with TS locally 21 | 22 | ### Understand the config file 23 | The TS server takes a few configuration items on startup. This can be found at [config.properties](src/main/resources/config.properties) 24 | > For your convenience, we have provided another config file [config-kube.properties](src/main/resources/config-kube.properties) that has some out of box kubernetes configs 25 | - `dm.host` & `dm.port`: The address of the data-management service. 26 | - `server.port`: The port number that this server listens to. 27 | - `trainer.minio.accessKey` & `trainer.minio.secretKey`: The credential used to access the minio server. 28 | - `trainer.minio.host`: The training container launched by training service needs to talk to minio to access training data. This address changes based on the selected container runtime (`backend` config) 29 | - `trainer.minio.metadataStore.bucketName`: Metadata store service's minio bucket name. The training container needs to write model artifacts into this bucket. 30 | - `trainer.metadataStore.host`: Metadata store service address. The training container needs to communicate with metadata store service periodically. 31 | - `backend`: This can be either `kubectl` or `docker` 32 | - **kubectl**: need to also provide `kubectl.configFile`, with which training service can talk to kubernetes cluster over API; `kubectl.namespace`, all training containers will be submitted in this namespace. 33 | - **docker**: need to also provide `docker.network`, all training containers will be connected to this docker network (so they can access Metadata Store Service) 34 | 35 | ### Start dependency minio 36 | This can be taken care of by our script [ms-001-start-minio.sh](../scripts/ms-001-start-minio.sh) 37 | 38 | ### Start dependency Data Management Service 39 | This can be taken care of by our script [dm-002-start-server.sh](../scripts/dm-002-start-server.sh) 40 | 41 | ### Start dependency Metadata Store Service 42 | This can be taken care of by our script [ms-002-start-server.sh](../scripts/ms-002-start-server.sh) 43 | 44 | ### Prepare text classification trainer image 45 | This can be taken care of by our script [ts-000-build-trainer.sh](../scripts/ts-000-build-trainer.sh). 46 | To make the trainer image available to both kubernetes and docker runtime, our script starts a local docker registry container on port 3000, then "upload" the trainer image to that local registry. 47 | Rest assure there will be **no actual network traffic** involved in this process. 48 | 49 | ### Build and run using docker 50 | 1. Modify config if needed. Set `dm.host` to `data-management`. 51 | 2. The [dockerfile](../services.dockerfile) in the root folder can be used to build the training service directly. Execute `docker build -t orca3/services:latest -f services.dockerfile .` in the root directly will build a docker image called `orca3/services` with `latest` tag. 52 | 3. **Using docker training backend**: Start the service using `docker run --name training-service -v /var/run/docker.sock:/var/run/docker.sock --network orca3 --rm -d -p 6003:51001 orca3/services:latest training-service.jar`. Note we mount your `docker.sock` to the container so the training service container can talk to your docker server (to launch another training container). This is same as running our [ts-001-start-server.sh](../scripts/ts-001-start-server.sh) 53 | 4. **Using kubernetes training backend**: Start the service using `docker run --name training-service -v $HOME/.kube/config:/.kube/config --env APP_CONFIG=config-kube.properties --network orca3 --rm -d -p 6003:51001 orca3/services:latest training-service.jar`. Note we used a different config file [config-kube.properties](src/main/resources/config-kube.properties) and mounted your kube config to the container so the training container can talk to your kubernetes cluster (to launch other training pods). This is same as runinng our [ts-001-start-server-kube.sh](../scripts/ts-001-start-server-kube.sh) 54 | 5. Now the service can be reached at `localhost:6003`. Try `grpcurl -plaintext localhost:6003 grpc.health.v1.Health/Check` or look at examples in [scripts](../scripts) folder to interact with the service 55 | 56 | ### Build and run using java (for experienced Java developer) 57 | 1. Modify config if needed. Set `dm.host` to `localhost`. Set `dm.port` to `6000`. Set `kubectl.configFile` to your kube config. It is either the value of environment variable `KUBECONFIG` or by default `${HOME}/.kube/config`. Please use an absolute path (i.e in my case on a MacOS it is `/Users/robert.xue/.kube/config`) 58 | 2. Use maven to build the project and produce a runnable Jar `./mvnw clean package -pl training-service -am`. 59 | 3. **Using docker training backend**: Run the jar using command `java -jar training-service/target/training-service-1.0-SNAPSHOT.jar` 60 | 4. **Using kubernetes training backend**: Run the jar using command `APP_CONFIG=config-kube.properties java -jar training-service/target/training-service-1.0-SNAPSHOT.jar` 61 | 5. Now the service can be reached at `localhost:51001`. Try `grpcurl -plaintext localhost:51001 grpc.health.v1.Health/Check` or look at examples in [scripts](../scripts) folder to interact with the service 62 | 63 | -------------------------------------------------------------------------------- /training-service/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | org.orca3 9 | mini-auto-ml 10 | 1.0-SNAPSHOT 11 | 12 | 13 | training-service 14 | 1.0-SNAPSHOT 15 | 16 | 17 | org.orca3.miniAutoML.training.TrainingService 18 | 19 | 20 | 21 | 22 | org.orca3 23 | grpc-contract 24 | ${project.version} 25 | 26 | 27 | io.grpc 28 | grpc-services 29 | 30 | 31 | ch.qos.logback 32 | logback-classic 33 | 34 | 35 | com.github.docker-java 36 | docker-java-core 37 | 3.2.11 38 | 39 | 40 | com.github.docker-java 41 | docker-java-transport-httpclient5 42 | 3.2.11 43 | 44 | 45 | io.kubernetes 46 | client-java 47 | 13.0.0 48 | 49 | 50 | 51 | io.grpc 52 | grpc-testing 53 | test 54 | 55 | 56 | junit 57 | junit 58 | RELEASE 59 | test 60 | 61 | 62 | 63 | 64 | 65 | org.codehaus.mojo 66 | exec-maven-plugin 67 | 3.0.0 68 | 69 | 70 | 71 | java 72 | 73 | 74 | 75 | 76 | ${mainClass} 77 | 78 | 79 | 80 | org.apache.maven.plugins 81 | maven-shade-plugin 82 | 83 | 84 | package 85 | 86 | shade 87 | 88 | 89 | 90 | 91 | ${mainClass} 92 | 93 | 94 | 95 | 96 | 97 | *:* 98 | 99 | META-INF/*.SF 100 | META-INF/*.DSA 101 | META-INF/*.RSA 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /training-service/single_trainer_demo.md: -------------------------------------------------------------------------------- 1 | # Single Trainer Training Demo 2 | 3 | > Please run all scripts from the repository's root directory. 4 | 5 | Please go over the [Data Management Function Demo](../data-management/demo.md) first. 6 | This demo depends on the dataset created in that demo. 7 | Take note of the value of `datasetId` in that demo. 8 | In the following instructions, we will assume `datasetId` is `1`. 9 | We will also need both the MinIO server and the Data Management server running from that demo. 10 | 11 | ## Step 1: Start the Metadata Store Server and the Training Server 12 | 13 | Run the following to start the Metadata Store server: 14 | ```shell 15 | ./scripts/ms-002-start-server.sh 16 | ``` 17 | 18 | The script will start the Metadata Store server container using the image `orca3/services:latest`. 19 | 20 | If this image does not exist on your machine, Docker will attempt pulling it from the Docker Hub. 21 | 22 | The script will name the container `metadata-store`, launch it in network `orca3`, and bind it to port `6002`. 23 | 24 | The server will be accessible at `localhost:6002` from your machine, or `metadata-store:51001` from other containers within the same network `orca3`. 25 | 26 | 27 | You should see the following when you run the script: 28 | ```shell 29 | Started metadata-store docker container and listen on port 6002 30 | ``` 31 | 32 | Next, run the following: 33 | ```shell 34 | ./scripts/ts-001-start-server.sh 35 | ``` 36 | 37 | The script will start the Training Service server container using the image `orca3/services:latest`. 38 | 39 | The script will name the container `training-service`, launch it in network `orca3`, and bind it to port `6003`. 40 | 41 | The server will be accessible at `localhost:6003` from your machine, or `training-service:51001` from other containers within the same network `orca3`. 42 | 43 | You should see the following when you run the script: 44 | ```shell 45 | Started training-service docker container and listen on port 6003 46 | ``` 47 | 48 | ## Step 2: Submit a training job 49 | 50 | > The following instructions assume a value of `1` for `datasetId`. 51 | > Make sure the value is the one that you received from the [Data Management Function Demo](../data-management/demo.md). 52 | 53 | > **If you are running on Apple Silicon**, you may need to perform 54 | > ```shell 55 | > docker pull orca3/intent-classification 56 | > ``` 57 | > before submitting a training job. 58 | 59 | Run the following (replace `1` with your `datasetId` if needed): 60 | ```shell 61 | ./scripts/ts-002-start-run.sh 1 62 | ``` 63 | 64 | The script will: 65 | 1. Invoke the `Train` API method of the Training Server to start an `intent-classification` training job on the dataset with `datasetId=1`, 66 | using training parameter `LR=4;EPOCHS=15;BATCH_SIE=64;FC_SIZE=128;`. 67 | 2. The API method will respond with `jobId`, which we can use to track the training status. 68 | 69 | You should see the following when you run the script: 70 | ```shell 71 | { 72 | "job_id": 1 73 | } 74 | ``` 75 | 76 | ## Step 3: Inspect the training job's status 77 | 78 | Run the following (replace `1` if you have a different `job_id` from the step above): 79 | ```shell 80 | ./scripts/ts-003-check-run.sh 1 81 | ``` 82 | 83 | This script will invoke the `GetTrainingStatus` API method of the Training Server to query the training job's status. 84 | 85 | As the job progresses, you may see it in different statuses. Here are some examples. 86 | 87 | ### Job is in the queue 88 | ```shell 89 | { 90 | "job_id": 1, 91 | "message": "Queueing, there are 0 training jobs waiting before this.", 92 | "metadata": { 93 | "algorithm": "intent-classification", 94 | "dataset_id": "1", 95 | "name": "test1", 96 | "train_data_version_hash": "hashDg==", 97 | "parameters": { 98 | "BATCH_SIZE": "64", 99 | "EPOCHS": "15", 100 | "FC_SIZE": "128", 101 | "LR": "4" 102 | }, 103 | "output_model_name": "my-intent-classification-model" 104 | } 105 | } 106 | ``` 107 | 108 | ### Job is being launched 109 | ```shell 110 | { 111 | "status": "launch", 112 | "job_id": 1, 113 | "metadata": { 114 | "algorithm": "intent-classification", 115 | "dataset_id": "1", 116 | "name": "test1", 117 | "train_data_version_hash": "hashDg==", 118 | "parameters": { 119 | "BATCH_SIZE": "64", 120 | "EPOCHS": "15", 121 | "FC_SIZE": "128", 122 | "LR": "4" 123 | }, 124 | "output_model_name": "my-intent-classification-model" 125 | } 126 | } 127 | ``` 128 | 129 | ### Job is running 130 | ```shell 131 | { 132 | "status": "running", 133 | "job_id": 1, 134 | "metadata": { 135 | "algorithm": "intent-classification", 136 | "dataset_id": "1", 137 | "name": "test1", 138 | "train_data_version_hash": "hashDg==", 139 | "parameters": { 140 | "BATCH_SIZE": "64", 141 | "EPOCHS": "15", 142 | "FC_SIZE": "128", 143 | "LR": "4" 144 | }, 145 | "output_model_name": "my-intent-classification-model" 146 | } 147 | } 148 | ``` 149 | 150 | ### Job completed successfully 151 | ```shell 152 | { 153 | "status": "succeed", 154 | "job_id": 1, 155 | "metadata": { 156 | "algorithm": "intent-classification", 157 | "dataset_id": "1", 158 | "name": "test1", 159 | "train_data_version_hash": "hashDg==", 160 | "parameters": { 161 | "BATCH_SIZE": "64", 162 | "EPOCHS": "15", 163 | "FC_SIZE": "128", 164 | "LR": "4" 165 | }, 166 | "output_model_name": "my-intent-classification-model" 167 | } 168 | } 169 | ``` 170 | 171 | ### Job failed due to missing image 172 | If you run into an error like 173 | ```shell 174 | { 175 | "status": "failure", 176 | "job_id": 1, 177 | "message": "Status 404: {\"message\":\"No such image: orca3/intent-classification:latest\"}\n", 178 | "metadata": { 179 | "algorithm": "intent-classification", 180 | "dataset_id": "1", 181 | "name": "test1", 182 | "train_data_version_hash": "hashDg==", 183 | "parameters": { 184 | "BATCH_SIZE": "64", 185 | "EPOCHS": "15", 186 | "FC_SIZE": "128", 187 | "LR": "4" 188 | }, 189 | "output_model_name": "my-intent-classification-model" 190 | } 191 | } 192 | ``` 193 | run 194 | ```shell 195 | docker pull orca3/intent-classification:latest 196 | ``` 197 | and retry from Step 2. 198 | 199 | ## Clean up 200 | 201 | > If you would like to run the [distributed training service lab (Chapter 4)](distributed_trainer_demo.md), skip this step and keep containers running. 202 | > They will provide the required dataset. 203 | 204 | Run the following: 205 | ```shell 206 | ./scripts/lab-999-tear-down.sh 207 | ``` 208 | -------------------------------------------------------------------------------- /training-service/src/main/java/org/orca3/miniAutoML/training/TrainingService.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.training; 2 | 3 | import com.google.common.base.Strings; 4 | import io.grpc.ManagedChannel; 5 | import io.grpc.ManagedChannelBuilder; 6 | import io.grpc.Status; 7 | import io.grpc.stub.StreamObserver; 8 | import org.orca3.miniAutoML.ServiceBase; 9 | import org.orca3.miniAutoML.training.models.ExecutedTrainingJob; 10 | import org.orca3.miniAutoML.training.models.MemoryStore; 11 | import org.orca3.miniAutoML.training.tracker.DockerTracker; 12 | import org.orca3.miniAutoML.training.tracker.KubectlTracker; 13 | import org.orca3.miniAutoML.training.tracker.Tracker; 14 | import org.slf4j.Logger; 15 | import org.slf4j.LoggerFactory; 16 | 17 | import java.io.IOException; 18 | import java.util.Properties; 19 | import java.util.concurrent.Executors; 20 | import java.util.concurrent.ScheduledExecutorService; 21 | import java.util.concurrent.ScheduledFuture; 22 | 23 | import static java.util.concurrent.TimeUnit.SECONDS; 24 | 25 | public class TrainingService extends TrainingServiceGrpc.TrainingServiceImplBase { 26 | private final MemoryStore store; 27 | private final Config config; 28 | private static final Logger logger = LoggerFactory.getLogger(TrainingService.class); 29 | 30 | public TrainingService(MemoryStore store, Config config) { 31 | this.config = config; 32 | this.store = store; 33 | } 34 | 35 | public static void main(String[] args) throws IOException, InterruptedException { 36 | logger.info("Hello, Training Service!"); 37 | Properties props = ServiceBase.getConfigProperties(); 38 | Config config = new Config(props); 39 | 40 | MemoryStore store = new MemoryStore(); 41 | ManagedChannel dmChannel = ManagedChannelBuilder.forAddress(config.dmHost, Integer.parseInt(config.dmPort)) 42 | .usePlaintext().build(); 43 | Tracker tracker; 44 | if (config.backend.equals("docker")) { 45 | logger.info("Using docker backend."); 46 | tracker = new DockerTracker(store, props, dmChannel); 47 | } else if (config.backend.equals("kubectl")) { 48 | logger.info("Using kubernetes backend."); 49 | tracker = new KubectlTracker(store, props, dmChannel); 50 | } else { 51 | throw new IllegalArgumentException(String.format("Unsupported backend %s", config.backend)); 52 | } 53 | TrainingService trainingService = new TrainingService(store, config); 54 | final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(2); 55 | final ScheduledFuture launchingTask = 56 | scheduler.scheduleAtFixedRate(tracker::launchAll, 5, 5, SECONDS); 57 | final ScheduledFuture refreshingTask = 58 | scheduler.scheduleAtFixedRate(tracker::updateContainerStatus, 7, 5, SECONDS); 59 | ServiceBase.startService(Integer.parseInt(config.serverPort), trainingService, () -> { 60 | dmChannel.shutdownNow(); 61 | tracker.shutdownAll(); 62 | launchingTask.cancel(true); 63 | refreshingTask.cancel(true); 64 | scheduler.shutdown(); 65 | 66 | }); 67 | } 68 | 69 | static class Config { 70 | final String serverPort; 71 | final String dmHost; 72 | final String dmPort; 73 | final String backend; 74 | 75 | public Config(Properties properties) { 76 | this.serverPort = properties.getProperty("ts.server.port"); 77 | this.dmPort = properties.getProperty("dm.server.port"); 78 | this.dmHost = properties.getProperty("dm.server.host"); 79 | this.backend = properties.getProperty("ts.backend"); 80 | } 81 | } 82 | 83 | @Override 84 | public void train(TrainRequest request, StreamObserver responseObserver) { 85 | if (Strings.isNullOrEmpty(request.getMetadata().getOutputModelName())) { 86 | responseObserver.onError(Status.INVALID_ARGUMENT 87 | .withDescription("\"outputModelName\" is required.") 88 | .asException()); 89 | return; 90 | } 91 | if (Strings.isNullOrEmpty(request.getMetadata().getDatasetId())) { 92 | responseObserver.onError(Status.INVALID_ARGUMENT 93 | .withDescription("\"datasetId\" is required.") 94 | .asException()); 95 | return; 96 | } 97 | if (Strings.isNullOrEmpty(request.getMetadata().getTrainDataVersionHash())) { 98 | responseObserver.onError(Status.INVALID_ARGUMENT 99 | .withDescription("\"trainDataVersionHash\" is required.") 100 | .asException()); 101 | return; 102 | } 103 | if (Strings.isNullOrEmpty(request.getMetadata().getAlgorithm())) { 104 | responseObserver.onError(Status.INVALID_ARGUMENT 105 | .withDescription("\"algorithm\" is required.") 106 | .asException()); 107 | return; 108 | } 109 | int jobId = store.offer(request); 110 | responseObserver.onNext(TrainResponse.newBuilder().setJobId(jobId).build()); 111 | responseObserver.onCompleted(); 112 | } 113 | 114 | @Override 115 | public void getTrainingStatus(GetTrainingStatusRequest request, StreamObserver responseObserver) { 116 | int jobId = request.getJobId(); 117 | final ExecutedTrainingJob job; 118 | final TrainingStatus status; 119 | if (store.finalizedJobs.containsKey(jobId)) { 120 | job = store.finalizedJobs.get(jobId); 121 | status = job.isSuccess() ? TrainingStatus.succeed : TrainingStatus.failure; 122 | } else if (store.launchingList.containsKey(jobId)) { 123 | job = store.launchingList.get(jobId); 124 | status = TrainingStatus.launch; 125 | } else if (store.runningList.containsKey(jobId)) { 126 | job = store.runningList.get(jobId); 127 | status = TrainingStatus.running; 128 | } else { 129 | TrainingJobMetadata metadata = store.jobQueue.get(jobId); 130 | if (metadata != null) { 131 | int position = store.getQueuePosition(jobId); 132 | responseObserver.onNext(GetTrainingStatusResponse.newBuilder() 133 | .setJobId(jobId) 134 | .setStatus(TrainingStatus.queuing) 135 | .setMetadata(metadata) 136 | .setMessage(String.format("Queueing, there are %s training jobs waiting before this.", position)) 137 | .setPositionInQueue(position) 138 | .build()); 139 | responseObserver.onCompleted(); 140 | } else { 141 | responseObserver.onError(Status.NOT_FOUND 142 | .withDescription(String.format("Job %s doesn't exist", jobId)) 143 | .asException()); 144 | } 145 | return; 146 | } 147 | responseObserver.onNext(GetTrainingStatusResponse.newBuilder() 148 | .setJobId(jobId) 149 | .setStatus(status) 150 | .setMessage(job.getMessage()) 151 | .setMetadata(job.getMetadata()) 152 | .build()); 153 | responseObserver.onCompleted(); 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /training-service/src/main/java/org/orca3/miniAutoML/training/models/ExecutedTrainingJob.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.training.models; 2 | 3 | import com.google.common.base.Strings; 4 | import org.orca3.miniAutoML.training.TrainingJobMetadata; 5 | 6 | import javax.annotation.Nullable; 7 | 8 | public class ExecutedTrainingJob { 9 | private final long launchedAt; 10 | private final Long finishedAt; 11 | private final boolean success; 12 | private final TrainingJobMetadata metadata; 13 | private final String message; 14 | 15 | ExecutedTrainingJob(long launchedAt, Long finishedAt, boolean success, TrainingJobMetadata metadata, String message) { 16 | this.launchedAt = launchedAt; 17 | this.finishedAt = finishedAt; 18 | this.success = success; 19 | this.metadata = metadata; 20 | this.message = Strings.nullToEmpty(message); 21 | } 22 | 23 | public ExecutedTrainingJob(Long launchedAt, TrainingJobMetadata metadata, String message) { 24 | this(launchedAt, null, false, metadata, message); 25 | } 26 | 27 | public boolean isSuccess() { 28 | return success; 29 | } 30 | 31 | public TrainingJobMetadata getMetadata() { 32 | return metadata; 33 | } 34 | 35 | public String getMessage() { 36 | return message; 37 | } 38 | 39 | public Long getLaunchedAt() { 40 | return launchedAt; 41 | } 42 | 43 | @Nullable 44 | public Long getFinishedAt() { 45 | return finishedAt; 46 | } 47 | 48 | public ExecutedTrainingJob finished(long finishedAt, boolean success, String message) { 49 | return new ExecutedTrainingJob(this.launchedAt, finishedAt, success, this.metadata, message); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /training-service/src/main/java/org/orca3/miniAutoML/training/models/MemoryStore.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.training.models; 2 | 3 | import org.orca3.miniAutoML.training.TrainRequest; 4 | import org.orca3.miniAutoML.training.TrainingJobMetadata; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | import java.util.SortedMap; 9 | import java.util.TreeMap; 10 | import java.util.concurrent.atomic.AtomicInteger; 11 | 12 | public class MemoryStore { 13 | public SortedMap jobQueue = new TreeMap<>(); 14 | public Map launchingList = new HashMap<>(); 15 | public Map runningList = new HashMap<>(); 16 | public Map finalizedJobs = new HashMap<>(); 17 | AtomicInteger jobIdSeed = new AtomicInteger(); 18 | 19 | public int offer(TrainRequest request) { 20 | int jobId = jobIdSeed.incrementAndGet(); 21 | jobQueue.put(jobId, request.getMetadata()); 22 | return jobId; 23 | } 24 | 25 | public int getQueuePosition(int jobId) { 26 | return jobQueue.headMap(jobId).size(); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /training-service/src/main/java/org/orca3/miniAutoML/training/tracker/Tracker.java: -------------------------------------------------------------------------------- 1 | package org.orca3.miniAutoML.training.tracker; 2 | 3 | import io.grpc.ManagedChannel; 4 | import org.orca3.miniAutoML.dataManagement.DataManagementServiceGrpc; 5 | import org.orca3.miniAutoML.dataManagement.SnapshotState; 6 | import org.orca3.miniAutoML.dataManagement.VersionQuery; 7 | import org.orca3.miniAutoML.dataManagement.VersionedSnapshot; 8 | import org.orca3.miniAutoML.training.TrainingJobMetadata; 9 | import org.orca3.miniAutoML.training.models.ExecutedTrainingJob; 10 | import org.orca3.miniAutoML.training.models.MemoryStore; 11 | import org.slf4j.Logger; 12 | 13 | import java.util.HashMap; 14 | import java.util.Map; 15 | import java.util.Properties; 16 | 17 | public abstract class Tracker { 18 | protected final MemoryStore store; 19 | protected final DataManagementServiceGrpc.DataManagementServiceBlockingStub dmClient; 20 | protected final Logger logger; 21 | protected final T config; 22 | 23 | Tracker(MemoryStore store, ManagedChannel dmChannel, Logger logger, T config) { 24 | this.store = store; 25 | this.dmClient = DataManagementServiceGrpc.newBlockingStub(dmChannel); 26 | this.logger = logger; 27 | this.config = config; 28 | } 29 | 30 | public void launchAll() { 31 | while (hasCapacity() && !store.jobQueue.isEmpty()) { 32 | int jobId = store.jobQueue.firstKey(); 33 | TrainingJobMetadata metadata = store.jobQueue.get(jobId); 34 | try { 35 | VersionedSnapshot r = dmClient.fetchTrainingDataset(VersionQuery.newBuilder() 36 | .setDatasetId(metadata.getDatasetId()).setVersionHash(metadata.getTrainDataVersionHash()) 37 | .build()); 38 | if (r.getState() == SnapshotState.READY) { 39 | store.jobQueue.remove(jobId); 40 | launch(jobId, metadata, r); 41 | store.launchingList.put(jobId, new ExecutedTrainingJob(System.currentTimeMillis(), metadata, "")); 42 | } else { 43 | logger.info(String.format("Dataset %s of version hash %s is not ready yet. Current state: %s.", 44 | metadata.getDatasetId(), metadata.getTrainDataVersionHash(), r.getState())); 45 | } 46 | } catch (Exception ex) { 47 | store.jobQueue.remove(jobId); 48 | store.finalizedJobs.put(jobId, new ExecutedTrainingJob(System.currentTimeMillis(), metadata, "") 49 | .finished(System.currentTimeMillis(), false, ex.getMessage())); 50 | logger.warn(String.format("Failed to launch job %d.", jobId), ex); 51 | } 52 | } 53 | } 54 | 55 | protected Map containerEnvVars(int jobId, TrainingJobMetadata metadata, VersionedSnapshot versionedSnapshot) { 56 | Map envs = new HashMap<>(); 57 | envs.put("JOB_ID", Integer.toString(jobId)); 58 | envs.put("ALGORITHM_NAME", metadata.getAlgorithm()); 59 | envs.put("METADATA_STORE_SERVER", config.metadataStoreHost); 60 | envs.put("TRAINING_DATASET_ID", metadata.getDatasetId()); 61 | envs.put("TRAINING_DATASET_VERSION_HASH", metadata.getTrainDataVersionHash()); 62 | envs.put("MODEL_BUCKET", config.metadataStoreBucketName); 63 | envs.put("MODEL_NAME", metadata.getOutputModelName()); 64 | envs.put("MINIO_SERVER", config.minioHost); 65 | envs.put("MINIO_SERVER_ACCESS_KEY", config.minioAccessKey); 66 | envs.put("MINIO_SERVER_SECRET_KEY", config.minioSecretKey); 67 | envs.put("TRAINING_DATA_BUCKET", versionedSnapshot.getRoot().getBucket()); 68 | envs.put("TRAINING_DATA_PATH", versionedSnapshot.getRoot().getPath()); 69 | envs.putAll(metadata.getParametersMap()); 70 | return envs; 71 | } 72 | 73 | public abstract boolean hasCapacity(); 74 | 75 | protected abstract String launch(int jobId, TrainingJobMetadata metadata, VersionedSnapshot versionedSnapshot); 76 | 77 | public abstract void updateContainerStatus(); 78 | 79 | public abstract void shutdownAll(); 80 | 81 | public static class SharedConfig { 82 | final String minioAccessKey; 83 | final String minioSecretKey; 84 | final String minioHost; 85 | final String metadataStoreBucketName; 86 | final String metadataStoreHost; 87 | 88 | SharedConfig(Properties props) { 89 | this.minioAccessKey = props.getProperty("minio.accessKey"); 90 | this.minioSecretKey = props.getProperty("minio.secretKey"); 91 | this.minioHost = props.getProperty("ts.trainer.minio.host"); 92 | this.metadataStoreHost = props.getProperty("ts.trainer.ms.host"); 93 | this.metadataStoreBucketName = props.getProperty("ms.minio.bucketName"); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /training-service/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | --------------------------------------------------------------------------------