├── .gitignore ├── Kafkastreamsserver └── src │ └── main │ └── java │ └── com │ └── lightbend │ ├── modelserver │ ├── store │ │ ├── ModelStateSerde.java │ │ ├── ModelStateStore.java │ │ ├── ModelStateStoreChangeLogger.java │ │ ├── ModelStateStoreSupplier.java │ │ ├── ReadableModelStateStore.java │ │ └── StoreState.java │ └── withstore │ │ ├── DataProcessorWithStore.java │ │ ├── ModelProcessorWithStore.java │ │ └── ModelServerWithStore.java │ └── queriablestate │ ├── HostStoreInfo.java │ ├── MetadataService.java │ ├── ModelServingInfo.java │ └── QueriesRestService.java ├── README.md ├── akkaserver └── src │ └── main │ └── scala │ └── com │ └── lightbend │ ├── modelserver │ ├── AkkaModelServer.scala │ ├── ModelStage.scala │ └── ReadableModelStateStore.scala │ └── queriablestate │ └── QueriesAkkaHttpResource.scala ├── build.sbt ├── data ├── WineQuality │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index ├── optimized_WineQuality.pb ├── winequalityDecisionTreeClassification.pmml ├── winequalityDesisionTreeRegression.pmml ├── winequalityGeneralizedLinearRegressionGamma.pmml ├── winequalityGeneralizedLinearRegressionGaussian.pmml ├── winequalityLinearRegression.pmml ├── winequalityMultilayerPerceptron.pmml ├── winequalityRandonForrestClassification.pmml ├── winequality_red.csv └── winequality_red_names.csv ├── flinkserver └── src │ └── main │ ├── resources │ ├── log4j.properties.bat │ └── logback.xml.bat │ └── scala │ └── com │ └── lightbend │ └── modelserver │ ├── BadDataHandler.scala │ ├── keyed │ ├── DataProcessorKeyed.scala │ └── ModelServingKeyedJob.scala │ ├── partitioned │ ├── DataProcessorMap.scala │ └── ModelServingFlatJob.scala │ ├── query │ └── ModelStateQuery.scala │ └── typeschema │ ├── ByteArraySchema.scala │ └── ModelTypeSerializer.scala ├── kafkaclient └── src │ └── main │ └── scala │ └── com │ └── lightbend │ └── kafka │ ├── DataProvider.scala │ ├── KafkaMessageSender.scala │ └── ModelProvider.scala ├── kafkaconfiguration └── src │ └── main │ └── java │ └── com │ └── lightbend │ └── configuration │ └── kafka │ └── ApplicationKafkaParameters.java ├── model └── src │ └── main │ ├── java │ └── com │ │ └── lightbend │ │ └── model │ │ └── java │ │ ├── Model.java │ │ ├── ModelFactory.java │ │ ├── PMML │ │ ├── PMMLModel.java │ │ └── PMMLModelFactory.java │ │ └── tensorflow │ │ ├── TensorflowModel.java │ │ └── TensorflowModelFactory.java │ └── scala │ └── com │ └── lightbend │ └── model │ └── scala │ ├── DataWithModel.scala │ ├── Model.scala │ ├── ModelFactory.scala │ ├── PMML │ └── PMMLModel.scala │ └── tensorflow │ └── TensorFlowModel.scala ├── project ├── Dependencies.scala ├── Versions.scala ├── assembly.sbt ├── build.properties └── scalapb.sbt ├── protobufs └── src │ └── main │ └── protobuf │ ├── modeldescriptor.proto │ └── winerecord.proto ├── servingsamples └── src │ └── main │ └── scala │ └── com │ └── lightbend │ ├── jpmml │ └── WineQualityRandomForestClassifier.scala │ └── tensorflow │ ├── WineModelServing.scala │ └── WineModelServingBundle.scala ├── sparkML └── src │ └── main │ └── scala │ └── com │ └── lightbend │ └── spark │ └── ml │ ├── WineQualityDecisionTreeClassifier.scala │ ├── WineQualityDecisionTreeRegressor.scala │ ├── WineQualityPerceptron.scala │ ├── WineQualityRandomForrestClassifier.scala │ ├── WinequalityGeneralizedLinearRegression.scala │ └── WinequalityLinearRegression.scala ├── sparkserver └── src │ └── main │ └── scala │ └── com │ └── lightbend │ └── modelserver │ ├── DataRecord.scala │ ├── ModelSerializerKryo.scala │ └── SparkModelServer.scala └── utils └── src └── main ├── java └── com │ └── lightbend │ └── modelserver │ └── support │ └── java │ ├── DataConverter.java │ └── ModelToServe.java └── scala └── com └── lightbend └── modelserver ├── kafka ├── EmbeddedSingleNodeKafkaCluster.scala ├── KafkaEmbedded.scala └── KafkaSupport.scala └── support └── scala ├── DataReader.scala └── ModelToServe.scala /.gitignore: -------------------------------------------------------------------------------- 1 | ### Kerberos/keytab ### 2 | *.keytab 3 | *_jaas.conf 4 | 5 | ### SQLite ### 6 | *.db 7 | 8 | # Created by https://www.gitignore.io/api/eclipse,gradle,intellij,intellij+iml,java,maven,sbt,scala,visualstudiocode 9 | 10 | ### Eclipse ### 11 | 12 | .metadata 13 | bin/ 14 | tmp/ 15 | *.tmp 16 | *.bak 17 | *.swp 18 | *~.nib 19 | local.properties 20 | .settings/ 21 | .loadpath 22 | .recommenders 23 | 24 | # Eclipse Core 25 | .project 26 | 27 | # External tool builders 28 | .externalToolBuilders/ 29 | 30 | # Locally stored "Eclipse launch configurations" 31 | *.launch 32 | 33 | # PyDev specific (Python IDE for Eclipse) 34 | *.pydevproject 35 | 36 | # CDT-specific (C/C++ Development Tooling) 37 | .cproject 38 | 39 | # JDT-specific (Eclipse Java Development Tools) 40 | .classpath 41 | 42 | # Java annotation processor (APT) 43 | .factorypath 44 | 45 | # PDT-specific (PHP Development Tools) 46 | .buildpath 47 | 48 | # sbteclipse plugin 49 | .target 50 | 51 | # Tern plugin 52 | .tern-project 53 | 54 | # TeXlipse plugin 55 | .texlipse 56 | 57 | # STS (Spring Tool Suite) 58 | .springBeans 59 | 60 | # Code Recommenders 61 | .recommenders/ 62 | 63 | 64 | ### Intellij ### 65 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 66 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 67 | # Everything 68 | .idea/ 69 | 70 | ## File-based project format: 71 | *.iws 72 | 73 | ## Plugin-specific files: 74 | 75 | # IntelliJ 76 | /out/ 77 | 78 | # mpeltonen/sbt-idea plugin 79 | .idea_modules/ 80 | 81 | # JIRA plugin 82 | atlassian-ide-plugin.xml 83 | 84 | # Crashlytics plugin (for Android Studio and IntelliJ) 85 | com_crashlytics_export_strings.xml 86 | crashlytics.properties 87 | crashlytics-build.properties 88 | fabric.properties 89 | 90 | ### Intellij Patch ### 91 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 92 | 93 | # *.iml 94 | # modules.xml 95 | # .idea/misc.xml 96 | # *.ipr 97 | 98 | 99 | ### Intellij+iml ### 100 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 101 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 102 | 103 | # User-specific stuff: 104 | 105 | # Sensitive or high-churn files: 106 | 107 | # Gradle: 108 | 109 | # Mongo Explorer plugin: 110 | 111 | ## File-based project format: 112 | 113 | ## Plugin-specific files: 114 | 115 | # IntelliJ 116 | 117 | # mpeltonen/sbt-idea plugin 118 | 119 | # JIRA plugin 120 | 121 | # Crashlytics plugin (for Android Studio and IntelliJ) 122 | 123 | ### Intellij+iml Patch ### 124 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 125 | 126 | *.iml 127 | modules.xml 128 | .idea/misc.xml 129 | *.ipr 130 | 131 | 132 | ### Maven ### 133 | target/ 134 | pom.xml.tag 135 | pom.xml.releaseBackup 136 | pom.xml.versionsBackup 137 | pom.xml.next 138 | release.properties 139 | dependency-reduced-pom.xml 140 | buildNumber.properties 141 | .mvn/timing.properties 142 | 143 | # Exclude maven wrapper 144 | !/.mvn/wrapper/maven-wrapper.jar 145 | 146 | 147 | ### SBT ### 148 | # Simple Build Tool 149 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 150 | 151 | lib_managed/ 152 | src_managed/ 153 | project/boot/ 154 | .history 155 | .cache 156 | 157 | 158 | ### Scala ### 159 | *.class 160 | *.log 161 | 162 | # sbt specific 163 | .lib/ 164 | dist/* 165 | project/plugins/project/ 166 | 167 | # Scala-IDE specific 168 | .scala_dependencies 169 | .worksheet 170 | 171 | # ENSIME specific 172 | .ensime_cache/ 173 | .ensime 174 | 175 | 176 | ### VisualStudioCode ### 177 | .vscode/* 178 | !.vscode/settings.json 179 | !.vscode/tasks.json 180 | !.vscode/launch.json 181 | !.vscode/extensions.json 182 | 183 | 184 | ### Java ### 185 | 186 | # BlueJ files 187 | *.ctxt 188 | 189 | # Mobile Tools for Java (J2ME) 190 | .mtj.tmp/ 191 | 192 | # Package Files # 193 | *.jar 194 | *.war 195 | *.ear 196 | 197 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 198 | hs_err_pid* 199 | 200 | 201 | ### Gradle ### 202 | .gradle 203 | /build/ 204 | 205 | # Ignore Gradle GUI config 206 | gradle-app.setting 207 | 208 | # Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored) 209 | !gradle-wrapper.jar 210 | 211 | # Cache of project 212 | .gradletasknamecache 213 | 214 | # # Work around https://youtrack.jetbrains.com/issue/IDEA-116898 215 | # gradle/wrapper/gradle-wrapper.properties 216 | 217 | deploy.conf -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/ModelStateSerde.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import com.lightbend.model.java.Model; 4 | import com.lightbend.model.java.ModelFactory; 5 | import com.lightbend.model.Modeldescriptor; 6 | import com.lightbend.model.java.PMML.PMMLModelFactory; 7 | import com.lightbend.model.java.tensorflow.TensorflowModelFactory; 8 | import com.lightbend.queriablestate.ModelServingInfo; 9 | import org.apache.kafka.common.serialization.Deserializer; 10 | import org.apache.kafka.common.serialization.Serde; 11 | import org.apache.kafka.common.serialization.Serializer; 12 | 13 | import java.io.ByteArrayInputStream; 14 | import java.io.ByteArrayOutputStream; 15 | import java.io.DataInputStream; 16 | import java.io.DataOutputStream; 17 | import java.util.HashMap; 18 | import java.util.Map; 19 | 20 | /** 21 | * Created by boris on 7/11/17. 22 | * based on 23 | * https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/scala/io/confluent/examples/streams/algebird/TopCMSSerde.scala 24 | */ 25 | public class ModelStateSerde implements Serde { 26 | 27 | final private Serializer serializer; 28 | final private Deserializer deserializer; 29 | 30 | public ModelStateSerde(){ 31 | serializer = new ModelStateSerializer(); 32 | deserializer = new ModelStateDeserializer(); 33 | } 34 | 35 | @Override public void configure(Map configs, boolean isKey) {} 36 | 37 | @Override public void close() {} 38 | 39 | @Override public Serializer serializer() { 40 | return serializer; 41 | } 42 | 43 | @Override public Deserializer deserializer() { 44 | return deserializer; 45 | } 46 | 47 | public static class ModelStateSerializer implements Serializer { 48 | 49 | private ByteArrayOutputStream bos = new ByteArrayOutputStream(); 50 | 51 | 52 | @Override public void configure(Map configs, boolean isKey) {} 53 | 54 | @Override public byte[] serialize(String topic, StoreState state) { 55 | 56 | System.out.println("Serializing Store !!"); 57 | 58 | bos.reset(); 59 | DataOutputStream output = new DataOutputStream(bos); 60 | 61 | writeModel(state.getCurrentModel(), output); 62 | writeModel(state.getNewModel(), output); 63 | 64 | writeServingInfo(state.getCurrentServingInfo(), output); 65 | writeServingInfo(state.getNewServingInfo(), output); 66 | 67 | try { 68 | output.flush(); 69 | output.close(); 70 | } 71 | catch(Throwable t){} 72 | return bos.toByteArray(); 73 | 74 | } 75 | 76 | private void writeModel(Model model, DataOutputStream output){ 77 | try{ 78 | if(model == null){ 79 | output.writeLong(0); 80 | return; 81 | } 82 | byte[] bytes = model.getBytes(); 83 | output.writeLong(bytes.length); 84 | output.writeLong(model.getType()); 85 | output.write(bytes); 86 | } 87 | catch (Throwable t){ 88 | System.out.println("Error Serializing model"); 89 | t.printStackTrace(); 90 | } 91 | } 92 | 93 | private void writeServingInfo(ModelServingInfo servingInfo, DataOutputStream output){ 94 | try{ 95 | if(servingInfo == null) { 96 | output.writeLong(0); 97 | return; 98 | } 99 | output.writeLong(5); 100 | output.writeUTF(servingInfo.getDescription()); 101 | output.writeUTF(servingInfo.getName()); 102 | output.writeDouble(servingInfo.getDuration()); 103 | output.writeLong(servingInfo.getInvocations()); 104 | output.writeLong(servingInfo.getMax()); 105 | output.writeLong(servingInfo.getMin()); 106 | output.writeLong(servingInfo.getSince()); 107 | } 108 | catch (Throwable t){ 109 | System.out.println("Error Serializing servingInfo"); 110 | t.printStackTrace(); 111 | } 112 | } 113 | 114 | @Override public void close() {} 115 | } 116 | public static class ModelStateDeserializer implements Deserializer { 117 | 118 | private static final Map factories = new HashMap() { 119 | { 120 | put(Modeldescriptor.ModelDescriptor.ModelType.TENSORFLOW.getNumber(), TensorflowModelFactory.getInstance()); 121 | put(Modeldescriptor.ModelDescriptor.ModelType.PMML.getNumber(), PMMLModelFactory.getInstance()); 122 | } 123 | }; 124 | 125 | @Override 126 | public void configure(Map configs, boolean isKey) { 127 | } 128 | 129 | @Override 130 | public StoreState deserialize(String topic, byte[] data) { 131 | 132 | System.out.println("Deserializing Store !!"); 133 | 134 | ByteArrayInputStream bis = new ByteArrayInputStream(data); 135 | DataInputStream input = new DataInputStream(bis); 136 | 137 | Model currentModel = readModel(input); 138 | Model newModel = readModel(input); 139 | 140 | ModelServingInfo currentServingInfo = readServingInfo(input); 141 | ModelServingInfo newServingInfo = readServingInfo(input); 142 | 143 | return new StoreState(currentModel, newModel, currentServingInfo, newServingInfo); 144 | } 145 | 146 | @Override 147 | public void close() { 148 | } 149 | 150 | private Model readModel(DataInputStream input) { 151 | try { 152 | int length = (int)input.readLong(); 153 | if (length == 0) 154 | return null; 155 | int type = (int) input.readLong(); 156 | byte[] bytes = new byte[length]; 157 | input.read(bytes); 158 | ModelFactory factory = factories.get(type); 159 | return factory.restore(bytes); 160 | } catch (Throwable t) { 161 | System.out.println("Error Deserializing model"); 162 | t.printStackTrace(); 163 | return null; 164 | } 165 | } 166 | 167 | private ModelServingInfo readServingInfo(DataInputStream input) { 168 | try { 169 | long length = input.readLong(); 170 | if (length == 0) 171 | return null; 172 | String descriprtion = input.readUTF(); 173 | String name = input.readUTF(); 174 | double duration = input.readDouble(); 175 | long invocations = input.readLong(); 176 | long max = input.readLong(); 177 | long min = input.readLong(); 178 | long since = input.readLong(); 179 | return new ModelServingInfo(name, descriprtion, since, invocations, duration, min, max); 180 | } catch (Throwable t) { 181 | System.out.println("Error Deserializing serving info"); 182 | t.printStackTrace(); 183 | return null; 184 | } 185 | } 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/ModelStateStore.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import com.lightbend.model.java.Model; 4 | import com.lightbend.queriablestate.ModelServingInfo; 5 | import org.apache.kafka.common.serialization.Serdes; 6 | import org.apache.kafka.streams.processor.ProcessorContext; 7 | import org.apache.kafka.streams.processor.StateRestoreCallback; 8 | import org.apache.kafka.streams.processor.StateStore; 9 | import org.apache.kafka.streams.state.QueryableStoreType; 10 | import org.apache.kafka.streams.state.StateSerdes; 11 | import org.apache.kafka.streams.state.internals.StateStoreProvider; 12 | 13 | /** 14 | * Created by boris on 7/11/17. 15 | * Implementation of a custom state store based on 16 | * http://docs.confluent.io/current/streams/developer-guide.html#streams-developer-guide-state-store-custom 17 | * and example at: 18 | * https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/scala/io/confluent/examples/streams/algebird/CMSStore.scala 19 | * 20 | */ 21 | public class ModelStateStore implements StateStore, ReadableModelStateStore { 22 | 23 | private String name = null; 24 | private boolean loggingEnabled = false; 25 | private ModelStateStoreChangeLogger changeLogger = null; 26 | /** 27 | * The "storage backend" of this store. 28 | * Needs proper initializing in case the store's changelog is empty. 29 | */ 30 | private StoreState state = null; 31 | private boolean open = false; 32 | private int changelogKey = 42; 33 | 34 | /* 35 | * @param name The name of this store instance 36 | */ 37 | 38 | public ModelStateStore(String name, boolean loggingEnabled) { 39 | this.name = name; 40 | this.loggingEnabled = loggingEnabled; 41 | state = new StoreState(); 42 | } 43 | 44 | @Override 45 | public String name() { 46 | return name; 47 | } 48 | 49 | @Override 50 | public void init(ProcessorContext context, StateStore root) { 51 | StateSerdes serdes = new StateSerdes( 52 | name, Serdes.Integer(), new ModelStateSerde()); 53 | changeLogger = new ModelStateStoreChangeLogger(name, context, serdes); 54 | if (root != null && loggingEnabled) { 55 | context.register(root, loggingEnabled, new StateRestoreCallback() { 56 | @Override 57 | public void restore(byte[] key, byte[] value) { 58 | if (value == null) { 59 | state.zero(); 60 | } else { 61 | state = serdes.valueFrom(value); 62 | } 63 | } 64 | }); 65 | } 66 | open = true; 67 | } 68 | 69 | /** 70 | * Periodically saves the latest state to Kafka. 71 | * =Implementation detail= 72 | * The changelog records have the form: (hardcodedKey, StoreState). That is, we are backing up the 73 | * underlying StoreState data structure in its entirety to Kafka. 74 | */ 75 | @Override 76 | public void flush() { 77 | if (loggingEnabled) { 78 | changeLogger.logChange(changelogKey, state); 79 | } 80 | } 81 | 82 | @Override 83 | public void close() { 84 | open = false; 85 | } 86 | 87 | @Override 88 | public boolean persistent() { 89 | return false; 90 | } 91 | 92 | @Override 93 | public boolean isOpen() { 94 | return open; 95 | } 96 | 97 | public Model getCurrentModel() { 98 | return state.getCurrentModel(); 99 | } 100 | 101 | public void setCurrentModel(Model currentModel) { 102 | state.setCurrentModel(currentModel); 103 | } 104 | 105 | public Model getNewModel() { 106 | return state.getNewModel(); 107 | } 108 | 109 | public void setNewModel(Model newModel) { 110 | state.setNewModel(newModel); 111 | } 112 | 113 | public ModelServingInfo getCurrentServingInfo() { 114 | return state.getCurrentServingInfo(); 115 | } 116 | 117 | public void setCurrentServingInfo(ModelServingInfo currentServingInfo) { 118 | state.setCurrentServingInfo(currentServingInfo); 119 | } 120 | 121 | public ModelServingInfo getNewServingInfo() { 122 | return state.getNewServingInfo(); 123 | } 124 | 125 | public void setNewServingInfo(ModelServingInfo newServingInfo) { 126 | state.setNewServingInfo(newServingInfo); 127 | } 128 | 129 | public static class ModelStateStoreType implements QueryableStoreType { 130 | 131 | @Override 132 | public boolean accepts(StateStore stateStore) { 133 | return stateStore instanceof ModelStateStore; 134 | } 135 | 136 | @Override 137 | public ReadableModelStateStore create(StateStoreProvider provider, String storeName) { 138 | return provider.stores(storeName, this).get(0); 139 | } 140 | 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/ModelStateStoreChangeLogger.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import org.apache.kafka.common.serialization.Serializer; 4 | import org.apache.kafka.streams.processor.ProcessorContext; 5 | import org.apache.kafka.streams.processor.internals.ProcessorStateManager; 6 | import org.apache.kafka.streams.processor.internals.RecordCollector; 7 | import org.apache.kafka.streams.state.StateSerdes; 8 | 9 | /** 10 | * Created by boris on 7/11/17. 11 | * based on 12 | * https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/scala/io/confluent/examples/streams/algebird/CMSStoreChangeLogger.scala 13 | */ 14 | public class ModelStateStoreChangeLogger { 15 | 16 | private String topic; 17 | private RecordCollector collector; 18 | private int partition; 19 | private StateSerdes serialization; 20 | private ProcessorContext context; 21 | 22 | public ModelStateStoreChangeLogger(String storeName, ProcessorContext context, int partition, StateSerdes serialization){ 23 | topic = ProcessorStateManager.storeChangelogTopic (context.applicationId(), storeName); 24 | collector = ((RecordCollector.Supplier)context).recordCollector(); 25 | this.partition = partition; 26 | this.serialization = serialization; 27 | this.context = context; 28 | } 29 | 30 | public ModelStateStoreChangeLogger(String storeName, ProcessorContext context, StateSerdes serialization){ 31 | this(storeName, context, context.taskId().partition, serialization); 32 | } 33 | 34 | public void logChange(K key, V value) { 35 | if (collector != null) { 36 | Serializer keySerializer = serialization.keySerializer(); 37 | Serializer valueSerializer = serialization.valueSerializer(); 38 | collector.send(this.topic, key, value, this.partition, context.timestamp(), keySerializer, valueSerializer); 39 | } 40 | } 41 | } -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/ModelStateStoreSupplier.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import org.apache.kafka.common.serialization.Serde; 4 | import org.apache.kafka.streams.processor.StateStoreSupplier; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by boris on 7/11/17. 11 | * based on https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/scala/io/confluent/examples/streams/algebird/CMSStoreSupplier.scala 12 | */ 13 | public class ModelStateStoreSupplier implements StateStoreSupplier { 14 | 15 | private String name; 16 | private Serde serde; 17 | private boolean loggingEnabled; 18 | private Map logConfig; 19 | 20 | public ModelStateStoreSupplier(String name, Serde serde, boolean loggingEnabled, Map logConfig){ 21 | 22 | this.name = name; 23 | this.serde = serde; 24 | this.loggingEnabled = loggingEnabled; 25 | this.logConfig = logConfig; 26 | } 27 | 28 | public ModelStateStoreSupplier(String name, Serde serde) { 29 | this(name, serde, true, new HashMap<>()); 30 | } 31 | 32 | public ModelStateStoreSupplier(String name, Serde serde, boolean loggingEnabled) { 33 | this(name, serde, loggingEnabled, new HashMap<>()); 34 | } 35 | 36 | @Override public String name() { 37 | return name; 38 | } 39 | 40 | @Override public ModelStateStore get() {return new ModelStateStore(name, loggingEnabled);} 41 | 42 | @Override public Map logConfig() { 43 | return logConfig; 44 | } 45 | 46 | @Override public boolean loggingEnabled() { 47 | return loggingEnabled; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/ReadableModelStateStore.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import com.lightbend.queriablestate.ModelServingInfo; 4 | 5 | /** 6 | * Created by boris on 7/13/17. 7 | */ 8 | public interface ReadableModelStateStore { 9 | ModelServingInfo getCurrentServingInfo(); 10 | } 11 | 12 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/store/StoreState.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.store; 2 | 3 | import com.lightbend.model.java.Model; 4 | import com.lightbend.queriablestate.ModelServingInfo; 5 | 6 | /** 7 | * Created by boris on 7/18/17. 8 | */ 9 | public class StoreState { 10 | private Model currentModel = null; 11 | private Model newModel = null; 12 | private ModelServingInfo currentServingInfo = null; 13 | private ModelServingInfo newServingInfo = null; 14 | 15 | public StoreState() { 16 | currentModel = null; 17 | newModel = null; 18 | currentServingInfo = null; 19 | newServingInfo = null; 20 | } 21 | 22 | public StoreState(Model currentModel, Model newModel, ModelServingInfo currentServingInfo, ModelServingInfo newServingInfo) { 23 | this.currentModel = currentModel; 24 | this.newModel = newModel; 25 | this.currentServingInfo = currentServingInfo; 26 | this.newServingInfo = newServingInfo; 27 | } 28 | 29 | public void zero() { 30 | currentModel = null; 31 | newModel = null; 32 | currentServingInfo = null; 33 | newServingInfo = null; 34 | } 35 | 36 | public Model getCurrentModel() { 37 | return currentModel; 38 | } 39 | 40 | public void setCurrentModel(Model currentModel) { 41 | this.currentModel = currentModel; 42 | } 43 | 44 | public Model getNewModel() { 45 | return newModel; 46 | } 47 | 48 | public void setNewModel(Model newModel) { 49 | this.newModel = newModel; 50 | } 51 | 52 | public ModelServingInfo getCurrentServingInfo() { 53 | return currentServingInfo; 54 | } 55 | 56 | public void setCurrentServingInfo(ModelServingInfo currentServingInfo) { 57 | this.currentServingInfo = currentServingInfo; 58 | } 59 | 60 | public ModelServingInfo getNewServingInfo() { 61 | return newServingInfo; 62 | } 63 | 64 | public void setNewServingInfo(ModelServingInfo newServingInfo) { 65 | this.newServingInfo = newServingInfo; 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/withstore/DataProcessorWithStore.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.withstore; 2 | 3 | import com.lightbend.model.Winerecord; 4 | import com.lightbend.modelserver.store.ModelStateStore; 5 | import com.lightbend.modelserver.support.java.DataConverter; 6 | import com.lightbend.queriablestate.ModelServingInfo; 7 | import org.apache.kafka.streams.processor.AbstractProcessor; 8 | import org.apache.kafka.streams.processor.ProcessorContext; 9 | 10 | import java.util.Objects; 11 | import java.util.Optional; 12 | 13 | /** 14 | * Created by boris on 7/12/17. 15 | * used 16 | * https://github.com/bbejeck/kafka-streams/blob/master/src/main/java/bbejeck/processor/stocks/StockSummaryProcessor.java 17 | */ 18 | public class DataProcessorWithStore extends AbstractProcessor { 19 | 20 | private ModelStateStore modelStore; 21 | private ProcessorContext context; 22 | 23 | @Override 24 | public void process(byte[] key, byte[] value) { 25 | Optional dataRecord = DataConverter.convertData(value); 26 | if(!dataRecord.isPresent()) { 27 | // context().commit(); 28 | return; // Bad record 29 | } 30 | if(modelStore.getNewModel() != null){ 31 | // update the model 32 | if(modelStore.getCurrentModel() != null) 33 | modelStore.getCurrentModel().cleanup(); 34 | modelStore.setCurrentModel(modelStore.getNewModel()); 35 | modelStore.setCurrentServingInfo(new ModelServingInfo(modelStore.getNewServingInfo().getName(), 36 | modelStore.getNewServingInfo().getDescription(), System.currentTimeMillis())); 37 | modelStore.setNewServingInfo(null); 38 | modelStore.setNewModel(null); 39 | } 40 | // Actually score 41 | if(modelStore.getCurrentModel() == null) { 42 | // No model currently 43 | System.out.println("No model available - skipping"); 44 | // context().forward(key,Optional.empty()); 45 | // context().commit(); 46 | } 47 | else{ 48 | // Score the model 49 | long start = System.currentTimeMillis(); 50 | double quality = (double) modelStore.getCurrentModel().score(dataRecord.get()); 51 | long duration = System.currentTimeMillis() - start; 52 | modelStore.getCurrentServingInfo().update(duration); 53 | System.out.println("Calculated quality - " + quality + " in " + duration + "ms"); 54 | // context().forward(key,Optional.of(quality)); 55 | // context().commit(); 56 | } 57 | 58 | } 59 | 60 | @Override 61 | public void init(ProcessorContext context) { 62 | this.context = context; 63 | this.context.schedule(10000); 64 | modelStore = (ModelStateStore) this.context.getStateStore("modelStore"); 65 | Objects.requireNonNull(modelStore, "State store can't be null"); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/withstore/ModelProcessorWithStore.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.withstore; 2 | 3 | import com.lightbend.model.Modeldescriptor; 4 | import com.lightbend.model.java.Model; 5 | import com.lightbend.model.java.ModelFactory; 6 | import com.lightbend.model.java.PMML.PMMLModelFactory; 7 | import com.lightbend.model.java.tensorflow.TensorflowModelFactory; 8 | import com.lightbend.modelserver.store.ModelStateStore; 9 | import com.lightbend.modelserver.support.java.DataConverter; 10 | import com.lightbend.modelserver.support.java.ModelToServe; 11 | import com.lightbend.queriablestate.ModelServingInfo; 12 | import org.apache.kafka.streams.processor.AbstractProcessor; 13 | import org.apache.kafka.streams.processor.ProcessorContext; 14 | 15 | import java.util.HashMap; 16 | import java.util.Map; 17 | import java.util.Objects; 18 | import java.util.Optional; 19 | 20 | /** 21 | * Created by boris on 7/12/17. 22 | */ 23 | public class ModelProcessorWithStore extends AbstractProcessor { 24 | 25 | private static final Map factories = new HashMap() { 26 | { 27 | put(Modeldescriptor.ModelDescriptor.ModelType.TENSORFLOW.getNumber(), TensorflowModelFactory.getInstance()); 28 | put(Modeldescriptor.ModelDescriptor.ModelType.PMML.getNumber(), PMMLModelFactory.getInstance()); 29 | } 30 | }; 31 | private ModelStateStore modelStore; 32 | private ProcessorContext context; 33 | 34 | @Override 35 | public void process(byte[] key, byte[] value) { 36 | 37 | Optional descriptor = DataConverter.convertModel(value); 38 | if(!descriptor.isPresent()){ 39 | return; // Bad record 40 | } 41 | ModelToServe model = descriptor.get(); 42 | System.out.println("New scoring model " + model); 43 | if(model.getModelData() == null) { 44 | System.out.println("Location based model is not yet supported"); 45 | return; 46 | } 47 | ModelFactory factory = factories.get(model.getModelType().ordinal()); 48 | if(factory == null){ 49 | System.out.println("Bad model type " + model.getModelType()); 50 | return; 51 | } 52 | Optional current = factory.create(model); 53 | if(current.isPresent()) { 54 | modelStore.setNewModel(current.get()); 55 | modelStore.setNewServingInfo(new ModelServingInfo(model.getName(), model.getDescription(), 0)); 56 | return; 57 | } 58 | } 59 | 60 | @Override 61 | public void init(ProcessorContext context) { 62 | this.context = context; 63 | this.context.schedule(10000); 64 | modelStore = (ModelStateStore) this.context.getStateStore("modelStore"); 65 | Objects.requireNonNull(modelStore, "State store can't be null"); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/modelserver/withstore/ModelServerWithStore.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.withstore; 2 | 3 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters; 4 | import com.lightbend.modelserver.store.ModelStateSerde; 5 | import com.lightbend.modelserver.store.ModelStateStoreSupplier; 6 | import com.lightbend.modelserver.store.StoreState; 7 | import com.lightbend.queriablestate.QueriesRestService; 8 | import org.apache.kafka.common.serialization.ByteArrayDeserializer; 9 | import org.apache.kafka.common.serialization.Serde; 10 | import org.apache.kafka.streams.KafkaStreams; 11 | import org.apache.kafka.streams.StreamsConfig; 12 | import org.apache.kafka.streams.kstream.KStreamBuilder; 13 | 14 | import java.io.File; 15 | import java.nio.file.Files; 16 | import java.util.Properties; 17 | 18 | /** 19 | * Created by boris on 6/28/17. 20 | */ 21 | @SuppressWarnings("Duplicates") 22 | public class ModelServerWithStore { 23 | 24 | final static int port=8888; // Port for queryable state 25 | 26 | public static void main(String [ ] args) throws Throwable { 27 | 28 | Properties streamsConfiguration = new Properties(); 29 | // Give the Streams application a unique name. The name must be unique in the Kafka cluster 30 | // against which the application is run. 31 | streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "interactive-queries-example"); 32 | streamsConfiguration.put(StreamsConfig.CLIENT_ID_CONFIG, "interactive-queries-example-client"); 33 | // Where to find Kafka broker(s). 34 | streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, ApplicationKafkaParameters.LOCAL_KAFKA_BROKER); 35 | // Provide the details of our embedded http service that we'll use to connect to this streams 36 | // instance and discover locations of stores. 37 | streamsConfiguration.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + port); 38 | final File example = Files.createTempDirectory(new File("/tmp").toPath(), "example").toFile(); 39 | streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, example.getPath()); 40 | // Create topology 41 | final KafkaStreams streams = createStreams(streamsConfiguration); 42 | streams.cleanUp(); 43 | streams.start(); 44 | // Start the Restful proxy for servicing remote access to state stores 45 | final QueriesRestService restService = startRestProxy(streams, port); 46 | // Add shutdown hook to respond to SIGTERM and gracefully close Kafka Streams 47 | Runtime.getRuntime().addShutdownHook(new Thread(() -> { 48 | try { 49 | streams.close(); 50 | restService.stop(); 51 | } catch (Exception e) { 52 | // ignored 53 | } 54 | })); 55 | } 56 | 57 | static KafkaStreams createStreams(final Properties streamsConfiguration) { 58 | 59 | Serde stateSerde = new ModelStateSerde(); 60 | ByteArrayDeserializer deserializer = new ByteArrayDeserializer(); 61 | ModelStateStoreSupplier storeSupplier = new ModelStateStoreSupplier("modelStore", stateSerde); 62 | 63 | 64 | KStreamBuilder builder = new KStreamBuilder(); 65 | // Data input streams 66 | 67 | builder.addSource("data-source", deserializer, deserializer, ApplicationKafkaParameters.DATA_TOPIC) 68 | .addProcessor("ProcessData", DataProcessorWithStore::new, "data-source"); 69 | builder.addSource("model-source", deserializer, deserializer, ApplicationKafkaParameters.MODELS_TOPIC) 70 | .addProcessor("ProcessModels", ModelProcessorWithStore::new, "model-source"); 71 | builder.addStateStore(storeSupplier, "ProcessData", "ProcessModels"); 72 | 73 | 74 | return new KafkaStreams(builder, streamsConfiguration); 75 | } 76 | 77 | static QueriesRestService startRestProxy(final KafkaStreams streams, final int port) throws Exception { 78 | final QueriesRestService restService = new QueriesRestService(streams); 79 | restService.start(port); 80 | return restService; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/queriablestate/HostStoreInfo.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.queriablestate; 2 | 3 | // https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/java/io/confluent/examples/streams/interactivequeries/WordCountInteractiveQueriesRestService.java 4 | 5 | import java.util.Objects; 6 | import java.util.Set; 7 | 8 | public class HostStoreInfo { 9 | 10 | private String host; 11 | private int port; 12 | private Set storeNames; 13 | 14 | public HostStoreInfo(){} 15 | 16 | public HostStoreInfo(final String host, final int port, final Set storeNames) { 17 | this.host = host; 18 | this.port = port; 19 | this.storeNames = storeNames; 20 | } 21 | 22 | public String getHost() { 23 | return host; 24 | } 25 | 26 | public void setHost(final String host) { 27 | this.host = host; 28 | } 29 | 30 | public int getPort() { 31 | return port; 32 | } 33 | 34 | public void setPort(final int port) { 35 | this.port = port; 36 | } 37 | 38 | public Set getStoreNames() { 39 | return storeNames; 40 | } 41 | 42 | public void setStoreNames(final Set storeNames) { 43 | this.storeNames = storeNames; 44 | } 45 | 46 | @Override 47 | public String toString() { 48 | return "HostStoreInfo{" + 49 | "host='" + host + '\'' + 50 | ", port=" + port + 51 | ", storeNames=" + storeNames + 52 | '}'; 53 | } 54 | 55 | @Override 56 | public boolean equals(final Object o) { 57 | if (this == o) { 58 | return true; 59 | } 60 | if (o == null || getClass() != o.getClass()) { 61 | return false; 62 | } 63 | final HostStoreInfo that = (HostStoreInfo) o; 64 | return port == that.port && 65 | Objects.equals(host, that.host) && 66 | Objects.equals(storeNames, that.storeNames); 67 | } 68 | 69 | @Override 70 | public int hashCode() { 71 | return Objects.hash(host, port, storeNames); 72 | } 73 | } -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/queriablestate/MetadataService.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.queriablestate; 2 | 3 | import org.apache.kafka.streams.KafkaStreams; 4 | import org.apache.kafka.streams.state.StreamsMetadata; 5 | 6 | import java.util.Collection; 7 | import java.util.List; 8 | import java.util.stream.Collectors; 9 | 10 | /** 11 | * Looks up StreamsMetadata from KafkaStreams and converts the results 12 | * into Beans that can be JSON serialized via Jersey. 13 | * https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/java/io/confluent/examples/streams/interactivequeries/MetadataService.java 14 | */ 15 | public class MetadataService { 16 | 17 | private final KafkaStreams streams; 18 | 19 | public MetadataService(final KafkaStreams streams) { 20 | this.streams = streams; 21 | } 22 | 23 | /** 24 | * Get the metadata for all of the instances of this Kafka Streams application 25 | * @return List of {@link HostStoreInfo} 26 | */ 27 | public List streamsMetadata() { 28 | // Get metadata for all of the instances of this Kafka Streams application 29 | final Collection metadata = streams.allMetadata(); 30 | return mapInstancesToHostStoreInfo(metadata); 31 | } 32 | 33 | /** 34 | * Get the metadata for all instances of this Kafka Streams application that currently 35 | * has the provided store. 36 | * @param store The store to locate 37 | * @return List of {@link HostStoreInfo} 38 | */ 39 | public List streamsMetadataForStore(final String store) { 40 | // Get metadata for all of the instances of this Kafka Streams application hosting the store 41 | final Collection metadata = streams.allMetadataForStore(store); 42 | return mapInstancesToHostStoreInfo(metadata); 43 | } 44 | 45 | 46 | private List mapInstancesToHostStoreInfo(final Collection metadatas) { 47 | return metadatas.stream().map(metadata -> new HostStoreInfo(metadata.host(), 48 | metadata.port(), 49 | metadata.stateStoreNames())) 50 | // addCustomStore(metadata.stateStoreNames()))) 51 | .collect(Collectors.toList()); 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/queriablestate/ModelServingInfo.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.queriablestate; 2 | 3 | import java.util.Objects; 4 | 5 | public class ModelServingInfo { 6 | 7 | private String name; 8 | private String description; 9 | private long since; 10 | private long invocations; 11 | private double duration; 12 | private long min; 13 | private long max; 14 | 15 | public ModelServingInfo(){} 16 | 17 | public ModelServingInfo(final String name, final String description, final long since) { 18 | this.name = name; 19 | this.description = description; 20 | this.since = since; 21 | this.invocations = 0; 22 | this.duration = 0.; 23 | this.min = Long.MAX_VALUE; 24 | this.max = Long.MIN_VALUE; 25 | } 26 | 27 | public ModelServingInfo(final String name, final String description, final long since, final long invocations, 28 | final double duration, final long min, final long max) { 29 | this.name = name; 30 | this.description = description; 31 | this.since = since; 32 | this.invocations = invocations; 33 | this.duration = duration; 34 | this.min = min; 35 | this.max = max; 36 | } 37 | 38 | public void update(long execution){ 39 | invocations++; 40 | duration += execution; 41 | if(execution < min) min = execution; 42 | if(execution > max) max = execution; 43 | } 44 | 45 | public String getName() {return name;} 46 | 47 | public void setName(String name) {this.name = name;} 48 | 49 | public String getDescription() {return description;} 50 | 51 | public void setDescription(String description) {this.description = description;} 52 | 53 | public long getSince() {return since;} 54 | 55 | public void setSince(long since) {this.since = since;} 56 | 57 | public long getInvocations() {return invocations;} 58 | 59 | public void setInvocations(long invocations) {this.invocations = invocations;} 60 | 61 | public double getDuration() {return duration;} 62 | 63 | public void setDuration(double duration) {this.duration = duration;} 64 | 65 | public long getMin() {return min;} 66 | 67 | public void setMin(long min) {this.min = min;} 68 | 69 | public long getMax() {return max;} 70 | 71 | public void setMax(long max) {this.max = max;} 72 | 73 | @Override 74 | public String toString() { 75 | return "ModelServingInfo{" + 76 | "name='" + name + '\'' + 77 | ", description='" + description + '\'' + 78 | ", since=" + since + 79 | ", invocations=" + invocations + 80 | ", duration=" + duration + 81 | ", min=" + min + 82 | ", max=" + max + 83 | '}'; 84 | } 85 | 86 | @Override 87 | public boolean equals(final Object o) { 88 | if (this == o) { 89 | return true; 90 | } 91 | if (o == null || getClass() != o.getClass()) { 92 | return false; 93 | } 94 | final ModelServingInfo that = (ModelServingInfo) o; 95 | return name.equals(that.name) && 96 | description.equals(that.description); 97 | } 98 | 99 | @Override 100 | public int hashCode() { 101 | return Objects.hash(name, description); 102 | } 103 | } -------------------------------------------------------------------------------- /Kafkastreamsserver/src/main/java/com/lightbend/queriablestate/QueriesRestService.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.queriablestate; 2 | 3 | import com.lightbend.modelserver.store.ModelStateStore; 4 | import com.lightbend.modelserver.store.ReadableModelStateStore; 5 | import org.apache.kafka.streams.KafkaStreams; 6 | import org.eclipse.jetty.server.Server; 7 | import org.eclipse.jetty.servlet.ServletContextHandler; 8 | import org.eclipse.jetty.servlet.ServletHolder; 9 | import org.glassfish.jersey.jackson.JacksonFeature; 10 | import org.glassfish.jersey.server.ResourceConfig; 11 | import org.glassfish.jersey.servlet.ServletContainer; 12 | 13 | import javax.ws.rs.*; 14 | import javax.ws.rs.core.MediaType; 15 | import java.util.List; 16 | 17 | /** 18 | * A simple REST proxy that runs embedded in the {@link com.lightbend.modelserver.ModelServer}. This is used to 19 | * demonstrate how a developer can use the Interactive Queries APIs exposed by Kafka Streams to 20 | * locate and query the State Stores within a Kafka Streams Application. 21 | * https://github.com/confluentinc/examples/blob/3.2.x/kafka-streams/src/main/java/io/confluent/examples/streams/interactivequeries/WordCountInteractiveQueriesRestService.java 22 | */ 23 | @Path("state") 24 | public class QueriesRestService { 25 | 26 | private final KafkaStreams streams; 27 | private final MetadataService metadataService; 28 | private Server jettyServer; 29 | 30 | public QueriesRestService(final KafkaStreams streams) { 31 | this.streams = streams; 32 | this.metadataService = new MetadataService(streams); 33 | } 34 | 35 | /** 36 | * Get the metadata for all of the instances of this Kafka Streams application 37 | * @return List of {@link HostStoreInfo} 38 | */ 39 | @GET() 40 | @Path("/instances") 41 | @Produces(MediaType.APPLICATION_JSON) 42 | public List streamsMetadata() { 43 | return metadataService.streamsMetadata(); 44 | } 45 | 46 | /** 47 | * Get the metadata for all instances of this Kafka Streams application that currently 48 | * has the provided store. 49 | * @param store The store to locate 50 | * @return List of {@link HostStoreInfo} 51 | */ 52 | @GET() 53 | @Path("/instances/{storeName}") 54 | @Produces(MediaType.APPLICATION_JSON) 55 | public List streamsMetadataForStore(@PathParam("storeName") String store) { 56 | return metadataService.streamsMetadataForStore(store); 57 | } 58 | 59 | /** 60 | * Get current value of the of state 61 | * @return {@link ModelServingInfo} representing the key-value pair 62 | */ 63 | @GET 64 | @Path("{storeName}/value") 65 | @Produces(MediaType.APPLICATION_JSON) 66 | public ModelServingInfo servingInfo(@PathParam("storeName") final String storeName) { 67 | // Get the Store 68 | final ReadableModelStateStore store = streams.store(storeName, new ModelStateStore.ModelStateStoreType()); 69 | if (store == null) { 70 | throw new NotFoundException(); 71 | } 72 | return store.getCurrentServingInfo(); 73 | } 74 | 75 | /** 76 | * Start an embedded Jetty Server on the given port 77 | * @param port port to run the Server on 78 | * @throws Exception 79 | */ 80 | public void start(final int port) throws Exception { 81 | ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS); 82 | context.setContextPath("/"); 83 | 84 | jettyServer = new Server(port); 85 | jettyServer.setHandler(context); 86 | 87 | ResourceConfig rc = new ResourceConfig(); 88 | rc.register(this); 89 | rc.register(JacksonFeature.class); 90 | 91 | ServletContainer sc = new ServletContainer(rc); 92 | ServletHolder holder = new ServletHolder(sc); 93 | context.addServlet(holder, "/*"); 94 | 95 | jettyServer.start(); 96 | } 97 | 98 | /** 99 | * Stop the Jetty Server 100 | * @throws Exception 101 | */ 102 | public void stop() throws Exception { 103 | if (jettyServer != null) { 104 | jettyServer.stop(); 105 | } 106 | } 107 | 108 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model serving 2 | 3 | This is an umbrella project for all things model serving that is comprised of multiple projects 4 | 5 | -**akkaserver** - implementation of model scoring and statistics serving using Akka streams and Akka HTTP 6 | 7 | 8 | -**flinkserver** - implementation of model scoring and queryable state using Flink. Both 9 | key-based and partition-based approach are implemented here 10 | 11 | -**kafkaclient** - generic client used for testing of all implementations (except serving samples) 12 | Reads data files, split them into records, converts to protobuf implementations and publishes them to Kafka 13 | 14 | -**kafkaconfiguration** - simple module containing class with Kafka definitions - server location, 15 | topics, etc. used by all applications 16 | 17 | -**kafkastreamserver** - implementation of model scoring and queryable state using Kafka streams 18 | Also includes implementation of custom Kafka streams store. 19 | 20 | -**model** - implementation of support classes representing model and model factories used by all applications. 21 | Because Kafka streams is Java and the rest of implementations are Scala, there are two versions of these 22 | classes - Java and Scala 23 | 24 | -**serving samples** - This module contains simple implementations of model scoring using PMML and 25 | tensorflow model definitions. It is not using any streaming frameworks - just straight Scala code 26 | 27 | -**protobufs** - a module containing protobufs that are used for all streaming frameworks. 28 | This protobufs describe model and data definition in the stream. Because Kafka streams is Java 29 | and the rest of implementations are Scala, both Java and Scala implementations of protobufs are 30 | generated 31 | 32 | 33 | -**sparkML** - examples of using SparkML for machine learning and exporting results to PMML 34 | using JPMML evaluator for Spark - https://github.com/jpmml/jpmml-evaluator-spark 35 | 36 | -**sparkserver** - implementation of model scoring using Spark 37 | 38 | -**utils** - a module containing some utility code. Most importantly it contains embedded Kafka implementation 39 | which can be used for testing in the absence of kafka server. In order to use it, just add these 40 | lines to your code: 41 | 42 | 43 | // Create embedded Kafka and topics 44 | EmbeddedSingleNodeKafkaCluster.start() // Create and start the cluster 45 | EmbeddedSingleNodeKafkaCluster.createTopic(DATA_TOPIC) // Add topic 46 | EmbeddedSingleNodeKafkaCluster.createTopic(MODELS_TOPIC) // Add topic 47 | 48 | If you are using both server and client add kafka embedded only to server and start it before the client 49 | In addition to embedded kafka this module there are some utility classes used by all applications. 50 | Because Kafka streams is Java and the rest of implementations are Scala, there are two versions of these 51 | classes - Java and Scala 52 | 53 | -**data** - a directory of data files used as sources for all applications 54 | 55 | Not included in this project are: 56 | 57 | -**Beam implementation** - Beam Flink runner is still on Scala 2.10 so it is in its own 58 | separate project - https://github.com/typesafehub/fdp-beam-modelServer 59 | 60 | -**Python/Tensorflow/Keras** - is it is a Python so it is in its own 61 | separate project - https://github.com/typesafehub/fdp-tensorflow-python-examples 62 | -------------------------------------------------------------------------------- /akkaserver/src/main/scala/com/lightbend/modelserver/AkkaModelServer.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | import akka.actor.ActorSystem 4 | import akka.http.scaladsl.server.Route 5 | import akka.kafka.{ConsumerSettings, Subscriptions} 6 | import akka.kafka.scaladsl.Consumer 7 | import akka.stream.{ActorMaterializer, SourceShape} 8 | import akka.stream.scaladsl.{GraphDSL, Sink, Source} 9 | import akka.util.Timeout 10 | 11 | import scala.concurrent.duration._ 12 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 13 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters.{DATA_GROUP, LOCAL_KAFKA_BROKER, MODELS_GROUP} 14 | import com.lightbend.model.winerecord.WineRecord 15 | import com.lightbend.modelserver.kafka.EmbeddedSingleNodeKafkaCluster 16 | import org.apache.kafka.clients.consumer.ConsumerConfig 17 | import org.apache.kafka.common.serialization.ByteArrayDeserializer 18 | import akka.http.scaladsl.Http 19 | import com.lightbend.modelserver.modelServer.ReadableModelStateStore 20 | import com.lightbend.modelserver.queriablestate.QueriesAkkaHttpResource 21 | import com.lightbend.modelserver.support.scala.{DataReader, ModelToServe} 22 | 23 | /** 24 | * Created by boris on 7/21/17. 25 | */ 26 | object AkkaModelServer { 27 | 28 | implicit val system = ActorSystem("ModelServing") 29 | implicit val materializer = ActorMaterializer() 30 | implicit val executionContext = system.dispatcher 31 | 32 | val dataConsumerSettings = ConsumerSettings(system, new ByteArrayDeserializer, new ByteArrayDeserializer) 33 | .withBootstrapServers(LOCAL_KAFKA_BROKER) 34 | .withGroupId(DATA_GROUP) 35 | .withProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") 36 | 37 | val modelConsumerSettings = ConsumerSettings(system, new ByteArrayDeserializer, new ByteArrayDeserializer) 38 | .withBootstrapServers(LOCAL_KAFKA_BROKER) 39 | .withGroupId(MODELS_GROUP) 40 | .withProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") 41 | 42 | def main(args: Array[String]): Unit = { 43 | 44 | 45 | import ApplicationKafkaParameters._ 46 | 47 | // Create embedded Kafka and topics 48 | // EmbeddedSingleNodeKafkaCluster.start() 49 | // EmbeddedSingleNodeKafkaCluster.createTopic(DATA_TOPIC) 50 | // EmbeddedSingleNodeKafkaCluster.createTopic(MODELS_TOPIC) 51 | 52 | val modelStream: Source[ModelToServe, Consumer.Control] = 53 | Consumer.atMostOnceSource(modelConsumerSettings, Subscriptions.topics(MODELS_TOPIC)) 54 | .map(record => ModelToServe.fromByteArray(record.value())).filter(_.isSuccess).map(_.get) 55 | 56 | val dataStream: Source[WineRecord, Consumer.Control] = 57 | Consumer.atMostOnceSource(dataConsumerSettings, Subscriptions.topics(DATA_TOPIC)) 58 | .map(record => DataReader.fromByteArray(record.value())).filter(_.isSuccess).map(_.get) 59 | 60 | val model = new ModelStage() 61 | 62 | def keepModelMaterializedValue[M1, M2, M3](m1: M1, m2: M2, m3: M3): M3 = m3 63 | 64 | val modelPredictions : Source[Option[Double], ReadableModelStateStore] = Source.fromGraph( 65 | GraphDSL.create(dataStream, modelStream, model)(keepModelMaterializedValue) { 66 | implicit builder => (d, m, w) => 67 | import GraphDSL.Implicits._ 68 | 69 | // wire together the input streams with the model stage (2 in, 1 out) 70 | /* 71 | dataStream --> | | 72 | | model | -> predictions 73 | modelStream -> | | 74 | */ 75 | 76 | d ~> w.dataRecordIn 77 | m ~> w.modelRecordIn 78 | SourceShape(w.scoringResultOut) 79 | } 80 | ) 81 | 82 | 83 | val materializedReadableModelStateStore: ReadableModelStateStore = 84 | modelPredictions 85 | .map(println(_)) 86 | .to(Sink.ignore) // we do not read the results directly 87 | .run() // we run the stream, materializing the stage's StateStore 88 | 89 | startRest(materializedReadableModelStateStore) 90 | } 91 | 92 | def startRest(service : ReadableModelStateStore) : Unit = { 93 | 94 | implicit val timeout = Timeout(10 seconds) 95 | val host = "localhost" 96 | val port = 5000 97 | val routes: Route = QueriesAkkaHttpResource.storeRoutes(service) 98 | 99 | Http().bindAndHandle(routes, host, port) map 100 | { binding => println(s"REST interface bound to ${binding.localAddress}") } recover { case ex => 101 | println(s"REST interface could not bind to $host:$port", ex.getMessage) 102 | } 103 | } 104 | } -------------------------------------------------------------------------------- /akkaserver/src/main/scala/com/lightbend/modelserver/ModelStage.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | import akka.stream._ 4 | import akka.stream.stage.{GraphStageLogicWithLogging, _} 5 | import com.lightbend.model.modeldescriptor.ModelDescriptor 6 | import com.lightbend.model.winerecord.WineRecord 7 | import com.lightbend.model.scala.PMML.PMMLModel 8 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 9 | import com.lightbend.modelserver.modelServer.ReadableModelStateStore 10 | import com.lightbend.model.scala.Model 11 | import com.lightbend.modelserver.support.scala.{ModelToServe, ModelToServeStats} 12 | 13 | import scala.collection.immutable 14 | 15 | class ModelStage extends GraphStageWithMaterializedValue[ModelStageShape, ReadableModelStateStore] { 16 | 17 | private val factories = Map( 18 | ModelDescriptor.ModelType.PMML -> PMMLModel, 19 | ModelDescriptor.ModelType.TENSORFLOW -> TensorFlowModel) 20 | 21 | override val shape: ModelStageShape = new ModelStageShape 22 | 23 | override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, ReadableModelStateStore) = { 24 | 25 | 26 | val logic = new GraphStageLogicWithLogging(shape) { 27 | // state must be kept in the Logic instance, since it is created per stream materialization 28 | private var currentModel : Option[Model] = None 29 | private var newModel : Option[Model] = None 30 | var currentState : Option[ModelToServeStats] = None // exposed in materialized value 31 | private var newState : Option[ModelToServeStats] = None 32 | 33 | 34 | 35 | // TODO the pulls needed to get the stage actually pulling from the input streams 36 | override def preStart(): Unit = { 37 | tryPull(shape.modelRecordIn) 38 | tryPull(shape.dataRecordIn) 39 | } 40 | 41 | setHandler(shape.modelRecordIn, new InHandler { 42 | override def onPush(): Unit = { 43 | val model = grab(shape.modelRecordIn) 44 | println(s"New model - $model") 45 | newState = Some(new ModelToServeStats(model)) 46 | newModel = factories.get(model.modelType) match{ 47 | case Some(factory) => factory.create(model) 48 | case _ => None 49 | } 50 | pull(shape.modelRecordIn) 51 | } 52 | }) 53 | 54 | setHandler(shape.dataRecordIn, new InHandler { 55 | override def onPush(): Unit = { 56 | val record = grab(shape.dataRecordIn) 57 | newModel match { 58 | case Some(model) => { 59 | // close current model first 60 | currentModel match { 61 | case Some(m) => m.cleanup() 62 | case _ => 63 | } 64 | // Update model 65 | currentModel = Some(model) 66 | currentState = newState 67 | newModel = None 68 | } 69 | case _ => 70 | } 71 | currentModel match { 72 | case Some(model) => { 73 | val start = System.currentTimeMillis() 74 | val quality = model.score(record.asInstanceOf[AnyVal]).asInstanceOf[Double] 75 | val duration = System.currentTimeMillis() - start 76 | println(s"Calculated quality - $quality calculated in $duration ms") 77 | currentState.get.incrementUsage(duration) 78 | push(shape.scoringResultOut, Some(quality)) 79 | } 80 | case _ => { 81 | println("No model available - skipping") 82 | push(shape.scoringResultOut, None) 83 | } 84 | } 85 | pull(shape.dataRecordIn) 86 | } 87 | }) 88 | 89 | setHandler(shape.scoringResultOut, new OutHandler { 90 | override def onPull(): Unit = { 91 | } 92 | }) 93 | } 94 | // we materialize this value so whoever runs the stream can get the current serving info 95 | val readableModelStateStore = new ReadableModelStateStore() { 96 | override def getCurrentServingInfo: ModelToServeStats = logic.currentState.getOrElse(ModelToServeStats.empty) 97 | } 98 | new Tuple2[GraphStageLogic, ReadableModelStateStore](logic, readableModelStateStore) 99 | } 100 | } 101 | 102 | class ModelStageShape() extends Shape { 103 | var dataRecordIn = Inlet[WineRecord]("dataRecordIn") 104 | var modelRecordIn = Inlet[ModelToServe]("modelRecordIn") 105 | var scoringResultOut = Outlet[Option[Double]]("scoringOut") 106 | 107 | def this(dataRecordIn: Inlet[WineRecord], modelRecordIn: Inlet[ModelToServe], scoringResultOut: Outlet[Option[Double]]) { 108 | this() 109 | this.dataRecordIn = dataRecordIn 110 | this.modelRecordIn = modelRecordIn 111 | this.scoringResultOut = scoringResultOut 112 | } 113 | 114 | override def deepCopy(): Shape = new ModelStageShape(dataRecordIn.carbonCopy(), modelRecordIn.carbonCopy(), scoringResultOut) 115 | 116 | override def copyFromPorts(inlets: immutable.Seq[Inlet[_]], outlets: immutable.Seq[Outlet[_]]): Shape = 117 | new ModelStageShape( 118 | inlets(0).asInstanceOf[Inlet[WineRecord]], 119 | inlets(1).asInstanceOf[Inlet[ModelToServe]], 120 | outlets(0).asInstanceOf[Outlet[Option[Double]]]) 121 | 122 | override val inlets = List(dataRecordIn, modelRecordIn) 123 | override val outlets = List(scoringResultOut) 124 | } 125 | -------------------------------------------------------------------------------- /akkaserver/src/main/scala/com/lightbend/modelserver/ReadableModelStateStore.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.modelServer 2 | 3 | import com.lightbend.modelserver.support.scala.ModelToServeStats 4 | 5 | 6 | /** 7 | * Created by boris on 7/21/17. 8 | */ 9 | trait ReadableModelStateStore { 10 | def getCurrentServingInfo: ModelToServeStats 11 | } 12 | 13 | -------------------------------------------------------------------------------- /akkaserver/src/main/scala/com/lightbend/queriablestate/QueriesAkkaHttpResource.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.queriablestate 2 | 3 | 4 | import akka.http.scaladsl.server.Route 5 | import akka.http.scaladsl.server.Directives._ 6 | import com.lightbend.modelserver.modelServer.ReadableModelStateStore 7 | import com.lightbend.modelserver.support.scala.ModelToServeStats 8 | import de.heikoseeberger.akkahttpjackson.JacksonSupport 9 | 10 | object QueriesAkkaHttpResource extends JacksonSupport { 11 | 12 | def storeRoutes(predictions: ReadableModelStateStore): Route = 13 | get { 14 | path("stats") { 15 | val info: ModelToServeStats = predictions.getCurrentServingInfo 16 | complete(info) 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | 2 | name := "ModelServing" 3 | 4 | version := "1.0" 5 | 6 | scalaVersion in ThisBuild := "2.11.11" 7 | 8 | 9 | lazy val protobufs = (project in file("./protobufs")) 10 | .settings( 11 | PB.targets in Compile := Seq( 12 | PB.gens.java -> (sourceManaged in Compile).value, 13 | scalapb.gen(javaConversions=true) -> (sourceManaged in Compile).value 14 | ) 15 | ) 16 | 17 | lazy val kafkaclient = (project in file("./kafkaclient")) 18 | .settings(libraryDependencies ++= Dependencies.kafkabaseDependencies) 19 | .dependsOn(protobufs, kafkaconfiguration) 20 | 21 | lazy val model = (project in file("./model")) 22 | .settings(libraryDependencies ++= Dependencies.modelsDependencies) 23 | .dependsOn(protobufs, utils) 24 | 25 | 26 | lazy val kafkastreamsserver = (project in file("./Kafkastreamsserver")) 27 | .settings(libraryDependencies ++= Dependencies.kafkaDependencies ++ Dependencies.webDependencies) 28 | .dependsOn(model, kafkaconfiguration, utils) 29 | 30 | lazy val akkaServer = (project in file("./akkaserver")) 31 | .settings(libraryDependencies ++= Dependencies.kafkaDependencies ++ Dependencies.akkaServerDependencies 32 | ++ Dependencies.modelsDependencies) 33 | .dependsOn(model, kafkaconfiguration, utils) 34 | 35 | lazy val flinkserver = (project in file("./flinkserver")) 36 | .settings(libraryDependencies ++= Dependencies.flinkDependencies ++ Seq(Dependencies.joda, Dependencies.akkaslf)) 37 | .settings(dependencyOverrides += "com.typesafe.akka" % "akka-actor-2.11" % "2.3") 38 | .dependsOn(model, kafkaconfiguration, utils) 39 | 40 | lazy val sparkserver = (project in file("./sparkserver")) 41 | .settings(libraryDependencies ++= Dependencies.sparkDependencies) 42 | .settings(dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-core" % "2.8.9") 43 | .settings(dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-databind" % "2.8.9") 44 | .settings(dependencyOverrides += "com.fasterxml.jackson.module" % "jackson-module-scala_2.11" % "2.8.9") 45 | .dependsOn(model, kafkaconfiguration, utils) 46 | 47 | lazy val sparkML = (project in file("./sparkML")) 48 | .settings(libraryDependencies ++= Dependencies.sparkMLDependencies) 49 | .settings(dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-core" % "2.8.9") 50 | .settings(dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-databind" % "2.8.9") 51 | .settings(dependencyOverrides += "com.fasterxml.jackson.module" % "jackson-module-scala_2.11" % "2.8.9") 52 | 53 | 54 | lazy val servingsamples = (project in file("./servingsamples")) 55 | .settings(libraryDependencies ++= Dependencies.modelsDependencies ++ Seq(Dependencies.tensorflowProto)) 56 | 57 | 58 | lazy val kafkaconfiguration = (project in file("./kafkaconfiguration")) 59 | 60 | lazy val utils = (project in file("./utils")) 61 | .settings(libraryDependencies ++= Dependencies.kafkaDependencies ++ Seq(Dependencies.curator)) 62 | .dependsOn(protobufs) 63 | 64 | lazy val root = (project in file(".")). 65 | aggregate(protobufs, kafkaclient, model, utils, kafkaconfiguration, kafkastreamsserver, akkaServer, sparkserver) 66 | 67 | -------------------------------------------------------------------------------- /data/WineQuality/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/typesafehub/fdp-modelserver/7a8cf8bd5fe8476c36822e3520da794cf547eb76/data/WineQuality/saved_model.pb -------------------------------------------------------------------------------- /data/WineQuality/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/typesafehub/fdp-modelserver/7a8cf8bd5fe8476c36822e3520da794cf547eb76/data/WineQuality/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /data/WineQuality/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/typesafehub/fdp-modelserver/7a8cf8bd5fe8476c36822e3520da794cf547eb76/data/WineQuality/variables/variables.index -------------------------------------------------------------------------------- /data/optimized_WineQuality.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/typesafehub/fdp-modelserver/7a8cf8bd5fe8476c36822e3520da794cf547eb76/data/optimized_WineQuality.pb -------------------------------------------------------------------------------- /data/winequalityDesisionTreeRegression.pmml: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 2017-05-03T18:15:04Z 6 |
7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 |
-------------------------------------------------------------------------------- /data/winequalityGeneralizedLinearRegressionGamma.pmml: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 2017-05-03T23:53:03Z 6 |
7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 |
-------------------------------------------------------------------------------- /data/winequalityGeneralizedLinearRegressionGaussian.pmml: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 2017-05-03T23:45:23Z 6 |
7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 |
-------------------------------------------------------------------------------- /data/winequalityLinearRegression.pmml: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | 2017-05-03T17:51:32Z 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 |
-------------------------------------------------------------------------------- /flinkserver/src/main/resources/log4j.properties.bat: -------------------------------------------------------------------------------- 1 | 2 | log4j.rootLogger=INFO, console 3 | 4 | log4j.appender.console=org.apache.log4j.ConsoleAppender 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n -------------------------------------------------------------------------------- /flinkserver/src/main/resources/logback.xml.bat: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{60} %X{sourceThread} - %msg%n 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/BadDataHandler.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | import org.apache.flink.api.common.functions.FlatMapFunction 4 | import org.apache.flink.util.Collector 5 | 6 | 7 | import scala.util.{Failure, Success, Try} 8 | 9 | object BadDataHandler { 10 | def apply[T] = new BadDataHandler[T] 11 | } 12 | 13 | class BadDataHandler[T] extends FlatMapFunction[Try[T], T] { 14 | override def flatMap(t: Try[T], out: Collector[T]): Unit = { 15 | t match { 16 | case Success(t) => out.collect(t) 17 | case Failure(e) => println(s"BAD DATA: ${e.getMessage}") 18 | } 19 | } 20 | } -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/keyed/DataProcessorKeyed.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.keyed 2 | 3 | import com.lightbend.model.modeldescriptor.ModelDescriptor 4 | import com.lightbend.model.scala.PMML.PMMLModel 5 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 6 | import com.lightbend.model.winerecord.WineRecord 7 | import com.lightbend.model.scala.Model 8 | import com.lightbend.model.scala.PMML.PMMLModel 9 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 10 | import com.lightbend.modelserver.typeschema.ModelTypeSerializer 11 | import com.lightbend.modelserver.support.scala.{ModelToServe, ModelToServeStats} 12 | import org.apache.flink.api.common.state.{ListState, ListStateDescriptor, ValueState, ValueStateDescriptor} 13 | import org.apache.flink.api.scala.createTypeInformation 14 | import org.apache.flink.configuration.Configuration 15 | import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} 16 | import org.apache.flink.streaming.api.checkpoint.{CheckpointedFunction, CheckpointedRestoring} 17 | import org.apache.flink.streaming.api.functions.co.CoProcessFunction 18 | import org.apache.flink.util.Collector 19 | 20 | /** 21 | * Created by boris on 5/8/17. 22 | * 23 | * Main class processing data using models 24 | * 25 | * see http://dataartisans.github.io/flink-training/exercises/eventTimeJoin.html for details 26 | */ 27 | 28 | object DataProcessorKeyed { 29 | def apply() = new DataProcessorKeyed 30 | private val factories = Map(ModelDescriptor.ModelType.PMML -> PMMLModel, 31 | ModelDescriptor.ModelType.TENSORFLOW -> TensorFlowModel) 32 | } 33 | 34 | class DataProcessorKeyed extends CoProcessFunction[WineRecord, ModelToServe, Double] 35 | with CheckpointedFunction with CheckpointedRestoring[List[Option[Model]]] { 36 | 37 | // The managed keyed state see https://ci.apache.org/projects/flink/flink-docs-release-1.3/dev/stream/state.html 38 | var modelState: ValueState[ModelToServeStats] = _ 39 | var newModelState: ValueState[ModelToServeStats] = _ 40 | 41 | var currentModel : Option[Model] = None 42 | var newModel : Option[Model] = None 43 | 44 | @transient private var checkpointedState: ListState[Option[Model]] = null 45 | 46 | 47 | override def open(parameters: Configuration): Unit = { 48 | val modelDesc = new ValueStateDescriptor[ModelToServeStats]( 49 | "currentModel", // state name 50 | createTypeInformation[ModelToServeStats]) // type information 51 | modelDesc.setQueryable("currentModel") 52 | 53 | modelState = getRuntimeContext.getState(modelDesc) 54 | val newModelDesc = new ValueStateDescriptor[ModelToServeStats]( 55 | "newModel", // state name 56 | createTypeInformation[ModelToServeStats]) // type information 57 | newModelState = getRuntimeContext.getState(newModelDesc) 58 | } 59 | 60 | override def snapshotState(context: FunctionSnapshotContext): Unit = { 61 | checkpointedState.clear() 62 | checkpointedState.add(currentModel) 63 | checkpointedState.add(newModel) 64 | } 65 | 66 | override def initializeState(context: FunctionInitializationContext): Unit = { 67 | val descriptor = new ListStateDescriptor[Option[Model]] ( 68 | "modelState", 69 | new ModelTypeSerializer) 70 | 71 | checkpointedState = context.getOperatorStateStore.getListState (descriptor) 72 | 73 | if (context.isRestored) { 74 | val iterator = checkpointedState.get().iterator() 75 | currentModel = iterator.next() 76 | newModel = iterator.next() 77 | } 78 | } 79 | 80 | override def restoreState(state: List[Option[Model]]): Unit = { 81 | currentModel = state(0) 82 | newModel = state(1) 83 | } 84 | 85 | override def processElement2(model: ModelToServe, ctx: CoProcessFunction[WineRecord, ModelToServe, Double]#Context, out: Collector[Double]): Unit = { 86 | 87 | import DataProcessorKeyed._ 88 | 89 | println(s"New model - $model") 90 | newModelState.update(new ModelToServeStats(model)) 91 | newModel = factories.get(model.modelType) match { 92 | case Some(factory) => factory.create (model) 93 | case _ => None 94 | } 95 | } 96 | 97 | override def processElement1(record: WineRecord, ctx: CoProcessFunction[WineRecord, ModelToServe, Double]#Context, out: Collector[Double]): Unit = { 98 | 99 | // See if we have update for the model 100 | newModel match { 101 | case Some(model) => { 102 | // Clean up current model 103 | currentModel match { 104 | case Some(m) => m.cleanup() 105 | case _ => 106 | } 107 | // Update model 108 | currentModel = Some(model) 109 | modelState.update(newModelState.value()) 110 | newModel = None 111 | } 112 | case _ => 113 | } 114 | currentModel match { 115 | case Some(model) => { 116 | val start = System.currentTimeMillis() 117 | val quality = model.score(record.asInstanceOf[AnyVal]).asInstanceOf[Double] 118 | val duration = System.currentTimeMillis() - start 119 | modelState.update(modelState.value().incrementUsage(duration)) 120 | println(s"Calculated quality - $quality calculated in $duration ms") 121 | } 122 | case _ => println("No model available - skipping") 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/keyed/ModelServingKeyedJob.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.keyed 2 | 3 | import java.util.Properties 4 | 5 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 6 | import com.lightbend.model.winerecord.WineRecord 7 | import com.lightbend.modelserver.typeschema.ByteArraySchema 8 | import com.lightbend.modelserver.BadDataHandler 9 | import com.lightbend.modelserver.support.scala.ModelToServe 10 | import com.lightbend.modelserver.support.scala.DataReader 11 | import org.apache.flink.api.scala._ 12 | import org.apache.flink.configuration._ 13 | import org.apache.flink.runtime.concurrent.Executors 14 | import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils 15 | import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster 16 | import org.apache.flink.streaming.api.TimeCharacteristic 17 | import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment 18 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer010 19 | 20 | 21 | /** 22 | * Created by boris on 5/9/17. 23 | * loosely based on http://dataartisans.github.io/flink-training/exercises/eventTimeJoin.html approach 24 | * for queriable state 25 | * https://github.com/dataArtisans/flink-queryable_state_demo/blob/master/README.md 26 | * Using Flink min server to enable Queryable data access 27 | * see https://github.com/dataArtisans/flink-queryable_state_demo/blob/master/src/main/java/com/dataartisans/queryablestatedemo/EventCountJob.java 28 | * 29 | * This little application is based on a RichCoProcessFunction which works on a keyed streams. It is applicable 30 | * when a single applications serves multiple different models for different data types. Every model is keyed with 31 | * the type of data what it is designed for. Same key should be present in the data, if it wants to use a specific 32 | * model. 33 | * Scaling of the application is based on the data type - for every key there is a separate instance of the 34 | * RichCoProcessFunction dedicated to this type. All messages of the same type are processed by the same instance 35 | * of RichCoProcessFunction 36 | */ 37 | object ModelServingKeyedJob { 38 | 39 | def main(args: Array[String]): Unit = { 40 | // executeLocal() 41 | executeServer() 42 | } 43 | 44 | // Execute on the local Flink server - to test queariable state 45 | def executeServer() : Unit = { 46 | 47 | // We use a mini cluster here for sake of simplicity, because I don't want 48 | // to require a Flink installation to run this demo. Everything should be 49 | // contained in this JAR. 50 | 51 | val port = 6124 52 | val parallelism = 2 53 | 54 | val config = new Configuration() 55 | config.setInteger(JobManagerOptions.PORT, port) 56 | config.setString(JobManagerOptions.ADDRESS, "localhost"); 57 | config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, parallelism) 58 | // In a non MiniCluster setup queryable state is enabled by default. 59 | config.setBoolean(QueryableStateOptions.SERVER_ENABLE, true) 60 | config.setBoolean(ConfigConstants.LOCAL_START_WEBSERVER, true); 61 | // needed because queryable state server is always disabled with only one TaskManager 62 | config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 2); 63 | 64 | // Create a local Flink server 65 | val flinkCluster = new LocalFlinkMiniCluster( 66 | config, 67 | HighAvailabilityServicesUtils.createHighAvailabilityServices( 68 | config, 69 | Executors.directExecutor(), 70 | HighAvailabilityServicesUtils.AddressResolution.TRY_ADDRESS_RESOLUTION), 71 | false); 72 | try { 73 | // Start server and create environment 74 | flinkCluster.start(true); 75 | 76 | val env = StreamExecutionEnvironment.createRemoteEnvironment("localhost", port) 77 | env.setParallelism(parallelism) 78 | // Build Graph 79 | buildGraph(env) 80 | env.execute() 81 | val jobGraph = env.getStreamGraph.getJobGraph 82 | // Submit to the server and wait for completion 83 | flinkCluster.submitJobAndWait(jobGraph, false) 84 | } catch { 85 | case e: Exception => e.printStackTrace() 86 | } 87 | } 88 | 89 | // Execute localle in the environment 90 | def executeLocal() : Unit = { 91 | val env = StreamExecutionEnvironment.getExecutionEnvironment 92 | buildGraph(env) 93 | System.out.println("[info] Job ID: " + env.getStreamGraph.getJobGraph.getJobID) 94 | env.execute() 95 | } 96 | 97 | // Build execution Graph 98 | def buildGraph(env : StreamExecutionEnvironment) : Unit = { 99 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) 100 | env.enableCheckpointing(5000) 101 | 102 | // configure Kafka consumer 103 | // Data 104 | val dataKafkaProps = new Properties 105 | dataKafkaProps.setProperty("zookeeper.connect", ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 106 | dataKafkaProps.setProperty("bootstrap.servers", ApplicationKafkaParameters.LOCAL_KAFKA_BROKER) 107 | dataKafkaProps.setProperty("group.id", ApplicationKafkaParameters.DATA_GROUP) 108 | // always read the Kafka topic from the current location 109 | dataKafkaProps.setProperty("auto.offset.reset", "latest") 110 | 111 | // Model 112 | val modelKafkaProps = new Properties 113 | modelKafkaProps.setProperty("zookeeper.connect", ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 114 | modelKafkaProps.setProperty("bootstrap.servers", ApplicationKafkaParameters.LOCAL_KAFKA_BROKER) 115 | modelKafkaProps.setProperty("group.id", ApplicationKafkaParameters.MODELS_GROUP) 116 | // always read the Kafka topic from the current location 117 | modelKafkaProps.setProperty("auto.offset.reset", "latest") 118 | 119 | // create a Kafka consumers 120 | // Data 121 | val dataConsumer = new FlinkKafkaConsumer010[Array[Byte]]( 122 | ApplicationKafkaParameters.DATA_TOPIC, 123 | new ByteArraySchema, 124 | dataKafkaProps 125 | ) 126 | 127 | // Model 128 | val modelConsumer = new FlinkKafkaConsumer010[Array[Byte]]( 129 | ApplicationKafkaParameters.MODELS_TOPIC, 130 | new ByteArraySchema, 131 | modelKafkaProps 132 | ) 133 | 134 | // Create input data streams 135 | val modelsStream = env.addSource(modelConsumer) 136 | val dataStream = env.addSource(dataConsumer) 137 | 138 | // Read data from streams 139 | val models = modelsStream.map(ModelToServe.fromByteArray(_)) 140 | .flatMap(BadDataHandler[ModelToServe]) 141 | .keyBy(_.dataType) 142 | val data = dataStream.map(DataReader.fromByteArray(_)) 143 | .flatMap(BadDataHandler[WineRecord]) 144 | .keyBy(_.dataType) 145 | 146 | // Merge streams 147 | data 148 | .connect(models) 149 | .process(DataProcessorKeyed()) 150 | } 151 | } -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/partitioned/DataProcessorMap.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.partitioned 2 | 3 | /** 4 | * Created by boris on 5/14/17. 5 | * 6 | * Main class processing data using models 7 | * 8 | */ 9 | import com.lightbend.model.modeldescriptor.ModelDescriptor 10 | import com.lightbend.model.winerecord.WineRecord 11 | import com.lightbend.model.scala.Model 12 | import com.lightbend.model.scala.PMML.PMMLModel 13 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 14 | import com.lightbend.modelserver.typeschema.ModelTypeSerializer 15 | import com.lightbend.modelserver.support.scala.ModelToServe 16 | import org.apache.flink.api.common.state.{ListState, ListStateDescriptor} 17 | import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext} 18 | import org.apache.flink.streaming.api.checkpoint.{CheckpointedFunction, CheckpointedRestoring} 19 | import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction 20 | import org.apache.flink.util.Collector 21 | 22 | object DataProcessorMap{ 23 | def apply() : DataProcessorMap = new DataProcessorMap() 24 | 25 | private val factories = Map(ModelDescriptor.ModelType.PMML -> PMMLModel, 26 | ModelDescriptor.ModelType.TENSORFLOW -> TensorFlowModel) 27 | } 28 | 29 | class DataProcessorMap extends RichCoFlatMapFunction[WineRecord, ModelToServe, Double] 30 | with CheckpointedFunction with CheckpointedRestoring[List[Option[Model]]] { 31 | 32 | var currentModel : Option[Model] = None 33 | var newModel : Option[Model] = None 34 | @transient private var checkpointedState: ListState[Option[Model]] = null 35 | 36 | override def snapshotState(context: FunctionSnapshotContext): Unit = { 37 | checkpointedState.clear() 38 | checkpointedState.add(currentModel) 39 | checkpointedState.add(newModel) 40 | } 41 | 42 | override def initializeState(context: FunctionInitializationContext): Unit = { 43 | val descriptor = new ListStateDescriptor[Option[Model]] ( 44 | "modelState", 45 | new ModelTypeSerializer) 46 | 47 | checkpointedState = context.getOperatorStateStore.getListState (descriptor) 48 | 49 | if (context.isRestored) { 50 | val iterator = checkpointedState.get().iterator() 51 | currentModel = iterator.next() 52 | newModel = iterator.next() 53 | } 54 | } 55 | 56 | override def restoreState(state: List[Option[Model]]): Unit = { 57 | currentModel = state(0) 58 | newModel = state(1) 59 | } 60 | 61 | override def flatMap2(model: ModelToServe, out: Collector[Double]): Unit = { 62 | 63 | import DataProcessorMap._ 64 | 65 | println(s"New model - $model") 66 | newModel = factories.get(model.modelType) match{ 67 | case Some(factory) => factory.create(model) 68 | case _ => None 69 | } 70 | } 71 | 72 | override def flatMap1(record: WineRecord, out: Collector[Double]): Unit = { 73 | // See if we need to update 74 | newModel match { 75 | case Some(model) => { 76 | // close current model first 77 | currentModel match { 78 | case Some(m) => m.cleanup(); 79 | case _ => 80 | } 81 | // Update model 82 | currentModel = Some(model) 83 | newModel = None 84 | } 85 | case _ => 86 | } 87 | currentModel match { 88 | case Some(model) => { 89 | val start = System.currentTimeMillis() 90 | val quality = model.score(record.asInstanceOf[AnyVal]).asInstanceOf[Double] 91 | val duration = System.currentTimeMillis() - start 92 | println(s"Subtask ${this.getRuntimeContext.getIndexOfThisSubtask} calculated quality - $quality calculated in $duration ms") 93 | } 94 | case _ => println("No model available - skipping") 95 | } 96 | } 97 | } -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/partitioned/ModelServingFlatJob.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.partitioned 2 | 3 | import java.util.Properties 4 | 5 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 6 | import com.lightbend.model.winerecord.WineRecord 7 | import com.lightbend.modelserver.typeschema.ByteArraySchema 8 | import com.lightbend.modelserver.BadDataHandler 9 | import com.lightbend.modelserver.support.scala.{DataReader, ModelToServe} 10 | import org.apache.flink.api.scala._ 11 | import org.apache.flink.configuration.{ConfigConstants, Configuration, JobManagerOptions, QueryableStateOptions} 12 | import org.apache.flink.runtime.concurrent.Executors 13 | import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils 14 | import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster 15 | import org.apache.flink.streaming.api.TimeCharacteristic 16 | import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment 17 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer010 18 | 19 | 20 | /** 21 | * Created by boris on 5/9/17. 22 | * loosely based on http://dataartisans.github.io/flink-training/exercises/eventTimeJoin.html approach 23 | * for queriable state 24 | * https://github.com/dataArtisans/flink-queryable_state_demo/blob/master/README.md 25 | * Using Flink min server to enable Queryable data access 26 | * see https://github.com/dataArtisans/flink-queryable_state_demo/blob/master/src/main/java/com/dataartisans/queryablestatedemo/EventCountJob.java 27 | * 28 | * This little application is based on a RichCoFlatMapFunction which works on a non keyed streams. It is 29 | * applicable when a single applications serves a single model(model set) for a single data type. 30 | * Scaling of the application is based on the parallelism of input stream and RichCoFlatMapFunction. 31 | * The model is broadcasted to all RichCoFlatMapFunction instances. The messages are processed by different 32 | * instances of RichCoFlatMapFunction in a round-robin fashion. 33 | */ 34 | 35 | object ModelServingFlatJob { 36 | 37 | def main(args: Array[String]): Unit = { 38 | // executeLocal() 39 | executeServer() 40 | } 41 | 42 | // Execute on the local Flink server - to test queariable state 43 | def executeServer() : Unit = { 44 | 45 | // We use a mini cluster here for sake of simplicity, because I don't want 46 | // to require a Flink installation to run this demo. Everything should be 47 | // contained in this JAR. 48 | 49 | val port = 6124 50 | val parallelism = 4 51 | 52 | 53 | val config = new Configuration() 54 | config.setInteger(JobManagerOptions.PORT, port) 55 | config.setString(JobManagerOptions.ADDRESS, "localhost"); 56 | config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, parallelism) 57 | // In a non MiniCluster setup queryable state is enabled by default. 58 | config.setBoolean(QueryableStateOptions.SERVER_ENABLE, true) 59 | config.setBoolean(ConfigConstants.LOCAL_START_WEBSERVER, true); 60 | // needed because queryable state server is always disabled with only one TaskManager 61 | config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, 2); 62 | 63 | // Create a local Flink server 64 | val flinkCluster = new LocalFlinkMiniCluster( 65 | config, 66 | HighAvailabilityServicesUtils.createHighAvailabilityServices( 67 | config, 68 | Executors.directExecutor(), 69 | HighAvailabilityServicesUtils.AddressResolution.TRY_ADDRESS_RESOLUTION), 70 | false); 71 | try { 72 | // Start server and create environment 73 | flinkCluster.start(true); 74 | val env = StreamExecutionEnvironment.createRemoteEnvironment("localhost", port, parallelism) 75 | // Build Graph 76 | buildGraph(env) 77 | env.execute() 78 | val jobGraph = env.getStreamGraph.getJobGraph 79 | // Submit to the server and wait for completion 80 | flinkCluster.submitJobAndWait(jobGraph, false) 81 | } catch { 82 | case e: Exception => e.printStackTrace() 83 | } 84 | } 85 | 86 | // Execute localle in the environment 87 | def executeLocal() : Unit = { 88 | val env = StreamExecutionEnvironment.getExecutionEnvironment 89 | buildGraph(env) 90 | System.out.println("[info] Job ID: " + env.getStreamGraph.getJobGraph.getJobID) 91 | env.execute() 92 | } 93 | 94 | // Build execution Graph 95 | def buildGraph(env : StreamExecutionEnvironment) : Unit = { 96 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) 97 | env.enableCheckpointing(5000) 98 | // configure Kafka consumer 99 | // Data 100 | val dataKafkaProps = new Properties 101 | dataKafkaProps.setProperty("zookeeper.connect", ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 102 | dataKafkaProps.setProperty("bootstrap.servers", ApplicationKafkaParameters.LOCAL_KAFKA_BROKER) 103 | dataKafkaProps.setProperty("group.id", ApplicationKafkaParameters.DATA_GROUP) 104 | dataKafkaProps.setProperty("auto.offset.reset", "latest") 105 | 106 | // Model 107 | val modelKafkaProps = new Properties 108 | modelKafkaProps.setProperty("zookeeper.connect", ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 109 | modelKafkaProps.setProperty("bootstrap.servers", ApplicationKafkaParameters.LOCAL_KAFKA_BROKER) 110 | modelKafkaProps.setProperty("group.id", ApplicationKafkaParameters.MODELS_GROUP) 111 | // always read the Kafka topic from the current location 112 | modelKafkaProps.setProperty("auto.offset.reset", "latest") 113 | 114 | // create a Kafka consumers 115 | // Data 116 | val dataConsumer = new FlinkKafkaConsumer010[Array[Byte]]( 117 | ApplicationKafkaParameters.DATA_TOPIC, 118 | new ByteArraySchema, 119 | dataKafkaProps 120 | ) 121 | 122 | // Model 123 | val modelConsumer = new FlinkKafkaConsumer010[Array[Byte]]( 124 | ApplicationKafkaParameters.MODELS_TOPIC, 125 | new ByteArraySchema, 126 | modelKafkaProps 127 | ) 128 | 129 | // Create input data streams 130 | val modelsStream = env.addSource(modelConsumer) 131 | val dataStream = env.addSource(dataConsumer) 132 | 133 | // Read data from streams 134 | val models = modelsStream.map(ModelToServe.fromByteArray(_)) 135 | .flatMap(BadDataHandler[ModelToServe]) 136 | .broadcast 137 | val data = dataStream.map(DataReader.fromByteArray(_)) 138 | .flatMap(BadDataHandler[WineRecord]) 139 | 140 | // Merge streams 141 | data 142 | .connect(models) 143 | .flatMap(DataProcessorMap()) 144 | } 145 | } -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/query/ModelStateQuery.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.query 2 | 3 | import java.util.concurrent.{Executors, TimeUnit} 4 | 5 | import com.lightbend.modelserver.support.scala.ModelToServeStats 6 | import org.apache.flink.api.common.{ExecutionConfig, JobID} 7 | import org.apache.flink.api.scala.createTypeInformation 8 | import org.apache.flink.configuration.{Configuration, JobManagerOptions} 9 | import org.apache.flink.runtime.highavailability.HighAvailabilityServicesUtils 10 | import org.apache.flink.runtime.query.QueryableStateClient 11 | import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer 12 | import org.apache.flink.runtime.state.{VoidNamespace, VoidNamespaceSerializer} 13 | import org.joda.time.DateTime 14 | 15 | import scala.concurrent.Await 16 | import scala.concurrent.duration.FiniteDuration 17 | 18 | /** 19 | * Created by boris on 5/12/17. 20 | * see https://ci.apache.org/projects/flink/flink-docs-release-1.3/dev/stream/queryable_state.html 21 | * It uses default port 6123 to access Flink server 22 | */ 23 | object ModelStateQuery { 24 | 25 | val timeInterval = 1000 * 20 // 20 sec 26 | 27 | def main(args: Array[String]) { 28 | 29 | val jobId = JobID.fromHexString("817bfacb1f0317eb15fb20c1201b9e1a") 30 | val types = Array("wine") 31 | 32 | val config = new Configuration() 33 | config.setString(JobManagerOptions.ADDRESS, "localhost") 34 | config.setInteger(JobManagerOptions.PORT, 6124) 35 | 36 | val highAvailabilityServices = HighAvailabilityServicesUtils.createHighAvailabilityServices( 37 | config, Executors.newSingleThreadScheduledExecutor, HighAvailabilityServicesUtils.AddressResolution.TRY_ADDRESS_RESOLUTION) 38 | val client = new QueryableStateClient(config, highAvailabilityServices) 39 | 40 | val execConfig = new ExecutionConfig 41 | val keySerializer = createTypeInformation[String].createSerializer(execConfig) 42 | val valueSerializer = createTypeInformation[ModelToServeStats].createSerializer(execConfig) 43 | 44 | println(" Name | Description | Since | Average | Min | Max |") 45 | while(true) { 46 | val stats = for (key <- types) yield { 47 | val serializedKey = KvStateRequestSerializer.serializeKeyAndNamespace( 48 | key, keySerializer, 49 | VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE) 50 | 51 | // now wait for the result and return it 52 | try { 53 | val serializedResult = client.getKvState(jobId, "currentModel", key.hashCode(), serializedKey) 54 | val serializedValue = Await.result(serializedResult, FiniteDuration(2, TimeUnit.SECONDS)) 55 | val value = KvStateRequestSerializer.deserializeValue(serializedValue, valueSerializer) 56 | List(value.name, value.description, value.since, value.usage, value.duration, value.min, value.max) 57 | } catch { 58 | case e: Exception => { 59 | e.printStackTrace() 60 | List() 61 | } 62 | } 63 | } 64 | stats.toList.filter(_.nonEmpty).foreach(row => 65 | println(s" ${row(0)} | ${row(1)} | ${new DateTime(row(2)).toString("yyyy/MM/dd HH:MM:SS")} | ${row(3)} |" + 66 | s" ${row(4).asInstanceOf[Double]/row(3).asInstanceOf[Long]} | ${row(5)} | ${row(6)} |") 67 | ) 68 | Thread.sleep(timeInterval) 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/typeschema/ByteArraySchema.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.typeschema 2 | 3 | /** 4 | * Created by boris on 5/9/17. 5 | */ 6 | import org.apache.flink.api.common.typeinfo.TypeInformation 7 | import org.apache.flink.api.java.typeutils.TypeExtractor 8 | import org.apache.flink.streaming.util.serialization.{DeserializationSchema, SerializationSchema} 9 | 10 | class ByteArraySchema extends DeserializationSchema[Array[Byte]] with SerializationSchema[Array[Byte]] { 11 | 12 | private val serialVersionUID: Long = 1234567L 13 | 14 | override def isEndOfStream(nextElement: Array[Byte]): Boolean = false 15 | 16 | override def deserialize(message: Array[Byte]): Array[Byte] = message 17 | 18 | override def serialize(element: Array[Byte]): Array[Byte] = element 19 | 20 | override def getProducedType: TypeInformation[Array[Byte]] = 21 | //PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO 22 | TypeExtractor.getForClass(classOf[Array[Byte]]) 23 | } 24 | -------------------------------------------------------------------------------- /flinkserver/src/main/scala/com/lightbend/modelserver/typeschema/ModelTypeSerializer.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.typeschema 2 | 3 | import java.io.IOException 4 | 5 | import com.lightbend.model.modeldescriptor.ModelDescriptor 6 | import org.apache.flink.api.common.typeutils.{CompatibilityResult, GenericTypeSerializerConfigSnapshot, TypeSerializer, TypeSerializerConfigSnapshot} 7 | import com.lightbend.model.scala.Model 8 | import com.lightbend.model.scala.PMML.PMMLModel 9 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 10 | import org.apache.flink.core.memory.{DataInputView, DataOutputView} 11 | 12 | class ModelTypeSerializer extends TypeSerializer[Option[Model]] { 13 | 14 | import ModelTypeSerializer._ 15 | 16 | override def createInstance(): Option[Model] = None 17 | 18 | override def canEqual(obj: scala.Any): Boolean = obj.isInstanceOf[ModelTypeSerializer] 19 | 20 | override def duplicate(): TypeSerializer[Option[Model]] = new ModelTypeSerializer 21 | 22 | override def ensureCompatibility(configSnapshot: TypeSerializerConfigSnapshot): CompatibilityResult[Option[Model]] = 23 | CompatibilityResult.requiresMigration() 24 | 25 | override def serialize(record: Option[Model], target: DataOutputView): Unit = { 26 | record match { 27 | case Some(model) => { 28 | target.writeBoolean(true) 29 | val content = model.toBytes() 30 | target.writeLong(model.getType) 31 | target.writeLong(content.length) 32 | target.write(content) 33 | } 34 | case _ => target.writeBoolean(false) 35 | } 36 | } 37 | 38 | override def isImmutableType: Boolean = false 39 | 40 | override def getLength: Int = -1 41 | 42 | override def snapshotConfiguration(): TypeSerializerConfigSnapshot = new ModelSerializerConfigSnapshot 43 | 44 | override def copy(from: Option[Model]): Option[Model] = 45 | from match { 46 | case Some(model) => Some(factories.get(model.getType.asInstanceOf[Int]).get.restore(model.toBytes())) 47 | case _ => None 48 | } 49 | 50 | override def copy(from: Option[Model], reuse: Option[Model]): Option[Model] = 51 | from match { 52 | case Some(model) => Some(factories.get(model.getType.asInstanceOf[Int]).get.restore(model.toBytes())) 53 | case _ => None 54 | } 55 | 56 | override def copy(source: DataInputView, target: DataOutputView): Unit = { 57 | val exist = source.readBoolean() 58 | target.writeBoolean(exist) 59 | exist match { 60 | case true => { 61 | target.writeLong (source.readLong () ) 62 | val clen = source.readLong ().asInstanceOf[Int] 63 | target.writeLong (clen) 64 | val content = new Array[Byte] (clen) 65 | source.read (content) 66 | target.write (content) 67 | } 68 | case _ => 69 | } 70 | } 71 | 72 | override def deserialize(source: DataInputView): Option[Model] = 73 | source.readBoolean() match { 74 | case true => { 75 | val t = source.readLong().asInstanceOf[Int] 76 | val size = source.readLong().asInstanceOf[Int] 77 | val content = new Array[Byte] (size) 78 | source.read (content) 79 | Some(factories.get(t).get.restore(content)) 80 | } 81 | case _ => None 82 | } 83 | 84 | override def deserialize(reuse: Option[Model], source: DataInputView): Option[Model] = 85 | source.readBoolean() match { 86 | case true => { 87 | val t = source.readLong().asInstanceOf[Int] 88 | val size = source.readLong().asInstanceOf[Int] 89 | val content = new Array[Byte] (size) 90 | source.read (content) 91 | Some(factories.get(t).get.restore(content)) 92 | } 93 | case _ => None 94 | } 95 | 96 | override def equals(obj: scala.Any): Boolean = obj.isInstanceOf[ModelTypeSerializer] 97 | 98 | override def hashCode(): Int = 42 99 | } 100 | 101 | object ModelTypeSerializer{ 102 | private val factories = Map(ModelDescriptor.ModelType.PMML.value -> PMMLModel, 103 | ModelDescriptor.ModelType.TENSORFLOW.value -> TensorFlowModel) 104 | 105 | def apply : ModelTypeSerializer = new ModelTypeSerializer() 106 | } 107 | 108 | 109 | object ModelSerializerConfigSnapshot { 110 | val VERSION = 1 111 | } 112 | 113 | class ModelSerializerConfigSnapshot[T <: Model] 114 | extends TypeSerializerConfigSnapshot{ 115 | 116 | import ModelSerializerConfigSnapshot._ 117 | 118 | // def this() {this(classOf[T])} 119 | 120 | override def getVersion = VERSION 121 | 122 | var typeClass = classOf[Model] 123 | 124 | override def write(out: DataOutputView): Unit = { 125 | super.write(out) 126 | // write only the classname to avoid Java serialization 127 | out.writeUTF(classOf[Model].getName) 128 | } 129 | 130 | override def read(in: DataInputView): Unit = { 131 | super.read(in) 132 | val genericTypeClassname = in.readUTF 133 | try 134 | typeClass = Class.forName(genericTypeClassname, true, getUserCodeClassLoader).asInstanceOf[Class[Model]] 135 | catch { 136 | case e: ClassNotFoundException => 137 | throw new IOException("Could not find the requested class " + genericTypeClassname + " in classpath.", e) 138 | } 139 | } 140 | 141 | def getTypeClass: Class[T] = typeClass.asInstanceOf[Class[T]] 142 | 143 | override def equals(obj: Any): Boolean = { 144 | if (obj == this) return true 145 | if (obj == null) return false 146 | (obj.getClass == getClass) && typeClass == obj.asInstanceOf[GenericTypeSerializerConfigSnapshot[_]].getTypeClass 147 | } 148 | 149 | override def hashCode: Int = 42 150 | } 151 | -------------------------------------------------------------------------------- /kafkaclient/src/main/scala/com/lightbend/kafka/DataProvider.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.kafka 2 | 3 | import java.io.ByteArrayOutputStream 4 | 5 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 6 | import com.lightbend.model.winerecord.WineRecord 7 | 8 | import scala.io.Source 9 | 10 | /** 11 | * Created by boris on 5/10/17. 12 | * 13 | * Application publishing models from /data directory to Kafka 14 | */ 15 | object DataProvider { 16 | 17 | val file = "data/winequality_red.csv" 18 | val timeInterval = 1000 * 1 // 1 sec 19 | 20 | def main(args: Array[String]) { 21 | val sender = KafkaMessageSender(ApplicationKafkaParameters.LOCAL_KAFKA_BROKER, ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 22 | sender.createTopic(ApplicationKafkaParameters.DATA_TOPIC) 23 | val bos = new ByteArrayOutputStream() 24 | val records = getListOfRecords(file) 25 | var nrec = 0 26 | while (true) { 27 | records.foreach(r => { 28 | bos.reset() 29 | r.writeTo(bos) 30 | sender.writeValue(ApplicationKafkaParameters.DATA_TOPIC, bos.toByteArray) 31 | nrec = nrec + 1 32 | if(nrec % 10 == 0) 33 | println(s"printed $nrec records") 34 | pause() 35 | }) 36 | } 37 | } 38 | 39 | private def pause() : Unit = { 40 | try{ 41 | Thread.sleep(timeInterval) 42 | } 43 | catch { 44 | case _: Throwable => // Ignore 45 | } 46 | } 47 | 48 | def getListOfRecords(file: String): Seq[WineRecord] = { 49 | 50 | var result = Seq.empty[WineRecord] 51 | val bufferedSource = Source.fromFile(file) 52 | for (line <- bufferedSource.getLines) { 53 | val cols = line.split(";").map(_.trim) 54 | val record = new WineRecord( 55 | fixedAcidity = cols(0).toDouble, 56 | volatileAcidity = cols(1).toDouble, 57 | citricAcid = cols(2).toDouble, 58 | residualSugar = cols(3).toDouble, 59 | chlorides = cols(4).toDouble, 60 | freeSulfurDioxide = cols(5).toDouble, 61 | totalSulfurDioxide = cols(6).toDouble, 62 | density = cols(7).toDouble, 63 | pH = cols(8).toDouble, 64 | sulphates = cols(9).toDouble, 65 | alcohol = cols(10).toDouble, 66 | dataType = "wine" 67 | ) 68 | result = record +: result 69 | } 70 | bufferedSource.close 71 | result 72 | } 73 | } -------------------------------------------------------------------------------- /kafkaclient/src/main/scala/com/lightbend/kafka/KafkaMessageSender.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.kafka 2 | 3 | /** 4 | * Created by boris on 5/10/17. 5 | * Byte array sender to Kafka 6 | */ 7 | 8 | import java.util.Properties 9 | 10 | import kafka.admin.AdminUtils 11 | import kafka.utils.ZkUtils 12 | import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord, RecordMetadata} 13 | import org.apache.kafka.common.serialization.ByteArraySerializer 14 | 15 | import scala.collection.mutable.Map 16 | 17 | 18 | class KafkaMessageSender (brokers: String, zookeeper : String){ 19 | 20 | // Configure 21 | val props = new Properties 22 | props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers) 23 | props.put(ProducerConfig.ACKS_CONFIG, KafkaMessageSender.ACKCONFIGURATION) 24 | props.put(ProducerConfig.RETRIES_CONFIG, KafkaMessageSender.RETRYCOUNT) 25 | props.put(ProducerConfig.BATCH_SIZE_CONFIG, KafkaMessageSender.BATCHSIZE) 26 | props.put(ProducerConfig.LINGER_MS_CONFIG, KafkaMessageSender.LINGERTIME) 27 | props.put(ProducerConfig.BUFFER_MEMORY_CONFIG, KafkaMessageSender.BUFFERMEMORY) 28 | props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) 29 | props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) 30 | 31 | // Create producer 32 | val producer = new KafkaProducer[Array[Byte], Array[Byte]](props) 33 | val zkUtils = ZkUtils.apply(zookeeper, KafkaMessageSender.sessionTimeout, 34 | KafkaMessageSender.connectionTimeout, false) 35 | 36 | // Write value to the queue 37 | def writeValue(topic: String, value: Array[Byte]): RecordMetadata = { 38 | val result = producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, null, value)).get 39 | producer.flush() 40 | result 41 | } 42 | 43 | // Close producer 44 | def close(): Unit = { 45 | producer.close 46 | } 47 | 48 | def createTopic(topic : String, numPartitions: Int = 1, replicationFactor : Int = 1): Unit = { 49 | if (!AdminUtils.topicExists(zkUtils, topic)){ 50 | try { 51 | AdminUtils.createTopic(zkUtils,topic, numPartitions, replicationFactor) 52 | println(s"Topic $topic with $numPartitions partitions and replication factor $replicationFactor is created") 53 | }catch { 54 | case t: Throwable => println(s"Failed to create topic $topic. ${t.getMessage}") 55 | } 56 | } 57 | else 58 | println(s"Topic $topic already exists") 59 | } 60 | } 61 | 62 | object KafkaMessageSender{ 63 | private val ACKCONFIGURATION = "all" // Blocking on the full commit of the record 64 | private val RETRYCOUNT = "1" // Number of retries on put 65 | private val BATCHSIZE = "1024" // Buffers for unsent records for each partition - controlls batching 66 | private val LINGERTIME = "1" // Timeout for more records to arive - controlls batching 67 | private val BUFFERMEMORY = "1024000" // Controls the total amount of memory available to the producer for buffering. If records are sent faster than they can be transmitted to the server then this buffer space will be exhausted. When the buffer space is exhausted additional send calls will block. The threshold for time to block is determined by max.block.ms after which it throws a TimeoutException. 68 | private val senders : Map[String, KafkaMessageSender] = Map() // Producer instances 69 | 70 | private val sessionTimeout = 10 * 1000 71 | private val connectionTimeout = 8 * 1000 72 | 73 | def apply(brokers: String, zookeeper : String): KafkaMessageSender = { 74 | senders.get(brokers) match { 75 | case Some(sender) => sender // Producer already exists 76 | case _ => { // Does not exist - create a new one 77 | val sender = new KafkaMessageSender(brokers, zookeeper) 78 | senders.put(brokers, sender) 79 | sender 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /kafkaclient/src/main/scala/com/lightbend/kafka/ModelProvider.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.kafka 2 | 3 | import java.io.{ByteArrayOutputStream, File} 4 | import java.nio.file.{Files, Paths} 5 | 6 | import com.google.protobuf.ByteString 7 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 8 | import com.lightbend.model.modeldescriptor.ModelDescriptor 9 | 10 | /** 11 | * Created by boris on 5/10/17. 12 | * 13 | * Application publishing models from /data directory to Kafka 14 | */ 15 | object ModelProvider { 16 | 17 | val directory = "data/" 18 | val tensorfile = "data/optimized_WineQuality.pb" 19 | val timeInterval = 1000 * 60 * 1 // 1 mins 20 | 21 | def main(args: Array[String]) { 22 | val sender = KafkaMessageSender(ApplicationKafkaParameters.LOCAL_KAFKA_BROKER, ApplicationKafkaParameters.LOCAL_ZOOKEEPER_HOST) 23 | sender.createTopic(ApplicationKafkaParameters.MODELS_TOPIC) 24 | val files = getListOfFiles(directory) 25 | val bos = new ByteArrayOutputStream() 26 | while (true) { 27 | files.foreach(f => { 28 | // PMML 29 | val pByteArray = Files.readAllBytes(Paths.get(directory + f)) 30 | val pRecord = ModelDescriptor(name = f.dropRight(5), 31 | description = "generated from SparkML", modeltype = ModelDescriptor.ModelType.PMML, 32 | dataType = "wine").withData(ByteString.copyFrom(pByteArray)) 33 | bos.reset() 34 | pRecord.writeTo(bos) 35 | sender.writeValue(ApplicationKafkaParameters.MODELS_TOPIC, bos.toByteArray) 36 | pause() 37 | // TF 38 | val tByteArray = Files.readAllBytes(Paths.get(tensorfile)) 39 | val tRecord = ModelDescriptor(name = tensorfile.dropRight(3), 40 | description = "generated from TensorFlow", modeltype = ModelDescriptor.ModelType.TENSORFLOW, 41 | dataType = "wine").withData(ByteString.copyFrom(tByteArray)) 42 | bos.reset() 43 | tRecord.writeTo(bos) 44 | sender.writeValue(ApplicationKafkaParameters.MODELS_TOPIC, bos.toByteArray) 45 | pause() 46 | }) 47 | } 48 | } 49 | 50 | private def pause() : Unit = { 51 | try{ 52 | Thread.sleep(timeInterval) 53 | } 54 | catch { 55 | case _: Throwable => // Ignore 56 | } 57 | } 58 | 59 | private def getListOfFiles(dir: String):Seq[String] = { 60 | val d = new File(dir) 61 | if (d.exists && d.isDirectory) { 62 | d.listFiles.filter(f => (f.isFile) && (f.getName.endsWith(".pmml"))).map(_.getName) 63 | } else { 64 | Seq.empty[String] 65 | } 66 | } 67 | } -------------------------------------------------------------------------------- /kafkaconfiguration/src/main/java/com/lightbend/configuration/kafka/ApplicationKafkaParameters.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.configuration.kafka; 2 | 3 | /** 4 | * Created by boris on 5/18/17. 5 | * Set of parameters for running applications 6 | */ 7 | public class ApplicationKafkaParameters { 8 | 9 | private ApplicationKafkaParameters(){} 10 | 11 | public static final String LOCAL_ZOOKEEPER_HOST = "localhost:2181"; 12 | public static final String LOCAL_KAFKA_BROKER = "localhost:9092"; 13 | 14 | public static final String DATA_TOPIC = "mdata"; 15 | public static final String MODELS_TOPIC = "models"; 16 | 17 | public static final String DATA_GROUP = "wineRecordsGroup"; 18 | public static final String MODELS_GROUP = "modelRecordsGroup"; 19 | } 20 | -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/Model.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Created by boris on 5/9/17. 7 | * Basic trait for model 8 | */ 9 | public interface Model extends Serializable { 10 | Object score(Object input); 11 | void cleanup(); 12 | byte[] getBytes(); 13 | long getType(); 14 | } -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/ModelFactory.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java; 2 | 3 | import com.lightbend.modelserver.support.java.ModelToServe; 4 | 5 | import java.util.Optional; 6 | 7 | /** 8 | * Created by boris on 7/14/17. 9 | */ 10 | public interface ModelFactory { 11 | Optional create(ModelToServe descriptor); 12 | Model restore(byte[] bytes); 13 | } 14 | -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/PMML/PMMLModel.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java.PMML; 2 | 3 | import com.google.protobuf.Descriptors; 4 | import com.lightbend.model.java.Model; 5 | import com.lightbend.model.Modeldescriptor; 6 | import com.lightbend.model.Winerecord; 7 | import org.dmg.pmml.FieldName; 8 | import org.dmg.pmml.PMML; 9 | import org.dmg.pmml.Visitor; 10 | import org.jpmml.evaluator.*; 11 | import org.jpmml.evaluator.visitors.*; 12 | import org.jpmml.model.PMMLUtil; 13 | 14 | import java.io.ByteArrayInputStream; 15 | import java.io.ByteArrayOutputStream; 16 | import java.util.*; 17 | 18 | /** 19 | * Created by boris on 5/18/17. 20 | */ 21 | public class PMMLModel implements Model { 22 | 23 | private static List optimizers = Arrays.asList(new ExpressionOptimizer(), new FieldOptimizer(), new PredicateOptimizer(), new GeneralRegressionModelOptimizer(), new NaiveBayesModelOptimizer(), new RegressionModelOptimizer()); 24 | 25 | private static Map names = createNamesMap(); 26 | private static Map createNamesMap() { 27 | Map map = new HashMap<>(); 28 | map.put("fixed acidity", "fixed_acidity"); 29 | map.put("volatile acidity", "volatile_acidity"); 30 | map.put("citric acid", "citric_acid"); 31 | map.put("residual sugar", "residual_sugar"); 32 | map.put("chlorides", "chlorides"); 33 | map.put("free sulfur dioxide", "free_sulfur_dioxide"); 34 | map.put("total sulfur dioxide", "total_sulfur_dioxide"); 35 | map.put("density", "density"); 36 | map.put("pH", "pH"); 37 | map.put("sulphates", "sulphates"); 38 | map.put("alcohol", "alcohol"); 39 | return map; 40 | } 41 | 42 | private PMML pmml; 43 | private Evaluator evaluator; 44 | private FieldName tname; 45 | private List inputFields; 46 | private Map arguments = new LinkedHashMap<>(); 47 | 48 | public PMMLModel(byte[] input) throws Throwable{ 49 | // unmarshal PMML 50 | pmml = PMMLUtil.unmarshal(new ByteArrayInputStream(input)); 51 | // Optimize model 52 | synchronized(this) { 53 | for (Visitor optimizer : optimizers) { 54 | optimizer.applyTo(pmml); 55 | } 56 | } 57 | 58 | // Create and verify evaluator 59 | ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); 60 | evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); 61 | evaluator.verify(); 62 | 63 | // Get input/target fields 64 | inputFields = evaluator.getInputFields(); 65 | TargetField target = evaluator.getTargetFields().get(0); 66 | tname = target.getName(); 67 | } 68 | 69 | @Override 70 | public Object score(Object input) { 71 | Winerecord.WineRecord inputs = (Winerecord.WineRecord) input; 72 | arguments.clear(); 73 | for(InputField field : inputFields){ 74 | arguments.put(field.getName(), field.prepare(getValueByName(inputs,field.getName().getValue()))); 75 | } 76 | 77 | // Calculate Output// Calculate Output 78 | Map result = evaluator.evaluate(arguments); 79 | 80 | // Prepare output 81 | double rv = 0; 82 | Object tresult = result.get(tname); 83 | if(tresult instanceof Computable){ 84 | String value = ((Computable)tresult).getResult().toString(); 85 | rv = Double.parseDouble(value); 86 | } 87 | else 88 | rv = (Double)tresult; 89 | return rv; 90 | } 91 | 92 | @Override 93 | public void cleanup() { 94 | // Do nothing 95 | 96 | } 97 | 98 | // Get variable value by name 99 | private double getValueByName(Winerecord.WineRecord input, String name){ 100 | Descriptors.FieldDescriptor descriptor = input.getDescriptorForType().findFieldByName(names.get(name)); 101 | return (double)input.getField(descriptor); 102 | } 103 | 104 | @Override 105 | public byte[] getBytes() { 106 | ByteArrayOutputStream ous = new ByteArrayOutputStream(); 107 | try { 108 | PMMLUtil.marshal(pmml, ous); 109 | } 110 | catch(Throwable t){ 111 | t.printStackTrace(); 112 | } 113 | return ous.toByteArray(); 114 | } 115 | 116 | @Override 117 | public long getType() { 118 | return (long) Modeldescriptor.ModelDescriptor.ModelType.PMML.getNumber(); 119 | } 120 | } -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/PMML/PMMLModelFactory.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java.PMML; 2 | 3 | import com.lightbend.model.java.Model; 4 | import com.lightbend.model.java.ModelFactory; 5 | import com.lightbend.modelserver.support.java.ModelToServe; 6 | 7 | import java.util.Optional; 8 | 9 | /** 10 | * Created by boris on 7/15/17. 11 | */ 12 | public class PMMLModelFactory implements ModelFactory { 13 | 14 | private static ModelFactory instance = null; 15 | 16 | private PMMLModelFactory(){} 17 | 18 | @Override 19 | public Optional create(ModelToServe descriptor) { 20 | try{ 21 | return Optional.of(new PMMLModel(descriptor.getModelData())); 22 | } 23 | catch (Throwable t){ 24 | System.out.println("Exception creating PMMLModel from " + descriptor); 25 | t.printStackTrace(); 26 | return Optional.empty(); 27 | } 28 | } 29 | 30 | @Override 31 | public Model restore(byte[] bytes) { 32 | try{ 33 | return new PMMLModel(bytes); 34 | } 35 | catch (Throwable t){ 36 | System.out.println("Exception restoring PMMLModel from "); 37 | t.printStackTrace(); 38 | return null; 39 | } 40 | } 41 | 42 | public static ModelFactory getInstance(){ 43 | if(instance == null) 44 | instance = new PMMLModelFactory(); 45 | return instance; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/tensorflow/TensorflowModel.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java.tensorflow; 2 | 3 | /** 4 | * Created by boris on 5/26/17. 5 | */ 6 | 7 | import com.lightbend.model.java.Model; 8 | import com.lightbend.model.Modeldescriptor; 9 | import com.lightbend.model.Winerecord; 10 | import org.tensorflow.Graph; 11 | import org.tensorflow.Session; 12 | import org.tensorflow.Tensor; 13 | 14 | public class TensorflowModel implements Model { 15 | private Graph graph = new Graph(); 16 | private Session session; 17 | 18 | public TensorflowModel(byte[] inputStream) { 19 | graph.importGraphDef(inputStream); 20 | session = new Session(graph); 21 | } 22 | 23 | @Override 24 | public Object score(Object input) { 25 | Winerecord.WineRecord record = (Winerecord.WineRecord) input; 26 | float[][] data = {{ 27 | (float)record.getFixedAcidity(), 28 | (float)record.getVolatileAcidity(), 29 | (float)record.getCitricAcid(), 30 | (float)record.getResidualSugar(), 31 | (float)record.getChlorides(), 32 | (float)record.getFreeSulfurDioxide(), 33 | (float)record.getTotalSulfurDioxide(), 34 | (float)record.getDensity(), 35 | (float)record.getPH(), 36 | (float)record.getSulphates(), 37 | (float)record.getAlcohol() 38 | }}; 39 | Tensor modelInput = Tensor.create(data); 40 | Tensor result = session.runner().feed("dense_1_input", modelInput).fetch("dense_3/Sigmoid").run().get(0); 41 | long[] rshape = result.shape(); 42 | float[][] rMatrix = new float[(int)rshape[0]][(int)rshape[1]]; 43 | result.copyTo(rMatrix); 44 | Intermediate value = new Intermediate(0, rMatrix[0][0]); 45 | for(int i=1; i < rshape[1]; i++){ 46 | if(rMatrix[0][i] > value.getValue()) { 47 | value.setIndex(i); 48 | value.setValue(rMatrix[0][i]); 49 | } 50 | } 51 | return (double)value.getIndex(); 52 | } 53 | 54 | @Override 55 | public void cleanup() { 56 | session.close(); 57 | graph.close(); 58 | } 59 | 60 | @Override 61 | public byte[] getBytes() { 62 | return graph.toGraphDef(); 63 | } 64 | 65 | public Graph getGraph() { 66 | return graph; 67 | } 68 | 69 | private class Intermediate{ 70 | private int index; 71 | private float value; 72 | public Intermediate(int i, float v){ 73 | index = i; 74 | value = v; 75 | } 76 | 77 | public int getIndex() { 78 | return index; 79 | } 80 | 81 | public void setIndex(int index) { 82 | this.index = index; 83 | } 84 | 85 | public float getValue() { 86 | return value; 87 | } 88 | 89 | public void setValue(float value) { 90 | this.value = value; 91 | } 92 | } 93 | 94 | @Override 95 | public long getType() { 96 | return (long) Modeldescriptor.ModelDescriptor.ModelType.TENSORFLOW.getNumber(); 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /model/src/main/java/com/lightbend/model/java/tensorflow/TensorflowModelFactory.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.java.tensorflow; 2 | 3 | import com.lightbend.model.java.Model; 4 | import com.lightbend.model.java.ModelFactory; 5 | import com.lightbend.modelserver.support.java.ModelToServe; 6 | 7 | import java.util.Optional; 8 | 9 | /** 10 | * Created by boris on 7/15/17. 11 | */ 12 | public class TensorflowModelFactory implements ModelFactory { 13 | 14 | private static TensorflowModelFactory instance = null; 15 | 16 | @Override 17 | public Optional create(ModelToServe descriptor) { 18 | 19 | try{ 20 | return Optional.of(new TensorflowModel(descriptor.getModelData())); 21 | } 22 | catch (Throwable t){ 23 | System.out.println("Exception creating TensorflowModel from " + descriptor); 24 | t.printStackTrace(); 25 | return Optional.empty(); 26 | } 27 | } 28 | 29 | @Override 30 | public Model restore(byte[] bytes) { 31 | try{ 32 | return new TensorflowModel(bytes); 33 | } 34 | catch (Throwable t){ 35 | System.out.println("Exception restoring PMMLModel from "); 36 | t.printStackTrace(); 37 | return null; 38 | } 39 | } 40 | 41 | public static ModelFactory getInstance(){ 42 | if(instance == null) 43 | instance = new TensorflowModelFactory(); 44 | return instance; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /model/src/main/scala/com/lightbend/model/scala/DataWithModel.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.scala 2 | 3 | import com.lightbend.model.winerecord.WineRecord 4 | import com.lightbend.modelserver.support.scala.ModelToServe 5 | 6 | /** 7 | * Created by boris on 5/8/17. 8 | */ 9 | 10 | case class DataWithModel(model: Option[ModelToServe], data : Option[WineRecord]){ 11 | def isModel : Boolean = model.isDefined 12 | def getModel : ModelToServe = model.get 13 | def getData : WineRecord = data.get 14 | def getDataType : String = { 15 | if(isModel) 16 | getModel.dataType 17 | else 18 | getData.dataType 19 | } 20 | } -------------------------------------------------------------------------------- /model/src/main/scala/com/lightbend/model/scala/Model.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.scala 2 | 3 | /** 4 | * Created by boris on 5/9/17. 5 | * Basic trait for model 6 | */ 7 | trait Model { 8 | def score(input : AnyVal) : AnyVal 9 | def cleanup() : Unit 10 | def toBytes() : Array[Byte] 11 | def getType : Long 12 | } -------------------------------------------------------------------------------- /model/src/main/scala/com/lightbend/model/scala/ModelFactory.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.scala 2 | 3 | import com.lightbend.modelserver.support.scala.ModelToServe 4 | 5 | 6 | /** 7 | * Created by boris on 5/9/17. 8 | * Basic trait for model factory 9 | */ 10 | trait ModelFactory { 11 | def create(input : ModelToServe) : Option[Model] 12 | def restore(bytes : Array[Byte]) : Model 13 | } -------------------------------------------------------------------------------- /model/src/main/scala/com/lightbend/model/scala/PMML/PMMLModel.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.scala.PMML 2 | 3 | /** 4 | * Created by boris on 5/9/17. 5 | * 6 | * Class for PMML model 7 | */ 8 | 9 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream} 10 | 11 | import com.lightbend.model.scala.{Model, ModelFactory} 12 | import com.lightbend.model.modeldescriptor.ModelDescriptor 13 | import com.lightbend.model.winerecord.WineRecord 14 | import com.lightbend.modelserver.support.scala.ModelToServe 15 | import org.dmg.pmml.{FieldName, PMML} 16 | import org.jpmml.evaluator.visitors._ 17 | import org.jpmml.evaluator.{Computable, FieldValue, ModelEvaluatorFactory, TargetField} 18 | import org.jpmml.model.PMMLUtil 19 | 20 | import scala.collection.JavaConversions._ 21 | import scala.collection._ 22 | 23 | 24 | class PMMLModel(inputStream: Array[Byte]) extends Model { 25 | 26 | var arguments = mutable.Map[FieldName, FieldValue]() 27 | 28 | // Marshall PMML 29 | val pmml = PMMLUtil.unmarshal(new ByteArrayInputStream(inputStream)) 30 | 31 | // Optimize model// Optimize model 32 | PMMLModel.optimize(pmml) 33 | 34 | // Create and verify evaluator 35 | val evaluator = ModelEvaluatorFactory.newInstance.newModelEvaluator(pmml) 36 | evaluator.verify() 37 | 38 | // Get input/target fields 39 | val inputFields = evaluator.getInputFields 40 | val target: TargetField = evaluator.getTargetFields.get(0) 41 | val tname = target.getName 42 | 43 | override def score(input: AnyVal): AnyVal = { 44 | val inputs = input.asInstanceOf[WineRecord] 45 | arguments.clear() 46 | inputFields.foreach(field => { 47 | arguments.put(field.getName, field.prepare(getValueByName(inputs, field.getName.getValue))) 48 | }) 49 | 50 | // Calculate Output// Calculate Output 51 | val result = evaluator.evaluate(arguments) 52 | 53 | // Prepare output 54 | result.get(tname) match { 55 | case c : Computable => c.getResult.toString.toDouble 56 | case v : Any => v.asInstanceOf[Double] 57 | } 58 | } 59 | 60 | override def cleanup(): Unit = {} 61 | 62 | private def getValueByName(inputs : WineRecord, name: String) : Double = 63 | PMMLModel.names.get(name) match { 64 | case Some(index) => { 65 | val v = inputs.getFieldByNumber(index + 1) 66 | v.asInstanceOf[Double] 67 | } 68 | case _ => .0 69 | } 70 | 71 | override def toBytes : Array[Byte] = { 72 | var stream = new ByteArrayOutputStream() 73 | PMMLUtil.marshal(pmml, stream) 74 | stream.toByteArray 75 | } 76 | 77 | override def getType: Long = ModelDescriptor.ModelType.PMML.value 78 | } 79 | 80 | object PMMLModel extends ModelFactory { 81 | 82 | private val optimizers = Array(new ExpressionOptimizer, new FieldOptimizer, new PredicateOptimizer, 83 | new GeneralRegressionModelOptimizer, new NaiveBayesModelOptimizer, new RegressionModelOptimizer) 84 | def optimize(pmml : PMML) = this.synchronized { 85 | optimizers.foreach(opt => 86 | try { 87 | opt.applyTo(pmml) 88 | } catch { 89 | case t: Throwable => { 90 | println(s"Error optimizing model for optimizer $opt") 91 | t.printStackTrace() 92 | } 93 | } 94 | ) 95 | } 96 | private val names = Map("fixed acidity" -> 0, 97 | "volatile acidity" -> 1,"citric acid" ->2,"residual sugar" -> 3, 98 | "chlorides" -> 4,"free sulfur dioxide" -> 5,"total sulfur dioxide" -> 6, 99 | "density" -> 7,"pH" -> 8,"sulphates" ->9,"alcohol" -> 10) 100 | 101 | override def create(input: ModelToServe): Option[Model] = { 102 | try { 103 | Some(new PMMLModel(input.model)) 104 | }catch{ 105 | case t: Throwable => None 106 | } 107 | } 108 | 109 | override def restore(bytes: Array[Byte]): Model = new PMMLModel(bytes) 110 | } -------------------------------------------------------------------------------- /model/src/main/scala/com/lightbend/model/scala/tensorflow/TensorFlowModel.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.model.scala.tensorflow 2 | 3 | import com.lightbend.model.scala.{Model, ModelFactory} 4 | import com.lightbend.model.modeldescriptor.ModelDescriptor 5 | import com.lightbend.model.winerecord.WineRecord 6 | import com.lightbend.modelserver.support.scala.ModelToServe 7 | import org.tensorflow.{Graph, Session, Tensor} 8 | 9 | /** 10 | * Created by boris on 5/26/17. 11 | * Implementation of tensorflow model 12 | */ 13 | 14 | class TensorFlowModel(inputStream : Array[Byte]) extends Model{ 15 | 16 | val graph = new Graph 17 | graph.importGraphDef(inputStream) 18 | val session = new Session(graph) 19 | 20 | override def score(input: AnyVal): AnyVal = { 21 | 22 | val record = input.asInstanceOf[WineRecord] 23 | val data = Array( 24 | record.fixedAcidity.toFloat, 25 | record.volatileAcidity.toFloat, 26 | record.citricAcid.toFloat, 27 | record.residualSugar.toFloat, 28 | record.chlorides.toFloat, 29 | record.freeSulfurDioxide.toFloat, 30 | record.totalSulfurDioxide.toFloat, 31 | record.density.toFloat, 32 | record.pH.toFloat, 33 | record.sulphates.toFloat, 34 | record.alcohol.toFloat 35 | ) 36 | val modelInput = Tensor.create(Array(data)) 37 | val result = session.runner.feed("dense_1_input", modelInput).fetch("dense_3/Sigmoid").run().get(0) 38 | val rshape = result.shape 39 | var rMatrix = Array.ofDim[Float](rshape(0).asInstanceOf[Int],rshape(1).asInstanceOf[Int]) 40 | result.copyTo(rMatrix) 41 | var value = (0, rMatrix(0)(0)) 42 | 1 to (rshape(1).asInstanceOf[Int] -1) foreach{i => { 43 | if(rMatrix(0)(i) > value._2) 44 | value = (i, rMatrix(0)(i)) 45 | }} 46 | value._1.toDouble 47 | } 48 | 49 | override def cleanup(): Unit = { 50 | try{ 51 | session.close 52 | }catch { 53 | case t: Throwable => // Swallow 54 | } 55 | try{ 56 | graph.close 57 | }catch { 58 | case t: Throwable => // Swallow 59 | } 60 | } 61 | 62 | override def toBytes(): Array[Byte] = graph.toGraphDef 63 | 64 | override def getType: Long = ModelDescriptor.ModelType.TENSORFLOW.value 65 | } 66 | 67 | object TensorFlowModel extends ModelFactory { 68 | def apply(inputStream: Array[Byte]): Option[TensorFlowModel] = { 69 | try { 70 | Some(new TensorFlowModel(inputStream)) 71 | }catch{ 72 | case t: Throwable => None 73 | } 74 | } 75 | 76 | override def create(input: ModelToServe): Option[Model] = { 77 | try { 78 | Some(new TensorFlowModel(input.model)) 79 | }catch{ 80 | case t: Throwable => None 81 | } 82 | } 83 | 84 | override def restore(bytes: Array[Byte]): Model = new TensorFlowModel(bytes) 85 | } 86 | -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by boris on 7/14/17. 3 | */ 4 | import sbt._ 5 | import Versions._ 6 | 7 | object Dependencies { 8 | val reactiveKafka = "com.typesafe.akka" % "akka-stream-kafka_2.11" % reactiveKafkaVersion 9 | 10 | val akkaStream = "com.typesafe.akka" % "akka-stream_2.11" % akkaVersion 11 | val akkaHttp = "com.typesafe.akka" % "akka-http_2.11" % akkaHttpVersion 12 | val akkaHttpJsonJackson = "de.heikoseeberger" % "akka-http-jackson_2.11" % akkaHttpJsonVersion // exclude("com.fasterxml.jackson.module","jackson-module-scala_2.11") 13 | val akkaslf = "com.typesafe.akka" % "akka-slf4j_2.11" % akkaSlfVersion 14 | 15 | 16 | val kafka = "org.apache.kafka" % "kafka_2.11" % kafkaVersion 17 | val kafkaclients = "org.apache.kafka" % "kafka-clients" % kafkaVersion 18 | val kafkastreams = "org.apache.kafka" % "kafka-streams" % kafkaVersion 19 | 20 | val curator = "org.apache.curator" % "curator-test" % Curator // ApacheV2 21 | 22 | val gson = "com.google.code.gson" % "gson" % gsonVersion 23 | val jersey = "org.glassfish.jersey.containers" % "jersey-container-servlet-core" % jerseyVersion 24 | val jerseymedia = "org.glassfish.jersey.media" % "jersey-media-json-jackson" % jerseyVersion 25 | val jettyserver = "org.eclipse.jetty" % "jetty-server" % jettyVersion 26 | val jettyservlet = "org.eclipse.jetty" % "jetty-servlet" % jettyVersion 27 | val wsrs = "javax.ws.rs" % "javax.ws.rs-api" % wsrsVersion 28 | 29 | val tensorflow = "org.tensorflow" % "tensorflow" % tensorflowVersion 30 | val tensorflowProto="org.tensorflow" % "proto" % tensorflowVersion 31 | 32 | val jpmml = "org.jpmml" % "pmml-evaluator" % PMMLVersion 33 | val jpmmlextras = "org.jpmml" % "pmml-evaluator-extension" % PMMLVersion 34 | 35 | val flinkScala = "org.apache.flink" % "flink-scala_2.11" % flinkVersion 36 | val flinkStreaming= "org.apache.flink" % "flink-streaming-scala_2.11" % flinkVersion 37 | val flinkKafka = "org.apache.flink" % "flink-connector-kafka-0.10_2.11" % flinkVersion 38 | 39 | val joda = "joda-time" % "joda-time" % jodaVersion 40 | 41 | val kryo = "com.esotericsoftware.kryo" % "kryo" % kryoVersion 42 | 43 | val sparkcore = "org.apache.spark" % "spark-core_2.11" % sparkVersion 44 | val sparkstreaming= "org.apache.spark" % "spark-streaming_2.11" % sparkVersion 45 | val sparkkafka = "org.apache.spark" % "spark-streaming-kafka-0-10_2.11"% sparkVersion 46 | 47 | val scopt = "com.github.scopt" % "scopt_2.11" % scoptVersion 48 | val sparkML = "org.apache.spark" % "spark-mllib_2.11" % sparkVersion 49 | val sparkJPMML = "org.jpmml" % "jpmml-sparkml" % sparkPMMLVersion 50 | 51 | 52 | val modelsDependencies = Seq(jpmml, jpmmlextras, tensorflow) 53 | val kafkabaseDependencies = Seq(reactiveKafka) ++ Seq(kafka, kafkaclients) 54 | val kafkaDependencies = Seq(reactiveKafka) ++ Seq(kafka, kafkaclients, kafkastreams) 55 | val webDependencies = Seq(gson, jersey, jerseymedia, jettyserver, jettyservlet, wsrs) 56 | val akkaServerDependencies = Seq(reactiveKafka) ++ Seq(akkaStream, akkaHttp, akkaHttpJsonJackson, reactiveKafka) 57 | 58 | val flinkDependencies = Seq(flinkScala, flinkStreaming, flinkKafka) 59 | 60 | val sparkDependencies = Seq(sparkcore, sparkstreaming, sparkkafka) 61 | 62 | val sparkMLDependencies = Seq(sparkML, sparkJPMML, scopt) 63 | 64 | 65 | } 66 | -------------------------------------------------------------------------------- /project/Versions.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by boris on 7/14/17. 3 | */ 4 | object Versions { 5 | val reactiveKafkaVersion = "0.16" 6 | val akkaVersion = "2.4.19" 7 | val akkaSlfVersion = "2.3.2" // older version to align with Flink 8 | val akkaHttpVersion = "10.0.9" 9 | val akkaHttpJsonVersion = "1.17.0" 10 | 11 | val Curator = "3.2.0" 12 | val kafkaVersion = "0.10.2.1" 13 | val tensorflowVersion = "1.1.0" 14 | val PMMLVersion = "1.3.5" 15 | val jettyVersion = "9.2.12.v20150709" 16 | val jacksonVersion = "2.8.8" 17 | val jerseyVersion = "2.25" 18 | val gsonVersion = "2.6.2" 19 | val wsrsVersion = "2.0.1" 20 | 21 | val flinkVersion = "1.3.1" 22 | val jodaVersion = "2.9.7" 23 | 24 | val kryoVersion = "2.24.0" 25 | val sparkVersion = "2.2.0" 26 | 27 | val scoptVersion = "3.5.0" 28 | val sparkPMMLVersion = "1.1.7" 29 | 30 | val beamVersion = "2.0.0" 31 | 32 | 33 | } 34 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.1") -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.15 -------------------------------------------------------------------------------- /project/scalapb.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.6") 2 | 3 | libraryDependencies += "com.trueaccord.scalapb" %% "compilerplugin" % "0.6.0-pre3" -------------------------------------------------------------------------------- /protobufs/src/main/protobuf/modeldescriptor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_package = "com.lightbend.model"; 4 | 5 | 6 | // Description of the trained model. 7 | message ModelDescriptor { 8 | // Model name 9 | string name = 1; 10 | // Human readable description. 11 | string description = 2; 12 | // Data type for which this model is applied. 13 | string dataType = 3; 14 | // Model type 15 | enum ModelType { 16 | TENSORFLOW = 0; 17 | TENSORFLOWSAVED = 1; 18 | PMML = 2; 19 | }; 20 | ModelType modeltype = 4; 21 | oneof MessageContent { 22 | // Byte array containing the model 23 | bytes data = 5; 24 | string location = 6; 25 | } 26 | } -------------------------------------------------------------------------------- /protobufs/src/main/protobuf/winerecord.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option java_package = "com.lightbend.model"; 4 | 5 | // Description of the wine. 6 | message WineRecord { 7 | double fixed_acidity = 1; 8 | double volatile_acidity = 2; 9 | double citric_acid = 3; 10 | double residual_sugar = 4; 11 | double chlorides = 5; 12 | double free_sulfur_dioxide = 6; 13 | double total_sulfur_dioxide = 7; 14 | double density = 8; 15 | double pH = 9; 16 | double sulphates = 10; 17 | double alcohol = 11; 18 | // Data type for this record 19 | string dataType = 12; 20 | } -------------------------------------------------------------------------------- /servingsamples/src/main/scala/com/lightbend/jpmml/WineQualityRandomForestClassifier.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.jpmml 2 | 3 | import java.io.{FileInputStream, InputStream} 4 | 5 | import org.dmg.pmml.{FieldName, PMML} 6 | import org.jpmml.evaluator.visitors.{ExpressionOptimizer, _} 7 | import org.jpmml.evaluator.{Computable, FieldValue, ModelEvaluatorFactory, TargetField} 8 | import org.jpmml.model.PMMLUtil 9 | 10 | import scala.collection.JavaConversions._ 11 | import scala.collection._ 12 | import scala.io.Source 13 | 14 | /** 15 | * Created by boris on 5/2/17. 16 | * Test of JPMML evaluator for the model created by DecisionTreeClassificator in SparkML and exported 17 | * using JPMML Spark https://github.com/jpmml/jpmml-spark 18 | * Implementation is based on JPMML example at 19 | * https://github.com/jpmml/jpmml-evaluator/blob/master/pmml-evaluator-example/src/main/java/org/jpmml/evaluator/EvaluationExample.java 20 | */ 21 | 22 | 23 | class WineQualityRandomForestClassifier(path : String) { 24 | import WineQualityRandomForestClassifier._ 25 | 26 | var arguments = mutable.Map[FieldName, FieldValue]() 27 | 28 | // constructor 29 | val pmml: PMML = readPMML(path) 30 | optimize(pmml) 31 | 32 | // Create and verify evaluator 33 | val evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml) 34 | evaluator.verify() 35 | 36 | // Get input/target fields 37 | val inputFields = evaluator.getInputFields 38 | val target: TargetField = evaluator.getTargetFields.get(0) 39 | val tname = target.getName 40 | 41 | def score(record : Array[Float]) : Double = { 42 | arguments.clear() 43 | inputFields.foreach(field => { 44 | arguments.put(field.getName, field.prepare(getValueByName(record, field.getName.getValue))) 45 | }) 46 | 47 | // Calculate Output// Calculate Output 48 | val result = evaluator.evaluate(arguments) 49 | 50 | // Prepare output 51 | result.get(tname) match { 52 | case c : Computable => c.getResult.toString.toDouble 53 | case v : Any => v.asInstanceOf[Double] 54 | } 55 | } 56 | 57 | private def getValueByName(inputs : Array[Float], name: String) : Double = { 58 | names.get(name) match { 59 | case Some(index) => { 60 | val v = inputs(index) 61 | v.asInstanceOf[Double] 62 | } 63 | case _ =>.0 64 | } 65 | } 66 | } 67 | 68 | object WineQualityRandomForestClassifier { 69 | 70 | def main(args: Array[String]): Unit = { 71 | val model_path = "data/winequalityRandonForrestClassification.pmml" // model 72 | val data_path = "data/winequality_red.csv" // data 73 | val lmodel = new WineQualityRandomForestClassifier(model_path) 74 | 75 | val inputs = getListOfRecords(data_path) 76 | inputs.foreach(record => 77 | println(s"result ${lmodel.score(record._1)} expected ${record._2}")) 78 | } 79 | 80 | def readPMML(file: String): PMML = { 81 | var is = null.asInstanceOf[InputStream] 82 | try { 83 | is = new FileInputStream(file) 84 | PMMLUtil.unmarshal(is) 85 | } 86 | finally if (is != null) is.close() 87 | } 88 | 89 | private val optimizers = Array(new ExpressionOptimizer, new FieldOptimizer, new PredicateOptimizer, 90 | new GeneralRegressionModelOptimizer, new NaiveBayesModelOptimizer, new RegressionModelOptimizer) 91 | 92 | def optimize(pmml : PMML) = this.synchronized { 93 | optimizers.foreach(opt => 94 | try { 95 | opt.applyTo(pmml) 96 | } catch { 97 | case t: Throwable => { 98 | println(s"Error optimizing model for optimizer $opt") 99 | t.printStackTrace() 100 | } 101 | } 102 | ) 103 | } 104 | 105 | def getListOfRecords(file: String): Seq[(Array[Float], Float)] = { 106 | 107 | var result = Seq.empty[(Array[Float], Float)] 108 | val bufferedSource = Source.fromFile(file) 109 | var current = 0 110 | for (line <- bufferedSource.getLines) { 111 | if (current == 0) 112 | current = 1 113 | else { 114 | val cols = line.split(";").map(_.trim) 115 | val record = Array( 116 | cols(0).toFloat, cols(1).toFloat, cols(2).toFloat, cols(3).toFloat, cols(4).toFloat, 117 | cols(5).toFloat, cols(6).toFloat, cols(7).toFloat, cols(8).toFloat, cols(9).toFloat, cols(10).toFloat) 118 | result = (record, cols(11).toFloat) +: result 119 | } 120 | } 121 | bufferedSource.close 122 | result 123 | } 124 | 125 | private val names = Map("fixed acidity" -> 0, 126 | "volatile acidity" -> 1,"citric acid" ->2,"residual sugar" -> 3, 127 | "chlorides" -> 4,"free sulfur dioxide" -> 5,"total sulfur dioxide" -> 6, 128 | "density" -> 7,"pH" -> 8,"sulphates" ->9,"alcohol" -> 10) 129 | } 130 | -------------------------------------------------------------------------------- /servingsamples/src/main/scala/com/lightbend/tensorflow/WineModelServing.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.tensorflow 2 | 3 | import java.nio.file.{Files, Path, Paths} 4 | import scala.io.Source 5 | import org.tensorflow.{Graph, Session, Tensor} 6 | 7 | /** 8 | * Created by boris on 5/25/17. 9 | */ 10 | class WineModelServing(path : String) { 11 | import WineModelServing._ 12 | 13 | // Constructor 14 | println(s"Loading saved model from $path") 15 | val lg = readGraph(Paths.get (path)) 16 | val ls = new Session (lg) 17 | println("Model Loading complete") 18 | 19 | def score(record : Array[Float]) : Double = { 20 | val input = Tensor.create(Array(record)) 21 | val result = ls.runner.feed("dense_1_input", input).fetch("dense_3/Sigmoid").run().get(0) 22 | val rshape = result.shape 23 | var rMatrix = Array.ofDim[Float](rshape(0).asInstanceOf[Int],rshape(1).asInstanceOf[Int]) 24 | result.copyTo(rMatrix) 25 | var value = (0, rMatrix(0)(0)) 26 | 1 to (rshape(1).asInstanceOf[Int] -1) foreach{i => { 27 | if(rMatrix(0)(i) > value._2) 28 | value = (i, rMatrix(0)(i)) 29 | }} 30 | value._1.toDouble 31 | } 32 | 33 | def cleanup() : Unit = { 34 | ls.close 35 | } 36 | } 37 | 38 | object WineModelServing{ 39 | def main(args: Array[String]): Unit = { 40 | val model_path = "data/optimized_WineQuality.pb" // model 41 | val data_path = "data/winequality_red.csv" // data 42 | 43 | val lmodel = new WineModelServing(model_path) 44 | val inputs = getListOfRecords(data_path) 45 | inputs.foreach(record => 46 | println(s"result ${lmodel.score(record._1)} expected ${record._2}")) 47 | lmodel.cleanup() 48 | } 49 | 50 | private def readGraph(path: Path) : Graph = { 51 | try { 52 | val graphData = Files.readAllBytes(path) 53 | val g = new Graph 54 | g.importGraphDef(graphData) 55 | g 56 | } catch { 57 | case e: Throwable => 58 | println("Failed to read graph [" + path + "]: " + e.getMessage) 59 | System.exit(1) 60 | null.asInstanceOf[Graph] 61 | } 62 | } 63 | 64 | def getListOfRecords(file: String): Seq[(Array[Float], Float)] = { 65 | 66 | var result = Seq.empty[(Array[Float], Float)] 67 | val bufferedSource = Source.fromFile(file) 68 | try for (line <- bufferedSource.getLines) { 69 | val cols = line.split(";").map(_.trim) 70 | val record = cols.take(11).map(_.toFloat) 71 | result = (record, cols(11).toFloat) +: result 72 | } finally 73 | bufferedSource.close 74 | result 75 | } 76 | 77 | } -------------------------------------------------------------------------------- /servingsamples/src/main/scala/com/lightbend/tensorflow/WineModelServingBundle.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.tensorflow 2 | 3 | /** 4 | * Created by boris on 5/26/17. 5 | */ 6 | 7 | 8 | import org.tensorflow.{SavedModelBundle, Session, Tensor} 9 | import org.tensorflow.framework.{MetaGraphDef, SignatureDef, TensorInfo, TensorShapeProto} 10 | 11 | import scala.collection.mutable._ 12 | import scala.collection.JavaConverters._ 13 | import scala.io.Source 14 | 15 | object WineModelServingBundle { 16 | 17 | def apply(path: String, label: String): WineModelServingBundle = new WineModelServingBundle(path, label) 18 | def main(args: Array[String]): Unit = { 19 | val data_path = "data/winequality_red.csv" // data 20 | val saved_model_path = "data/WineQuality" // Saved model directory 21 | val label = "serve" 22 | val model = WineModelServingBundle(saved_model_path, label) 23 | val inputs = getListOfRecords(data_path) 24 | inputs.foreach(record => 25 | println(s"result ${model.score(record._1)} expected ${record._2}")) 26 | model.cleanup() 27 | } 28 | 29 | def getListOfRecords(file: String): Seq[(Array[Float], Float)] = { 30 | 31 | var result = Seq.empty[(Array[Float], Float)] 32 | val bufferedSource = Source.fromFile(file) 33 | for (line <- bufferedSource.getLines) { 34 | val cols = line.split(";").map(_.trim) 35 | val record = cols.take(11).map(_.toFloat) 36 | result = (record, cols(11).toFloat) +: result 37 | } 38 | bufferedSource.close 39 | result 40 | } 41 | } 42 | 43 | class WineModelServingBundle(path : String, label : String){ 44 | // Constructor 45 | 46 | println(s"Loading saved model from $path with label $label") 47 | val bundle = SavedModelBundle.load(path, label) 48 | val ls: Session = bundle.session 49 | val metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef()) 50 | val signatures = parseSignature(metaGraphDef.getSignatureDefMap.asScala) 51 | println("Model Loading complete") 52 | 53 | def score(record : Array[Float]) : Double = { 54 | val input = Tensor.create(Array(record)) 55 | val result = ls.runner.feed(signatures(0).inputs(0).name, input).fetch(signatures(0).outputs(0).name).run().get(0) 56 | val rshape = result.shape 57 | var rMatrix = Array.ofDim[Float](rshape(0).asInstanceOf[Int],rshape(1).asInstanceOf[Int]) 58 | result.copyTo(rMatrix) 59 | var value = (0, rMatrix(0)(0)) 60 | 1 to (rshape(1).asInstanceOf[Int] -1) foreach{i => { 61 | if(rMatrix(0)(i) > value._2) 62 | value = (i, rMatrix(0)(i)) 63 | }} 64 | value._1.toDouble 65 | } 66 | 67 | def cleanup() : Unit = { 68 | ls.close 69 | } 70 | 71 | def convertParameters(tensorInfo: Map[String,TensorInfo]) : Seq[Parameter] = { 72 | 73 | var parameters = Seq.empty[Parameter] 74 | tensorInfo.foreach(input => { 75 | val fields = input._2.getAllFields.asScala 76 | var name = "" 77 | var dtype = "" 78 | var shape = Seq.empty[Int] 79 | fields.foreach(descriptor => { 80 | if(descriptor._1.getName.contains("shape") ){ 81 | descriptor._2.asInstanceOf[TensorShapeProto].getDimList.toArray.map(d => 82 | d.asInstanceOf[TensorShapeProto.Dim].getSize).toSeq.foreach(v => shape = shape :+ v.toInt) 83 | 84 | } 85 | if(descriptor._1.getName.contains("name") ) { 86 | name = descriptor._2.toString.split(":")(0) 87 | } 88 | if(descriptor._1.getName.contains("dtype") ) { 89 | dtype = descriptor._2.toString 90 | } 91 | }) 92 | parameters = Parameter(name, dtype, shape) +: parameters 93 | }) 94 | parameters 95 | } 96 | 97 | def parseSignature(signatureMap : Map[String, SignatureDef]) : Seq[Signature] = { 98 | 99 | var signatures = Seq.empty[Signature] 100 | signatureMap.foreach(definition => { 101 | val inputDefs = definition._2.getInputsMap.asScala 102 | val outputDefs = definition._2.getOutputsMap.asScala 103 | val inputs = convertParameters(inputDefs) 104 | val outputs = convertParameters(outputDefs) 105 | signatures = Signature(definition._1, inputs, outputs) +: signatures 106 | }) 107 | signatures 108 | } 109 | } 110 | 111 | case class Parameter(name : String, dtype: String, dimensions: Seq[Int] = Seq.empty[Int]){} 112 | 113 | case class Signature(name : String, inputs: Seq[Parameter], outputs: Seq[Parameter]){} -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WineQualityDecisionTreeClassifier.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | /** 4 | * Created by boris on 5/1/17. 5 | * 6 | * Decision tree learning uses a decision tree as a predictive model observations about an item (represented in the 7 | * branches) to conclusions about the item's target value (represented in the leaves). It is one of the predictive 8 | * modelling approaches used in statistics, data mining and machine learning. Tree models where the target variable 9 | * can take a finite set of values are called classification trees; in these tree structures, leaves represent class 10 | * labels and branches represent conjunctions of features that lead to those class labels. Decision trees where the 11 | * target variable can take continuous values (typically real numbers) are called regression trees. 12 | * 13 | */ 14 | 15 | import org.apache.spark.ml.Pipeline 16 | import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} 17 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler} 18 | import org.apache.spark.sql.SparkSession 19 | import org.apache.spark.sql.functions._ 20 | import org.jpmml.model.MetroJAXBUtil 21 | import org.jpmml.sparkml.ConverterUtil 22 | 23 | 24 | object WineQualityDecisionTreeClassifier { 25 | 26 | def main(args: Array[String]): Unit = { 27 | val spark = SparkSession 28 | .builder 29 | .appName("WineQualityDecisionTreeClassifierPMML") 30 | .master("local") 31 | .getOrCreate() 32 | 33 | // Load and parse the data file. 34 | val df = spark.read 35 | .format("csv") 36 | .option("header", "true") 37 | .option("mode", "DROPMALFORMED") 38 | .option("delimiter", ";") 39 | .load("data/winequality_red_names.csv") 40 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 41 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 42 | 43 | // CSV imports everything as Strings, fix the type 44 | val toDouble = udf[Double, String]( _.toDouble) 45 | val dff = df. 46 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 + 47 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 + 48 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 - 49 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 + 50 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 - 51 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 + 52 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 + 53 | withColumn("density", toDouble(df("density"))). // 7 - 54 | withColumn("pH", toDouble(df("pH"))). // 8 + 55 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 + 56 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 + 57 | 58 | 59 | // Decision Tree operates on feature vectors not individual features, so convert to DF again 60 | val assembler = new VectorAssembler(). 61 | setInputCols(inputFields.toArray). 62 | setOutputCol("features") 63 | 64 | // Fit on whole dataset to include all labels in index. 65 | val labelIndexer = new StringIndexer() 66 | .setInputCol("quality") 67 | .setOutputCol("indexedLabel") 68 | .fit(dff) 69 | 70 | // Train a DecisionTree model. 71 | val dt = new DecisionTreeClassifier() 72 | .setLabelCol("indexedLabel") 73 | .setFeaturesCol("features") 74 | 75 | // Convert indexed labels back to original labels. 76 | val labelConverter = new IndexToString() 77 | .setInputCol("prediction") 78 | .setOutputCol("predictedLabel") 79 | .setLabels(labelIndexer.labels) 80 | 81 | // create pileline 82 | val pipeline = new Pipeline() 83 | .setStages(Array(assembler, labelIndexer, dt, labelConverter)) 84 | 85 | // Train model 86 | val model = pipeline.fit(dff) 87 | 88 | // Print results 89 | val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] 90 | println("Learned classification tree model:\n" + treeModel.toDebugString) 91 | 92 | // PMML 93 | val schema = dff.schema 94 | val pmml = ConverterUtil.toPMML(schema, model) 95 | MetroJAXBUtil.marshalPMML(pmml, System.out) 96 | spark.stop() 97 | } 98 | } -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WineQualityDecisionTreeRegressor.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | /** 4 | * Created by boris on 5/1/17. 5 | * 6 | * Decision tree learning uses a decision tree as a predictive model which maps observations about an item (represented in the 7 | * branches) to conclusions about the item's target value (represented in the leaves). It is one of the predictive modelling 8 | * approaches used in statistics, data mining and machine learning. Tree models where the target variable can take a finite set of 9 | * values are called classification trees; in these tree structures, leaves represent class labels and branches represent 10 | * conjunctions of features that lead to those class labels. Decision trees where the target variable can take continuous values 11 | * (typically real numbers) are called regression trees. 12 | * In decision analysis, a decision tree can be used to visually and explicitly represent decisions and decision making. In data 13 | * mining, a decision tree describes data (but the resulting classification tree can be an input for decision making). This page 14 | * deals with decision trees in data mining. 15 | */ 16 | 17 | import org.apache.spark.ml.Pipeline 18 | import org.apache.spark.ml.feature.VectorAssembler 19 | import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} 20 | import org.apache.spark.sql.SparkSession 21 | import org.apache.spark.sql.functions._ 22 | import org.jpmml.model.MetroJAXBUtil 23 | import org.jpmml.sparkml.ConverterUtil 24 | 25 | 26 | object WineQualityDecisionTreeRegressor { 27 | 28 | def main(args: Array[String]): Unit = { 29 | val spark = SparkSession 30 | .builder 31 | .appName("WineQualityDecisionTreeRegressorPMML") 32 | .master("local") 33 | .getOrCreate() 34 | 35 | // Load and parse the data file. 36 | val df = spark.read 37 | .format("csv") 38 | .option("header", "true") 39 | .option("mode", "DROPMALFORMED") 40 | .option("delimiter", ";") 41 | .load("data/winequality_red_names.csv") 42 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 43 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 44 | 45 | // CSV imports everything as Strings, fix the type 46 | val toInt = udf[Int, String]( _.toInt) 47 | val toDouble = udf[Double, String]( _.toDouble) 48 | val dff = df. 49 | withColumn("quality", toInt(df("quality"))). 50 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 + 51 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 + 52 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 - 53 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 + 54 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 - 55 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 + 56 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 + 57 | withColumn("density", toDouble(df("density"))). // 7 - 58 | withColumn("pH", toDouble(df("pH"))). // 8 + 59 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 + 60 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 + 61 | 62 | 63 | // Decision Tree operates on feature vectors not individual features, so convert to DF again 64 | val assembler = new VectorAssembler(). 65 | setInputCols(inputFields.toArray). 66 | setOutputCol("features") 67 | 68 | // Train a DecisionTree model. 69 | val dt = new DecisionTreeRegressor() 70 | .setLabelCol("quality") 71 | .setFeaturesCol("features") 72 | 73 | // create pileline 74 | val pipeline = new Pipeline() 75 | .setStages(Array(assembler, dt)) 76 | 77 | // Train model 78 | val model = pipeline.fit(dff) 79 | 80 | // Print results 81 | val lrModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] 82 | println("Learned regression tree model:\n" + lrModel.toDebugString) 83 | 84 | // PMML 85 | val schema = dff.schema 86 | val pmml = ConverterUtil.toPMML(schema, model) 87 | MetroJAXBUtil.marshalPMML(pmml, System.out) 88 | spark.stop() 89 | } 90 | } -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WineQualityPerceptron.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | /** 4 | * Created by boris on 5/1/17. 5 | * 6 | * A multilayer perceptron (MLP) is a feedforward artificial neural network model that maps sets of input data 7 | * onto a set of appropriate outputs. An MLP consists of multiple layers of nodes in a directed graph, with each 8 | * layer fully connected to the next one. Except for the input nodes, each node is a neuron (or processing element) 9 | * with a nonlinear activation function. MLP utilizes a supervised learning technique called backpropagation for 10 | * training the network. MLP is a modification of the standard linear perceptron and can distinguish data that is 11 | * not linearly separable. 12 | */ 13 | 14 | import org.apache.spark.ml.Pipeline 15 | import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} 16 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler} 17 | import org.apache.spark.sql.SparkSession 18 | import org.apache.spark.sql.functions._ 19 | import org.jpmml.model.MetroJAXBUtil 20 | import org.jpmml.sparkml.ConverterUtil 21 | 22 | 23 | object WineQualityPerceptronPMML { 24 | 25 | def main(args: Array[String]): Unit = { 26 | val spark = SparkSession 27 | .builder 28 | .appName("WineQualityDecisionTreeRegressorPMML") 29 | .master("local") 30 | .getOrCreate() 31 | 32 | // Load and parse the data file. 33 | val df = spark.read 34 | .format("csv") 35 | .option("header", "true") 36 | .option("mode", "DROPMALFORMED") 37 | .option("delimiter", ";") 38 | .load("data/winequality_red_names.csv") 39 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 40 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 41 | 42 | // CSV imports everything as Strings, fix the type 43 | val toDouble = udf[Double, String]( _.toDouble) 44 | val dff = df. 45 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 + 46 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 + 47 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 - 48 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 + 49 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 - 50 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 + 51 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 + 52 | withColumn("density", toDouble(df("density"))). // 7 - 53 | withColumn("pH", toDouble(df("pH"))). // 8 + 54 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 + 55 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 + 56 | 57 | 58 | // Decision Tree operates on feature vectors not individual features, so convert to DF again 59 | val assembler = new VectorAssembler(). 60 | setInputCols(inputFields.toArray). 61 | setOutputCol("features") 62 | 63 | // Fit on whole dataset to include all labels in index. 64 | val labelIndexer = new StringIndexer() 65 | .setInputCol("quality") 66 | .setOutputCol("indexedLabel") 67 | .fit(dff) 68 | 69 | // specify layers for the neural network: 70 | // input layer of size 11 (features), two intermediate of size 10 and 20 71 | // and output of size 6 (classes) 72 | 73 | val layers = Array[Int](11, 10, 20, 6) 74 | 75 | // Train a DecisionTree model. 76 | val dt = new MultilayerPerceptronClassifier() 77 | .setLayers(layers) 78 | .setBlockSize(128) 79 | .setSeed(1234L) 80 | .setMaxIter(100) 81 | .setLabelCol("indexedLabel") 82 | .setFeaturesCol("features") 83 | 84 | // Convert indexed labels back to original labels. 85 | val labelConverter = new IndexToString() 86 | .setInputCol("prediction") 87 | .setOutputCol("predictedLabel") 88 | .setLabels(labelIndexer.labels) 89 | 90 | // create pileline 91 | val pipeline = new Pipeline() 92 | .setStages(Array(assembler, labelIndexer, dt, labelConverter)) 93 | 94 | // Train model 95 | val model = pipeline.fit(dff) 96 | 97 | // Print results 98 | val lrModel = model.stages(2).asInstanceOf[MultilayerPerceptronClassificationModel] 99 | println(s"Coefficients: \n${lrModel}") 100 | 101 | // PMML 102 | val schema = dff.schema 103 | val pmml = ConverterUtil.toPMML(schema, model) 104 | MetroJAXBUtil.marshalPMML(pmml, System.out) 105 | spark.stop() 106 | } 107 | } -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WineQualityRandomForrestClassifier.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | /** 4 | * Created by boris on 5/1/17. 5 | * 6 | * Random forests or random decision forests are an ensemble learning method for classification, regression and 7 | * other tasks, that operate by constructing a multitude of decision trees at training time and outputting the 8 | * class that is the mode of the classes (classification) or mean prediction (regression) of the individual trees. 9 | * Random decision forests correct for decision trees' habit of overfitting to their training set. 10 | */ 11 | 12 | import org.apache.spark.ml.Pipeline 13 | import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} 14 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler} 15 | import org.apache.spark.sql.SparkSession 16 | import org.apache.spark.sql.functions._ 17 | import org.jpmml.model.MetroJAXBUtil 18 | import org.jpmml.sparkml.ConverterUtil 19 | 20 | 21 | object WineQualityRandomForrestClassifierPMML { 22 | 23 | def main(args: Array[String]): Unit = { 24 | val spark = SparkSession 25 | .builder 26 | .appName("WineQualityDecisionTreeClassifierPMML") 27 | .master("local") 28 | .getOrCreate() 29 | 30 | // Load and parse the data file. 31 | val df = spark.read 32 | .format("csv") 33 | .option("header", "true") 34 | .option("mode", "DROPMALFORMED") 35 | .option("delimiter", ";") 36 | .load("data/winequality_red_names.csv") 37 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 38 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 39 | 40 | // CSV imports everything as Strings, fix the type 41 | val toDouble = udf[Double, String]( _.toDouble) 42 | val dff = df. 43 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 + 44 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 + 45 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 - 46 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 + 47 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 - 48 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 + 49 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 + 50 | withColumn("density", toDouble(df("density"))). // 7 - 51 | withColumn("pH", toDouble(df("pH"))). // 8 + 52 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 + 53 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 + 54 | 55 | 56 | // Decision Tree operates on feature vectors not individual features, so convert to DF again 57 | val assembler = new VectorAssembler(). 58 | setInputCols(inputFields.toArray). 59 | setOutputCol("features") 60 | 61 | // Fit on whole dataset to include all labels in index. 62 | val labelIndexer = new StringIndexer() 63 | .setInputCol("quality") 64 | .setOutputCol("indexedLabel") 65 | .fit(dff) 66 | 67 | // Train a DecisionTree model. 68 | val dt = new RandomForestClassifier() 69 | .setLabelCol("indexedLabel") 70 | .setFeaturesCol("features") 71 | .setNumTrees(10) 72 | 73 | // Convert indexed labels back to original labels. 74 | val labelConverter = new IndexToString() 75 | .setInputCol("prediction") 76 | .setOutputCol("predictedLabel") 77 | .setLabels(labelIndexer.labels) 78 | 79 | // create pileline 80 | val pipeline = new Pipeline() 81 | .setStages(Array(assembler, labelIndexer, dt, labelConverter)) 82 | 83 | // Train model 84 | val model = pipeline.fit(dff) 85 | 86 | // Print results 87 | val treeModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] 88 | println("Learned classification tree model:\n" + treeModel.toDebugString) 89 | 90 | // PMML 91 | val schema = dff.schema 92 | val pmml = ConverterUtil.toPMML(schema, model) 93 | MetroJAXBUtil.marshalPMML(pmml, System.out) 94 | spark.stop() 95 | } 96 | } -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WinequalityGeneralizedLinearRegression.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | import org.apache.spark.ml.Pipeline 4 | import org.apache.spark.ml.feature.VectorAssembler 5 | import org.apache.spark.ml.regression.{GeneralizedLinearRegression, GeneralizedLinearRegressionModel} 6 | import org.apache.spark.sql.SparkSession 7 | import org.apache.spark.sql.functions.udf 8 | import org.jpmml.model.MetroJAXBUtil 9 | import org.jpmml.sparkml.ConverterUtil 10 | 11 | 12 | /** 13 | * Created by boris on 5/3/17. 14 | * 15 | * Contrasted with linear regression where the output is assumed to follow a Gaussian distribution, generalized 16 | * linear models (GLMs) are specifications of linear models where the response variable YiYi follows some 17 | * distribution from the exponential family of distributions. Spark’s GeneralizedLinearRegression interface allows 18 | * for flexible specification of GLMs which can be used for various types of prediction problems including linear 19 | * regression, Poisson regression, logistic regression, and others. Currently in spark.ml, only a following subset 20 | * of the exponential family distributions are supported: 21 | * Family Response Type Supported Links 22 | * Gaussian Continuous Identity*, Log, Inverse 23 | * Binomial Binary Logit*, Probit, CLogLog 24 | * Poisson Count Log*, Identity, Sqrt 25 | * Gamma Continuous Inverse*, Idenity, Log 26 | */ 27 | 28 | object WinequalityGeneralizedLinearRegression { 29 | 30 | def main(args: Array[String]): Unit = { 31 | 32 | val spark = SparkSession 33 | .builder 34 | .appName("WineQualityDecisionTreeClassifierPMML") 35 | .master("local") 36 | .getOrCreate() 37 | 38 | // Load and parse the data file. 39 | val df = spark.read 40 | .format("csv") 41 | .option("header", "true") 42 | .option("mode", "DROPMALFORMED") 43 | .option("delimiter", ";") 44 | .load("data/winequality_red_names.csv") 45 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 46 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 47 | 48 | // CSV imports everything as Strings, fix the type 49 | val toDouble = udf[Double, String](_.toDouble) 50 | val dff = df. 51 | withColumn("quality", toDouble(df("quality"))). 52 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 53 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 54 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 55 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 56 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 57 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 58 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 59 | withColumn("density", toDouble(df("density"))). // 7 60 | withColumn("pH", toDouble(df("pH"))). // 8 61 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 62 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 63 | 64 | // Regression operates on feature vectors not individual features, so convert to DF again 65 | val assembler = new VectorAssembler(). 66 | setInputCols(inputFields.toArray). 67 | setOutputCol("features") 68 | 69 | val lr = new GeneralizedLinearRegression() 70 | // .setFamily("gaussian") 71 | .setFamily("gamma") 72 | .setLink("identity") 73 | .setMaxIter(100) 74 | .setRegParam(0.3) 75 | .setFeaturesCol("features") 76 | .setLabelCol("quality") 77 | 78 | // create pileline 79 | val pipeline = new Pipeline().setStages(Array(assembler,lr)) 80 | 81 | // Fit the model 82 | // val lrModel = lr.fit(dataFrame) 83 | val model = pipeline.fit(dff) 84 | val lrModel = model.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] 85 | 86 | 87 | // Summarize the model over the training set and print out some metrics 88 | // Print the coefficients and intercept for generalized linear regression model 89 | println(s"Coefficients: ${lrModel.coefficients}") 90 | println(s"Intercept: ${lrModel.intercept}") 91 | 92 | // Summarize the model over the training set and print out some metrics 93 | val summary = lrModel.summary 94 | println(s"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString(",")}") 95 | println(s"T Values: ${summary.tValues.mkString(",")}") 96 | println(s"P Values: ${summary.pValues.mkString(",")}") 97 | println(s"Dispersion: ${summary.dispersion}") 98 | println(s"Null Deviance: ${summary.nullDeviance}") 99 | println(s"Residual Degree Of Freedom Null: ${summary.residualDegreeOfFreedomNull}") 100 | println(s"Deviance: ${summary.deviance}") 101 | println(s"Residual Degree Of Freedom: ${summary.residualDegreeOfFreedom}") 102 | println(s"AIC: ${summary.aic}") 103 | println("Deviance Residuals: ") 104 | summary.residuals().show() 105 | 106 | // PMML 107 | val pmml = ConverterUtil.toPMML(dff.schema, model) 108 | MetroJAXBUtil.marshalPMML(pmml, System.out) 109 | spark.stop() 110 | } 111 | } -------------------------------------------------------------------------------- /sparkML/src/main/scala/com/lightbend/spark/ml/WinequalityLinearRegression.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.spark.ml 2 | 3 | import org.apache.spark.ml.Pipeline 4 | import org.apache.spark.ml.feature.VectorAssembler 5 | import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} 6 | import org.apache.spark.sql.SparkSession 7 | import org.apache.spark.sql.functions.udf 8 | import org.jpmml.model.MetroJAXBUtil 9 | import org.jpmml.sparkml.ConverterUtil 10 | 11 | 12 | /** 13 | * Created by boris on 5/3/17. 14 | * Ordinary Linear regression 15 | * 16 | */ 17 | 18 | object WinequalityLinearRegression { 19 | 20 | def main(args: Array[String]): Unit = { 21 | 22 | val spark = SparkSession 23 | .builder 24 | .appName("WineQualityDecisionTreeClassifierPMML") 25 | .master("local") 26 | .getOrCreate() 27 | 28 | // Load and parse the data file. 29 | val df = spark.read 30 | .format("csv") 31 | .option("header", "true") 32 | .option("mode", "DROPMALFORMED") 33 | .option("delimiter", ";") 34 | .load("data/winequality_red_names.csv") 35 | val inputFields = List("fixed acidity", "volatile acidity", "citric acid", "residual sugar", "chlorides", 36 | "free sulfur dioxide", "total sulfur dioxide", "density", "pH", "sulphates", "alcohol") 37 | 38 | // CSV imports everything as Strings, fix the type 39 | val toDouble = udf[Double, String](_.toDouble) 40 | val dff = df. 41 | withColumn("quality", toDouble(df("quality"))). 42 | withColumn("fixed acidity", toDouble(df("fixed acidity"))). // 0 43 | withColumn("volatile acidity", toDouble(df("volatile acidity"))). // 1 44 | withColumn("citric acid", toDouble(df("citric acid"))). // 2 45 | withColumn("residual sugar", toDouble(df("residual sugar"))). // 3 46 | withColumn("chlorides", toDouble(df("chlorides"))). // 4 47 | withColumn("free sulfur dioxide", toDouble(df("free sulfur dioxide"))). // 5 48 | withColumn("total sulfur dioxide", toDouble(df("total sulfur dioxide"))). // 6 49 | withColumn("density", toDouble(df("density"))). // 7 50 | withColumn("pH", toDouble(df("pH"))). // 8 51 | withColumn("sulphates", toDouble(df("sulphates"))). // 9 52 | withColumn("alcohol", toDouble(df("alcohol"))) // 10 53 | 54 | // Regression operates on feature vectors not individual features, so convert to DF again 55 | val assembler = new VectorAssembler(). 56 | setInputCols(inputFields.toArray). 57 | setOutputCol("features") 58 | 59 | val lr = new LinearRegression() 60 | .setMaxIter(50) 61 | .setRegParam(0.1) 62 | .setElasticNetParam(0.5) 63 | .setFeaturesCol("features") 64 | .setLabelCol("quality") 65 | 66 | // create pileline 67 | val pipeline = new Pipeline().setStages(Array(assembler,lr)) 68 | 69 | // Fit the model 70 | // val lrModel = lr.fit(dataFrame) 71 | val model = pipeline.fit(dff) 72 | val lrModel = model.stages(1).asInstanceOf[LinearRegressionModel] 73 | 74 | // Print the coefficients and intercept for linear regression 75 | println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") 76 | 77 | // Summarize the model over the training set and print out some metrics 78 | val trainingSummary = lrModel.summary 79 | println(s"numIterations: ${trainingSummary.totalIterations}") 80 | println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]") 81 | trainingSummary.residuals.show() 82 | println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") 83 | println(s"r2: ${trainingSummary.r2}") 84 | 85 | // PMML 86 | val pmml = ConverterUtil.toPMML(dff.schema, model) 87 | MetroJAXBUtil.marshalPMML(pmml, System.out) 88 | spark.stop() 89 | } 90 | } -------------------------------------------------------------------------------- /sparkserver/src/main/scala/com/lightbend/modelserver/DataRecord.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | import scala.util.Try 4 | import com.lightbend.model.modeldescriptor.ModelDescriptor 5 | import com.lightbend.model.winerecord.WineRecord 6 | 7 | /** 8 | * Created by boris on 5/8/17. 9 | */ 10 | 11 | object DataRecord { 12 | 13 | def fromByteArray(message: Array[Byte]): Try[WineRecord] = Try { 14 | WineRecord.parseFrom(message) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /sparkserver/src/main/scala/com/lightbend/modelserver/ModelSerializerKryo.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | /** 4 | * Created by boris on 6/2/17. 5 | */ 6 | 7 | import com.esotericsoftware.kryo.io.{Input, Output} 8 | import com.esotericsoftware.kryo.{Kryo, Serializer} 9 | import com.lightbend.model.modeldescriptor.ModelDescriptor 10 | import com.lightbend.model.scala.Model 11 | import com.lightbend.model.scala.PMML.PMMLModel 12 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 13 | import org.apache.spark.serializer.KryoRegistrator 14 | 15 | 16 | class ModelSerializerKryo extends Serializer[Model]{ 17 | 18 | super.setAcceptsNull(false) 19 | super.setImmutable(true) 20 | 21 | /** Reads bytes and returns a new object of the specified concrete type. 22 | *

23 | * Before Kryo can be used to read child objects, {@link Kryo#reference(Object)} must be called with the parent object to 24 | * ensure it can be referenced by the child objects. Any serializer that uses {@link Kryo} to read a child object may need to 25 | * be reentrant. 26 | *

27 | * This method should not be called directly, instead this serializer can be passed to {@link Kryo} read methods that accept a 28 | * serialier. 29 | * 30 | * @return May be null if { @link #getAcceptsNull()} is true. */ 31 | 32 | override def read(kryo: Kryo, input: Input, `type`: Class[Model]): Model = { 33 | import ModelSerializerKryo._ 34 | 35 | println("KRYO deserialization") 36 | val mType = input.readLong().asInstanceOf[Int] 37 | val bytes = Stream.continually(input.readByte()).takeWhile(_ != -1).toArray 38 | factories.get(mType) match { 39 | case Some(factory) => factory.restore(bytes) 40 | case _ => throw new Exception(s"Unknown model type $mType to restore") 41 | } 42 | } 43 | 44 | /** Writes the bytes for the object to the output. 45 | *

46 | * This method should not be called directly, instead this serializer can be passed to {@link Kryo} write methods that accept a 47 | * serialier. 48 | * 49 | * @param value May be null if { @link #getAcceptsNull()} is true. */ 50 | 51 | override def write(kryo: Kryo, output: Output, value: Model): Unit = { 52 | println("KRYO serialization") 53 | output.writeLong(value.getType) 54 | output.write(value.toBytes) 55 | } 56 | } 57 | 58 | object ModelSerializerKryo{ 59 | private val factories = Map(ModelDescriptor.ModelType.PMML.value -> PMMLModel, 60 | ModelDescriptor.ModelType.TENSORFLOW.value -> TensorFlowModel) 61 | } 62 | 63 | class ModelRegistrator extends KryoRegistrator { 64 | override def registerClasses(kryo: Kryo) { 65 | kryo.register(classOf[Model], new ModelSerializerKryo()) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /sparkserver/src/main/scala/com/lightbend/modelserver/SparkModelServer.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver 2 | 3 | /** 4 | * Created by boris on 6/26/17. 5 | */ 6 | 7 | import com.lightbend.configuration.kafka.ApplicationKafkaParameters 8 | import com.lightbend.model.modeldescriptor.ModelDescriptor 9 | import com.lightbend.modelserver.kafka.KafkaSupport 10 | import com.lightbend.model.scala.PMML.PMMLModel 11 | import com.lightbend.model.scala.tensorflow.TensorFlowModel 12 | import com.lightbend.model.scala.{DataWithModel, Model} 13 | import com.lightbend.modelserver.support.scala.ModelToServe 14 | import org.apache.spark.streaming.kafka010._ 15 | import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent 16 | import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe 17 | import org.apache.spark.streaming._ 18 | import org.apache.spark.SparkConf 19 | 20 | 21 | object SparkModelServer { 22 | 23 | private val factories = Map(ModelDescriptor.ModelType.PMML.value -> PMMLModel, 24 | ModelDescriptor.ModelType.TENSORFLOW.value -> TensorFlowModel) 25 | 26 | def main(args: Array[String]): Unit = { 27 | // Create context 28 | val sparkConf = new SparkConf() 29 | .setAppName("SparkModelServer") 30 | .setMaster("local") 31 | sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 32 | sparkConf.set("spark.kryo.registrator", "com.lightbend.modelserver.ModelRegistrator") 33 | // sparkConf.set("spark.kryo.registrator", "com.lightbend.modelServer.model.TensorFlowModelRegistrator") 34 | val ssc = new StreamingContext(sparkConf, Seconds(1)) 35 | ssc.checkpoint("./cpt") 36 | 37 | // Initial state RDD for current models 38 | val modelsRDD = ssc.sparkContext.emptyRDD[(String, Model)] 39 | 40 | // Create models kafka stream 41 | val kafkaParams = KafkaSupport.getKafkaConsumerConfig(ApplicationKafkaParameters.LOCAL_KAFKA_BROKER) 42 | val modelsStream = KafkaUtils.createDirectStream[Array[Byte], Array[Byte]](ssc,PreferConsistent, 43 | Subscribe[Array[Byte], Array[Byte]](Set(ApplicationKafkaParameters.MODELS_TOPIC),kafkaParams)) 44 | // Create data kafka stream 45 | val dataStream = KafkaUtils.createDirectStream[Array[Byte], Array[Byte]](ssc, PreferConsistent, 46 | Subscribe[Array[Byte], Array[Byte]](Set(ApplicationKafkaParameters.DATA_TOPIC),kafkaParams)) 47 | // Convert streams 48 | val data = dataStream.map(r => 49 | DataRecord.fromByteArray(r.value())). 50 | filter(_.isSuccess).map(d => DataWithModel(None, Some(d.get))) 51 | val models = modelsStream.map(r => ModelToServe.fromByteArray(r.value())). 52 | filter(_.isSuccess).map(m => DataWithModel(Some(m.get), None)) 53 | // Combine streams 54 | val unionStream = ssc.union(Seq(data, models)).map(r => (r.getDataType, r)) 55 | 56 | // score model using mapWithState, where state contains current model 57 | val mappingFunc = (dataType: String, dataWithModel: Option[DataWithModel], state: State[Model]) => { 58 | val currentModel = state.getOption().getOrElse(null.asInstanceOf[Model]) 59 | dataWithModel match { 60 | case Some(value) => 61 | if (value.isModel) { 62 | // Proces model 63 | if (currentModel != null) currentModel.cleanup() 64 | val model = factories.get(value.getModel.modelType.value) match { 65 | case Some(factory) => factory.create(value.getModel) 66 | case _ => None 67 | } 68 | model match { 69 | case Some(m) => state.update(m) 70 | case _ => 71 | } 72 | None 73 | } 74 | else { 75 | // process data 76 | if (currentModel != null) 77 | Some(currentModel.score(value.getData.asInstanceOf[AnyVal]).asInstanceOf[Double]) 78 | else 79 | None 80 | } 81 | case _ => None 82 | } 83 | } 84 | // Define StateSpec - types are derived from function 85 | val resultDstream = unionStream.mapWithState(StateSpec.function(mappingFunc).initialState(modelsRDD)) 86 | resultDstream.print() 87 | // Execute 88 | ssc.start() 89 | ssc.awaitTermination() 90 | } 91 | } -------------------------------------------------------------------------------- /utils/src/main/java/com/lightbend/modelserver/support/java/DataConverter.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.support.java; 2 | 3 | import com.lightbend.model.Modeldescriptor; 4 | import com.lightbend.model.Winerecord; 5 | 6 | import java.util.Optional; 7 | 8 | /** 9 | * Created by boris on 6/28/17. 10 | */ 11 | public class DataConverter { 12 | 13 | private DataConverter(){} 14 | 15 | public static Optional convertData(byte[] binary){ 16 | try { 17 | // Unmarshall record 18 | return Optional.of(Winerecord.WineRecord.parseFrom(binary)); 19 | } catch (Throwable t) { 20 | // Oops 21 | System.out.println("Exception parsing input record" + new String(binary)); 22 | t.printStackTrace(); 23 | return Optional.empty(); 24 | } 25 | } 26 | 27 | public static Optional convertModel(byte[] binary){ 28 | try { 29 | // Unmarshall record 30 | Modeldescriptor.ModelDescriptor model = Modeldescriptor.ModelDescriptor.parseFrom(binary); 31 | // Return it 32 | if(model.getMessageContentCase().equals(Modeldescriptor.ModelDescriptor.MessageContentCase.DATA)){ 33 | return Optional.of(new ModelToServe( 34 | model.getName(), model.getDescription(), model.getModeltype(), 35 | model.getData().toByteArray(), null, model.getDataType())); 36 | } 37 | else { 38 | System.out.println("Location based model is not yet supported"); 39 | return Optional.empty(); 40 | } 41 | } catch (Throwable t) { 42 | // Oops 43 | System.out.println("Exception parsing input record" + new String(binary)); 44 | t.printStackTrace(); 45 | return Optional.empty(); 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /utils/src/main/java/com/lightbend/modelserver/support/java/ModelToServe.java: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.support.java; 2 | 3 | import com.lightbend.model.Modeldescriptor; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by boris on 6/28/17. 9 | */ 10 | public class ModelToServe implements Serializable { 11 | 12 | private String name; 13 | private String description; 14 | private Modeldescriptor.ModelDescriptor.ModelType modelType; 15 | private byte[] modelData; 16 | private String modelDataLocation; 17 | private String dataType; 18 | 19 | public ModelToServe(String name, String description, Modeldescriptor.ModelDescriptor.ModelType modelType, 20 | byte[] dataContent, String modelDataLocation, String dataType){ 21 | this.name = name; 22 | this.description = description; 23 | this.modelType = modelType; 24 | this.modelData = dataContent; 25 | this.modelDataLocation = modelDataLocation; 26 | this.dataType = dataType; 27 | } 28 | 29 | public String getName() { 30 | return name; 31 | } 32 | 33 | public String getDescription() { 34 | return description; 35 | } 36 | 37 | public Modeldescriptor.ModelDescriptor.ModelType getModelType() { 38 | return modelType; 39 | } 40 | 41 | public String getDataType() { 42 | return dataType; 43 | } 44 | 45 | public byte[] getModelData() { 46 | return modelData; 47 | } 48 | 49 | public String getModelDataLocation() { 50 | return modelDataLocation; 51 | } 52 | 53 | @Override 54 | public String toString() { 55 | return "ModelToServe{" + 56 | "name='" + name + '\'' + 57 | ", description='" + description + '\'' + 58 | ", modelType=" + modelType + 59 | ", dataType='" + dataType + '\'' + 60 | '}'; 61 | } 62 | } -------------------------------------------------------------------------------- /utils/src/main/scala/com/lightbend/modelserver/kafka/EmbeddedSingleNodeKafkaCluster.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.kafka 2 | 3 | import java.util.Properties 4 | 5 | import kafka.server.KafkaConfig 6 | import org.apache.curator.test.TestingServer 7 | import org.slf4j.LoggerFactory 8 | 9 | 10 | object EmbeddedSingleNodeKafkaCluster { 11 | private val log = LoggerFactory.getLogger("EmbeddedSingleNodeKafkaCluster") 12 | private val DEFAULT_BROKER_PORT = 9092 13 | private var zookeeper = null.asInstanceOf[TestingServer] 14 | private var broker = null.asInstanceOf[KafkaEmbedded] 15 | 16 | /** 17 | * Creates and starts a Kafka cluster. 18 | */ 19 | def start(): Unit = { 20 | // Ensure no duplcates 21 | if (broker != null) 22 | EmbeddedSingleNodeKafkaCluster.log.debug("Embedded Kafka cluster is already running... returning") 23 | else { 24 | val brokerConfig = new Properties 25 | EmbeddedSingleNodeKafkaCluster.log.debug("Initiating embedded Kafka cluster startup") 26 | EmbeddedSingleNodeKafkaCluster.log.debug("Starting a ZooKeeper instance") 27 | try 28 | zookeeper = new TestingServer(2181) 29 | catch { 30 | case e: Exception => 31 | // TODO Auto-generated catch block 32 | e.printStackTrace() 33 | } 34 | EmbeddedSingleNodeKafkaCluster.log.debug(s"ZooKeeper instance is running at $zKConnectString") 35 | brokerConfig.put(KafkaConfig.ZkConnectProp, zKConnectString) 36 | brokerConfig.put(KafkaConfig.PortProp, DEFAULT_BROKER_PORT.toString) 37 | EmbeddedSingleNodeKafkaCluster.log.debug(s"Starting a Kafka instance on port ${brokerConfig.getProperty(KafkaConfig.PortProp)} ...") 38 | broker = new KafkaEmbedded(brokerConfig) 39 | broker.start() 40 | EmbeddedSingleNodeKafkaCluster.log.debug(s"Kafka instance is running at ${broker.brokerList}, connected to ZooKeeper at ${broker.zookeeperConnect}") 41 | } 42 | } 43 | 44 | /** 45 | * Stop the Kafka cluster. 46 | */ 47 | 48 | def stop(): Unit = { 49 | try { 50 | broker.stop() 51 | zookeeper.stop() 52 | } catch { 53 | case t: Throwable => 54 | } 55 | } 56 | 57 | /** 58 | * The ZooKeeper connection string aka `zookeeper.connect` in `hostnameOrIp:port` format. 59 | * Example: `127.0.0.1:2181`. 60 | * 61 | * You can use this to e.g. tell Kafka brokers how to connect to this instance. 62 | */ 63 | def zKConnectString: String = zookeeper.getConnectString 64 | 65 | /** 66 | * This cluster's `bootstrap.servers` value. Example: `127.0.0.1:9092`. 67 | * 68 | * You can use this to tell Kafka producers how to connect to this cluster. 69 | */ 70 | def bootstrapServers: String = broker.brokerList 71 | 72 | /** 73 | * Create a Kafka topic with 1 partition and a replication factor of 1. 74 | * 75 | * @param topic The name of the topic. 76 | */ 77 | def createTopic(topic: String): Unit = { 78 | createTopic(topic, 1, 1, new Properties) 79 | } 80 | 81 | /** 82 | * Create a Kafka topic with the given parameters. 83 | * 84 | * @param topic The name of the topic. 85 | * @param partitions The number of partitions for this topic. 86 | * @param replication The replication factor for (the partitions of) this topic. 87 | */ 88 | def createTopic(topic: String, partitions: Int, replication: Int): Unit = { 89 | createTopic(topic, partitions, replication, new Properties) 90 | } 91 | 92 | /** 93 | * Create a Kafka topic with the given parameters. 94 | * 95 | * @param topic The name of the topic. 96 | * @param partitions The number of partitions for this topic. 97 | * @param replication The replication factor for (partitions of) this topic. 98 | * @param topicConfig Additional topic-level configuration settings. 99 | */ 100 | def createTopic(topic: String, partitions: Int, replication: Int, topicConfig: Properties): Unit = { 101 | broker.createTopic(topic, partitions, replication, topicConfig) 102 | } 103 | } -------------------------------------------------------------------------------- /utils/src/main/scala/com/lightbend/modelserver/kafka/KafkaEmbedded.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.kafka 2 | 3 | import java.io.File 4 | import java.util.{Collections, Properties} 5 | 6 | import com.google.common.io.Files 7 | import kafka.admin.{AdminUtils, RackAwareMode} 8 | import kafka.server.{KafkaConfig, KafkaServerStartable} 9 | import kafka.utils.{CoreUtils, ZkUtils} 10 | import org.slf4j.LoggerFactory 11 | 12 | 13 | /** 14 | * Runs an in-memory, "embedded" instance of a Kafka broker, which listens at `127.0.0.1:9092` by 15 | * default. 16 | * 17 | * Requires a running ZooKeeper instance to connect to. 18 | */ 19 | object KafkaEmbedded { 20 | private val log = LoggerFactory.getLogger(classOf[KafkaEmbedded]) 21 | private val DEFAULT_ZK_CONNECT = "127.0.0.1:2181" 22 | private val DEFAULT_ZK_SESSION_TIMEOUT_MS = 10 * 1000 23 | private val DEFAULT_ZK_CONNECTION_TIMEOUT_MS = 8 * 1000 24 | 25 | def createTempDir: File = Files.createTempDir 26 | 27 | def apply(config: Properties): KafkaEmbedded = new KafkaEmbedded(config) 28 | } 29 | 30 | class KafkaEmbedded (val config: Properties){ 31 | 32 | /** 33 | * Creates and starts an embedded Kafka broker. 34 | * 35 | * @param config Broker configuration settings. Used to modify, for example, on which port the 36 | * broker should listen to. Note that you cannot change the `log.dirs` setting 37 | * currently. 38 | */ 39 | val logDir = KafkaEmbedded.createTempDir 40 | val effectiveConfig = effectiveConfigFrom(config) 41 | val loggingEnabled = true 42 | val kafkaConfig = new KafkaConfig(effectiveConfig, loggingEnabled) 43 | KafkaEmbedded.log.debug(s"Starting embedded Kafka broker (with log.dirs=$logDir and ZK ensemble at $zookeeperConnect) ...") 44 | val kafka = new KafkaServerStartable(kafkaConfig) 45 | val zkUtils = ZkUtils.apply(zookeeperConnect, KafkaEmbedded.DEFAULT_ZK_SESSION_TIMEOUT_MS, KafkaEmbedded.DEFAULT_ZK_CONNECTION_TIMEOUT_MS, false) 46 | KafkaEmbedded.log.debug(s"Startup of embedded Kafka broker at brokerList completed (with ZK ensemble at $zookeeperConnect) ...") 47 | 48 | private def effectiveConfigFrom(initialConfig: Properties) = { 49 | val effectiveConfig = new Properties 50 | effectiveConfig.put(KafkaConfig.BrokerIdProp, "0") 51 | effectiveConfig.put(KafkaConfig.HostNameProp, "127.0.0.1") 52 | effectiveConfig.put(KafkaConfig.PortProp, "9090") 53 | effectiveConfig.put(KafkaConfig.NumPartitionsProp, "1") 54 | effectiveConfig.put(KafkaConfig.AutoCreateTopicsEnableProp, "true") 55 | effectiveConfig.put(KafkaConfig.MessageMaxBytesProp, "1000000") 56 | effectiveConfig.put(KafkaConfig.ControlledShutdownEnableProp, "true") 57 | effectiveConfig.putAll(initialConfig) 58 | effectiveConfig.setProperty(KafkaConfig.LogDirProp, logDir.getAbsolutePath) 59 | effectiveConfig 60 | } 61 | 62 | /** 63 | * This broker's `metadata.broker.list` value. Example: `127.0.0.1:9092`. 64 | * 65 | * You can use this to tell Kafka producers and consumers how to connect to this instance. 66 | */ 67 | def brokerList: String = String.join(":", kafka.serverConfig.hostName, Integer.toString(kafka.serverConfig.port)) 68 | 69 | /** 70 | * The ZooKeeper connection string aka `zookeeper.connect`. 71 | */ 72 | def zookeeperConnect: String = effectiveConfig.getProperty("zookeeper.connect", KafkaEmbedded.DEFAULT_ZK_CONNECT) 73 | 74 | /** 75 | * Start the broker. 76 | */ 77 | def start(): Unit = { 78 | KafkaEmbedded.log.debug(s"Starting embedded Kafka broker at $brokerList (with log.dirs=$logDir and ZK ensemble at $zookeeperConnect) ...") 79 | kafka.startup() 80 | KafkaEmbedded.log.debug(s"Startup of embedded Kafka broker at $brokerList completed (with ZK ensemble at $zookeeperConnect) ...") 81 | } 82 | 83 | /** 84 | * Stop the broker. 85 | */ 86 | def stop(): Unit = { 87 | KafkaEmbedded.log.debug(s"Shutting down embedded Kafka broker at $brokerList (with ZK ensemble at $zookeeperConnect) ...") 88 | kafka.shutdown() 89 | kafka.awaitShutdown() 90 | KafkaEmbedded.log.debug(s"Removing logs.dir at $logDir ...") 91 | val logDirs = Collections.singletonList(logDir.getAbsolutePath) 92 | logDir.delete 93 | CoreUtils.delete(scala.collection.JavaConversions.asScalaBuffer(logDirs).seq) 94 | KafkaEmbedded.log.debug(s"Shutdown of embedded Kafka broker at $brokerList completed (with ZK ensemble at $zookeeperConnect) ...") 95 | } 96 | 97 | /** 98 | * Create a Kafka topic with 1 partition and a replication factor of 1. 99 | * 100 | * @param topic The name of the topic. 101 | */ 102 | def createTopic(topic: String): Unit = { 103 | createTopic(topic, 1, 1, new Properties) 104 | } 105 | 106 | /** 107 | * Create a Kafka topic with the given parameters. 108 | * 109 | * @param topic The name of the topic. 110 | * @param partitions The number of partitions for this topic. 111 | * @param replication The replication factor for (the partitions of) this topic. 112 | */ 113 | def createTopic(topic: String, partitions: Int, replication: Int): Unit = { 114 | createTopic(topic, partitions, replication, new Properties) 115 | } 116 | 117 | /** 118 | * Create a Kafka topic with the given parameters. 119 | * 120 | * @param topic The name of the topic. 121 | * @param partitions The number of partitions for this topic. 122 | * @param replication The replication factor for (partitions of) this topic. 123 | * @param topicConfig Additional topic-level configuration settings. 124 | */ 125 | def createTopic(topic: String, partitions: Int, replication: Int, topicConfig: Properties): Unit = { 126 | KafkaEmbedded.log.debug(s"Creating topic { name: $topic, partitions: $partitions, replication: $replication, config: $topicConfig }") 127 | // val zkClient = new ZkClient(zookeeperConnect, KafkaEmbedded.DEFAULT_ZK_SESSION_TIMEOUT_MS, KafkaEmbedded.DEFAULT_ZK_CONNECTION_TIMEOUT_MS, ZKStringSerializer) 128 | // val zkUtils = new ZkUtils(zkClient, new ZkConnection(zookeeperConnect), false) 129 | AdminUtils.createTopic(zkUtils, topic, partitions, replication, topicConfig, RackAwareMode.Enforced) 130 | } 131 | } -------------------------------------------------------------------------------- /utils/src/main/scala/com/lightbend/modelserver/kafka/KafkaSupport.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.kafka 2 | 3 | import org.apache.kafka.clients.consumer.ConsumerConfig 4 | import org.apache.kafka.common.serialization.ByteArrayDeserializer 5 | 6 | 7 | /** 8 | * Created by blublins on 9/7/16. 9 | */ 10 | @SerialVersionUID(102L) 11 | object KafkaSupport extends Serializable { 12 | 13 | // Kafka consumer properties 14 | private val sessionTimeout: Int = 10 * 1000 15 | private val connectionTimeout: Int = 8 * 1000 16 | private val AUTOCOMMITINTERVAL: String = "1000" 17 | // Frequency off offset commits 18 | private val SESSIONTIMEOUT: String = "30000" 19 | // The timeout used to detect failures - should be greater then processing time 20 | private val MAXPOLLRECORDS: String = "10" 21 | // Max number of records consumed in a single poll 22 | private val GROUPID: String = "Spark Streaming" // Consumer ID 23 | 24 | def getKafkaConsumerConfig(brokers: String): Map[String, String] = { 25 | Map[String, String]( 26 | ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, 27 | ConsumerConfig.GROUP_ID_CONFIG -> GROUPID, 28 | ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG -> "true", 29 | ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG -> AUTOCOMMITINTERVAL, 30 | ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG -> SESSIONTIMEOUT, 31 | ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> MAXPOLLRECORDS, 32 | ConsumerConfig.AUTO_OFFSET_RESET_CONFIG -> "earliest", 33 | ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getTypeName, 34 | ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getTypeName) 35 | } 36 | } -------------------------------------------------------------------------------- /utils/src/main/scala/com/lightbend/modelserver/support/scala/DataReader.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.support.scala 2 | 3 | import com.lightbend.model.winerecord.WineRecord 4 | 5 | import scala.util.Try 6 | 7 | /** 8 | * Created by boris on 5/8/17. 9 | */ 10 | object DataReader { 11 | 12 | def fromByteArray(message: Array[Byte]): Try[WineRecord] = Try { 13 | WineRecord.parseFrom(message) 14 | } 15 | } -------------------------------------------------------------------------------- /utils/src/main/scala/com/lightbend/modelserver/support/scala/ModelToServe.scala: -------------------------------------------------------------------------------- 1 | package com.lightbend.modelserver.support.scala 2 | 3 | import com.lightbend.model.modeldescriptor.ModelDescriptor 4 | 5 | import scala.util.Try 6 | 7 | /** 8 | * Created by boris on 5/8/17. 9 | */ 10 | object ModelToServe { 11 | def fromByteArray(message: Array[Byte]): Try[ModelToServe] = Try{ 12 | val m = ModelDescriptor.parseFrom(message) 13 | m.messageContent.isData match { 14 | case true => new ModelToServe(m.name, m.description, m.modeltype, m.getData.toByteArray, m.dataType) 15 | case _ => throw new Exception("Location based is not yet supported") 16 | } 17 | } 18 | } 19 | 20 | case class ModelToServe(name: String, description: String, 21 | modelType: ModelDescriptor.ModelType, 22 | model : Array[Byte], dataType : String) {} 23 | 24 | case class ModelToServeStats(name: String, description: String, modelType: String, 25 | since : Long, var usage : Long = 0, var duration : Double = .0, 26 | var min : Long = Long.MaxValue, var max : Long = Long.MinValue) { 27 | def this(m : ModelToServe) = this(m.name, m.description, m.modelType.name, System.currentTimeMillis()) 28 | def incrementUsage(execution : Long) : ModelToServeStats = { 29 | usage = usage + 1 30 | duration = duration + execution 31 | if(execution < min) min = execution 32 | if(execution > max) max = execution 33 | this 34 | } 35 | } 36 | 37 | object ModelToServeStats{ 38 | val empty = ModelToServeStats("None", "None", "None", 0, 0, .0, 0, 0) 39 | } --------------------------------------------------------------------------------