├── project ├── build.properties └── plugins.sbt ├── src └── main │ ├── resources │ ├── elecNormData.txt.dist │ ├── elecNormNew.arff │ ├── script.txt │ └── elecNormData.txt │ └── scala │ ├── pl │ └── gosub │ │ └── akka │ │ └── online │ │ ├── recursive │ │ └── least │ │ │ └── squares │ │ │ ├── RecursiveLeastSquresFilter.scala │ │ │ ├── RecursiveLeastSquaresMain.scala │ │ │ └── RecursiveLeastSquaresStage.scala │ │ ├── StdoutSink.scala │ │ ├── follow │ │ └── the │ │ │ └── leader │ │ │ ├── FollowTheLeaderLogic.scala │ │ │ ├── FollowTheLeaderMain.scala │ │ │ └── FollowTheLeaderStage.scala │ │ ├── HoeffdingTreeWithAlpakka.scala │ │ ├── KadaneFlowActor.scala │ │ ├── KadaneFlowStage.scala │ │ ├── BloomFilterCrossStage.scala │ │ ├── HoeffdingTreeFlowStream.scala │ │ ├── BloomFilterCrossMatStage.scala │ │ ├── HoeffdingTreeFlowActor.scala │ │ ├── Main.scala │ │ └── SuffixTreeTripodMatStage.scala │ └── org │ └── apache │ └── spark │ └── streamdm │ ├── classifiers │ ├── model │ │ ├── Regularizer.scala │ │ ├── ZeroRegularizer.scala │ │ ├── L2Regularizer.scala │ │ ├── L1Regularizer.scala │ │ ├── PerceptronLoss.scala │ │ ├── SquaredLoss.scala │ │ ├── HingeLoss.scala │ │ ├── LogisticLoss.scala │ │ ├── Loss.scala │ │ └── LinearModel.scala │ ├── OnlineClassifier.scala │ ├── trees │ │ ├── FeatureSplit.scala │ │ ├── Utils.scala │ │ ├── GaussianEstimator.scala │ │ ├── ConditionalTest.scala │ │ ├── SplitCriterion.scala │ │ ├── FeatureClassObserver.scala │ │ └── Node.scala │ └── bayes │ │ └── MultinomialNaiveBayes.scala │ ├── core │ ├── Model.scala │ ├── ClassificationModel.scala │ ├── OnlineLearner.scala │ ├── Instance.scala │ ├── specification │ │ ├── FeatureSpecification.scala │ │ ├── ExampleSpecification.scala │ │ ├── InstanceSpecification.scala │ │ └── SpecificationParser.scala │ ├── NullInstance.scala │ ├── Example.scala │ ├── ExampleParser.scala │ ├── TextInstance.scala │ ├── DenseInstance.scala │ └── SparseInstance.scala │ └── utils │ └── Utils.scala ├── lib └── suffixtree-1.0.0-SNAPSHOT.jar ├── .gitignore ├── README.md └── LICENSE /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.12 2 | -------------------------------------------------------------------------------- /src/main/resources/elecNormData.txt.dist: -------------------------------------------------------------------------------- 1 | pipe to this file with the script from script.txt -------------------------------------------------------------------------------- /lib/suffixtree-1.0.0-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gosubpl/akka-online/HEAD/lib/suffixtree-1.0.0-SNAPSHOT.jar -------------------------------------------------------------------------------- /src/main/resources/elecNormNew.arff: -------------------------------------------------------------------------------- 1 | Download sample arff file from http://moa.cs.waikato.ac.nz/datasets/ 2 | and extract it here (replacing this placeholder file) -------------------------------------------------------------------------------- /src/main/resources/script.txt: -------------------------------------------------------------------------------- 1 | for x in `cat elecNormNew.arff | tail -n +14 | head`; do echo $x; sleep 1; done 2 | for x in `cat elecNormNew.arff | tail -n +14 | head`; do rnd=$(( ( RANDOM % 10 ) + 1 )); if [ $rnd == 1 ] ; then action="QUERY" ; else action="EXAMPLE" ; fi; echo $action';'$x; sleep 1; done 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | 4 | # sbt specific 5 | .cache 6 | .history 7 | .lib/ 8 | dist/* 9 | target/ 10 | lib_managed/ 11 | src_managed/ 12 | project/boot/ 13 | project/plugins/project/ 14 | 15 | # Scala-IDE specific 16 | .scala_dependencies 17 | .worksheet 18 | /.idea 19 | target 20 | 21 | *.orig 22 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | //addSbtPlugin("com.typesafe.sbt" % "sbt-multi-jvm" % "0.3.8") 2 | 3 | //addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.2.16") 4 | 5 | //addSbtPlugin("pl.project13.sbt" % "sbt-jol" % "0.1.1") 6 | 7 | addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.0.0-RC1") 8 | 9 | //addSbtPlugin("com.timushev.sbt" % "sbt-updates" % "0.1.8") 10 | // 11 | //addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 12 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/recursive/least/squares/RecursiveLeastSquresFilter.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.recursive.least.squares 2 | 3 | class RecursiveLeastSquresFilter(private val x0: Double, private val y0: Double) { 4 | 5 | private var theta: Double = 1.0; 6 | 7 | private var previousResult: Double = y0 8 | 9 | private var weight = if(x0 == 0) 0.0 else (y0 / x0) 10 | 11 | def predict(x: Double, previousResult: Double): Double = { 12 | 13 | theta = theta - (theta * theta * x * x) / (1 + x * x * theta) 14 | 15 | weight = weight - theta * x * (weight * x - previousResult) 16 | 17 | val prediction = weight * x 18 | 19 | prediction 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/StdoutSink.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.stream.{Attributes, Inlet, SinkShape} 4 | import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler} 5 | 6 | class StdoutSink extends GraphStage[SinkShape[Char]] { 7 | val in: Inlet[Char] = Inlet("StdoutSink") 8 | override val shape: SinkShape[Char] = SinkShape(in) 9 | 10 | override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = 11 | new GraphStageLogic(shape) { 12 | 13 | // This requests one element at the Sink startup. 14 | override def preStart(): Unit = pull(in) 15 | 16 | setHandler(in, new InHandler { 17 | override def onPush(): Unit = { 18 | print(grab(in)) 19 | pull(in) 20 | } 21 | }) 22 | } 23 | } -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/follow/the/leader/FollowTheLeaderLogic.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.follow.the.leader 2 | 3 | class FollowTheLeaderLogic(val hypotheses: Seq[Double => Double], lossFunction: ((Double, Double) => Double), windowSize: Int) { 4 | 5 | private var past: Seq[(Double, Double)] = Seq.empty; 6 | 7 | private var pastX = 0.0 8 | 9 | def predict(x: Double, y: Double): Double = { 10 | 11 | past = past :+ (pastX, y) 12 | 13 | past.dropWhile(_ => past.size >= windowSize) 14 | 15 | val leader = if(past isEmpty) { 16 | hypotheses.head 17 | } else { 18 | hypotheses 19 | .map(hypothesis => (hypothesis, past.map{ case (x, y) => lossFunction(hypothesis(x), y)}.reduce(_+_))) 20 | .sortBy(_._2) 21 | .map(_._1) 22 | .head 23 | } 24 | 25 | val prediction = leader(x) 26 | 27 | pastX = x 28 | 29 | prediction 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/recursive/least/squares/RecursiveLeastSquaresMain.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.recursive.least.squares 2 | 3 | import akka.actor.ActorSystem 4 | import akka.stream.scaladsl.{GraphDSL, RunnableGraph, Sink, Source} 5 | import akka.stream.{ActorMaterializer, ClosedShape} 6 | 7 | import scala.util.Random 8 | 9 | object ResursiveLeastSquaresMain extends App { 10 | 11 | implicit val system = ActorSystem() 12 | implicit val mat = ActorMaterializer() 13 | 14 | val rlsStage = new RecursiveLeastSquaresStage(new RecursiveLeastSquresFilter(1.0, 10.0)) 15 | 16 | val graph = RunnableGraph.fromGraph(GraphDSL.create(){ implicit builder => 17 | 18 | import GraphDSL.Implicits._ 19 | 20 | val cross = builder.add(rlsStage) 21 | 22 | val x = Source.fromIterator(() => Iterator.iterate(0.0)(x => x + 1 )) 23 | val y = Source.fromIterator(() => Iterator.iterate(-10.0)(x => x + 10 + random(5.0) )) 24 | val p = Sink.foreach(println) 25 | 26 | x ~> cross.in0 27 | y ~> cross.in1 28 | p <~ cross.out 29 | 30 | ClosedShape 31 | }).run 32 | 33 | def random(sigma: Double): Double = (Random.nextDouble() * sigma) - (sigma / 2.0) 34 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/Regularizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * A regularizer trait defines the gradient operation for computing regularized 22 | * models. 23 | */ 24 | trait Regularizer extends Serializable { 25 | /** Computes the value of the gradient function 26 | * @param value the weight for which the gradient is computed 27 | * @return the gradient value 28 | */ 29 | def gradient(weight: Double): Double 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/ZeroRegularizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * The ZeroRegularizer gradient simply returns 0 as the gradient. 22 | */ 23 | class ZeroRegularizer extends Regularizer with Serializable { 24 | /** Computes the value of the gradient function 25 | * @param value the weight for which the gradient is computed 26 | * @return the gradient value 27 | */ 28 | def gradient(weight: Double): Double = 0.0 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/L2Regularizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * The L2Regularizer gradient returns the weight as the gradient for the 22 | * regularization. 23 | */ 24 | class L2Regularizer extends Regularizer with Serializable { 25 | /** Computes the value of the gradient function 26 | * @param value the weight for which the gradient is computed 27 | * @return the gradient value 28 | */ 29 | def gradient(weight: Double): Double = weight 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/L1Regularizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * The L1Regularizer gradient return the sign as the weight for the 22 | * regularization. 23 | */ 24 | class L1Regularizer extends Regularizer with Serializable { 25 | /** Computes the value of the gradient function 26 | * @param value the weight for which the gradient is computed 27 | * @return the gradient value 28 | */ 29 | def gradient(weight: Double): Double = if (weight>=0) 1.0 else -1.0 30 | } 31 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/Model.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * A Model trait defines the needed operations on any learning Model. It 22 | * provides a method for updating the model. 23 | */ 24 | trait Model extends Serializable { 25 | 26 | type T <: Model 27 | 28 | /** 29 | * Update the model, depending on the Instance given for training. 30 | * 31 | * @param change the example based on which the Model is updated 32 | * @return the updated Model 33 | */ 34 | def update(change: Example): T 35 | } 36 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/follow/the/leader/FollowTheLeaderMain.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.follow.the.leader 2 | 3 | import akka.actor.ActorSystem 4 | import akka.stream.scaladsl.{GraphDSL, RunnableGraph, Sink, Source} 5 | import akka.stream.{ActorMaterializer, ClosedShape} 6 | 7 | import scala.util.Random 8 | 9 | object FollowTheLeaderMain extends App { 10 | 11 | implicit val system = ActorSystem() 12 | implicit val mat = ActorMaterializer() 13 | 14 | val hypotheses = (0 to 30).toSeq.map(a => {(x: Double) => a * x}) 15 | 16 | val ftlStage = new FollowTheLeaderStage(new FollowTheLeaderLogic( 17 | (0 to 30).map(a => {(x: Double) => a * x}), 18 | {(prediction: Double, y: Double) => Math.abs(prediction - y)}, 19 | 5 20 | )) 21 | 22 | val graph = RunnableGraph.fromGraph(GraphDSL.create(){ implicit builder => 23 | 24 | import GraphDSL.Implicits._ 25 | 26 | val cross = builder.add(ftlStage) 27 | 28 | val x = Source.fromIterator(() => Iterator.iterate(0.0)(x => x + 1 )) 29 | val y = Source.fromIterator(() => Iterator.iterate(0.0)(x => x + 10 + random(5.0) )) 30 | val p = Sink.foreach(println) 31 | 32 | x ~> cross.in0 33 | y ~> cross.in1 34 | p <~ cross.out 35 | 36 | ClosedShape 37 | }).run 38 | 39 | def random(sigma: Double): Double = (Random.nextDouble() * sigma) - (sigma / 2.0) 40 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/PerceptronLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * Implementation of the perceptron loss function. Essentially, the perceptron 22 | * is using the squared loss function except for the "gradient". 23 | */ 24 | 25 | class PerceptronLoss extends SquaredLoss with Serializable { 26 | /** Computes the value of the perceptron update function 27 | * @param value the label against which the update is computed 28 | * @param dot the dot product of the linear model and the instance 29 | * @return the update value 30 | */ 31 | override def gradient(label: Double, dot: Double): Double = 32 | label-predict(dot) 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/OnlineClassifier.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Portions Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | /* 19 | This file has been changed by gosubpl 20 | */ 21 | 22 | package org.apache.spark.streamdm.classifiers 23 | 24 | import org.apache.spark.streamdm.core._ 25 | 26 | /** 27 | * A Classifier trait defines the needed operations on any implemented 28 | * classifier. It is a subtrait of Learner and it adds a method for predicting 29 | * the class of and input stream of Examples. 30 | */ 31 | trait OnlineClassifier extends OnlineLearner with Serializable { 32 | 33 | /* Predict the label of the Example stream, given the current Model 34 | * 35 | * @param instance the input Example stream 36 | * @return a stream of tuples containing the original instance and the 37 | * predicted value 38 | */ 39 | def predictSingle(input: Example): (Example, Double) 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/ClassificationModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * A ClassificationModel trait defines the needed operations on any classification Model. It 22 | * provides methods for updating the model and for predicting the label of a 23 | * given Instance 24 | */ 25 | trait ClassificationModel extends Model { 26 | 27 | /* Predict the label of the Instance, given the current Model 28 | * 29 | * @param instance the Instance which needs a class predicted 30 | * @return a Double representing the class predicted 31 | */ 32 | def predict(instance: Example): Double 33 | 34 | /** Computes the probability for a given label class, given the current Model 35 | * 36 | * @param instance the Instance which needs a class predicted 37 | * @return the predicted probability 38 | */ 39 | 40 | def prob(instance: Example): Double 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/SquaredLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * Implementation of the squared loss function. 22 | */ 23 | 24 | class SquaredLoss extends Loss with Serializable { 25 | /** Computes the value of the loss function 26 | * @param value the label against which the loss is computed 27 | * @param dot the dot product of the linear model and the instance 28 | * @return the loss value 29 | */ 30 | def loss(label: Double, dot: Double): Double = 31 | 0.5*(dot-label)*(dot-label) 32 | 33 | /** Computes the value of the gradient function 34 | * @param value the label against which the gradient is computed 35 | * @param dot the dot product of the linear model and the instance 36 | * @return the gradient value 37 | */ 38 | def gradient(label: Double, dot: Double): Double = 39 | dot-label 40 | 41 | /** Computes the binary prediction based on a dot prodcut 42 | * @param dot the dot product of the linear model and the instance 43 | * @return the predicted binary class 44 | */ 45 | def predict(dot: Double): Double = 46 | if (dot>=0) 1 else 0 47 | } 48 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/FeatureSplit.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import scala.math.Ordered 21 | 22 | /** 23 | * Class for recording a split suggestion. 24 | */ 25 | class FeatureSplit(val conditionalTest: ConditionalTest, val merit: Double, 26 | val result: Array[Array[Double]]) extends Ordered[FeatureSplit] { 27 | 28 | /** Compares two FeatureSplit objects. 29 | * @param that the comparison feature 30 | * @return comparison result 31 | */ 32 | override def compare(that: FeatureSplit): Int = { 33 | if (this.merit < that.merit) -1 34 | else if (this.merit > that.merit) 1 35 | else 0 36 | } 37 | 38 | /** 39 | * Returns the number of the split. 40 | * @return the number of the split 41 | */ 42 | def numSplit(): Int = result.length 43 | 44 | /** 45 | * Returns the distribution of the split index. 46 | * 47 | * @param splitIndex the split index 48 | * @return an Array containing the distribution 49 | */ 50 | def distributionFromSplit(splitIndex: Int): Array[Double] = result(splitIndex) 51 | 52 | override def toString(): String = "FeatureSplit, merit=" + merit + ", " + conditionalTest 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/HingeLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * Implementation of the squared loss function. 22 | */ 23 | 24 | class HingeLoss extends Loss with Serializable { 25 | /** Computes the value of the loss function 26 | * @param value the label against which the loss is computed 27 | * @param dot the dot product of the linear model and the instance 28 | * @return the loss value 29 | */ 30 | def loss(label: Double, dot: Double): Double = { 31 | val l = if (label==0) -1.0 else 1 32 | val v = 1.0 - l*dot 33 | if (v<0) 0 else v 34 | } 35 | 36 | /** Computes the value of the gradient function 37 | * @param value the label against which the gradient is computed 38 | * @param dot the dot product of the linear model and the instance 39 | * @return the gradient value 40 | */ 41 | def gradient(label: Double, dot: Double): Double = { 42 | val l = if (label==0) -1.0 else 1 43 | val d = l*dot 44 | if (d<1) -l else 0 45 | } 46 | 47 | /** Computes the binary prediction based on a dot prodcut 48 | * @param dot the dot product of the linear model and the instance 49 | * @return the predicted binary class 50 | */ 51 | def predict(dot: Double): Double = 52 | if (dot>=0) 1 else 0 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/LogisticLoss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | import scala.math 21 | 22 | /** 23 | * Implementation of the logistic loss function. 24 | */ 25 | 26 | class LogisticLoss extends Loss with Serializable { 27 | /** Computes the value of the loss function 28 | * @param value the label against which the loss is computed 29 | * @param dot the dot product of the linear model and the instance 30 | * @return the loss value 31 | */ 32 | def loss(label: Double, dot: Double): Double = { 33 | val y = if(label<=0) -1 else 1 34 | math.log(1+math.exp(-y*dot)) 35 | } 36 | 37 | /** Computes the value of the gradient function 38 | * @param value the label against which the gradient is computed 39 | * @param dot the dot product of the linear model and the instance 40 | * @return the gradient value 41 | */ 42 | def gradient(label: Double, dot: Double): Double = { 43 | val y = if(label<=0) -1 else 1 44 | -y*(1.0-1.0/(1.0+math.exp(-y*dot))) 45 | } 46 | 47 | /** Computes the binary prediction based on a dot product 48 | * @param dot the dot product of the linear model and the instance 49 | * @return the predicted binary class 50 | */ 51 | def predict(dot: Double): Double = { 52 | val f = 1.0 / (1.0+math.exp(-dot)) 53 | if (f>0.5) 1 else 0 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/OnlineLearner.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Portions Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | /* 19 | This file has been changed by gosubpl 20 | */ 21 | 22 | package org.apache.spark.streamdm.core 23 | 24 | import org.apache.spark.streamdm.core._ 25 | import com.github.javacliparser.Configurable 26 | import org.apache.spark.streamdm.core.specification.ExampleSpecification 27 | 28 | 29 | /** 30 | * A Learner trait defines the needed operations for any learner algorithm 31 | * implemented. It provides methods for training the model for a stream of 32 | * Example RDDs. 33 | * Any Learner will contain a data structure derived from Model. 34 | */ 35 | trait OnlineLearner extends Configurable with Serializable { 36 | 37 | type T <: Model 38 | 39 | /** 40 | * Init the model based on the algorithm implemented in the learner. 41 | * 42 | * @param exampleSpecification the ExampleSpecification of the input stream. 43 | */ 44 | def init(exampleSpecification: ExampleSpecification): Unit 45 | 46 | /** 47 | * Train the model based on the algorithm implemented in the learner, 48 | * from the stream of Examples given for training. 49 | * 50 | * @param input a stream of Examples 51 | */ 52 | def trainIncremental(input: Example): Unit 53 | 54 | /** 55 | * Gets the current Model used for the Learner. 56 | * 57 | * @return the Model object used for training 58 | */ 59 | def getModel: T 60 | } 61 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/Loss.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | /** 21 | * A Loss trait defines the operation needed to compute the loss function, the 22 | * prediction function, and the gradient for use in a LinearModel. 23 | */ 24 | trait Loss extends Serializable { 25 | /** Computes the value of the loss function 26 | * @param value the label against which the loss is computed 27 | * @param dot the dot product of the linear model and the instance 28 | * @return the loss value 29 | */ 30 | def loss(label: Double, dot: Double): Double 31 | 32 | /** Computes the value of the gradient function 33 | * @param value the label against which the gradient is computed 34 | * @param dot the dot product of the linear model and the instance 35 | * @return the gradient value 36 | */ 37 | def gradient(label: Double, dot: Double): Double 38 | 39 | /** Computes the binary prediction based on a dot product 40 | * @param dot the dot product of the linear model and the instance 41 | * @return the predicted binary class 42 | */ 43 | def predict(dot: Double): Double 44 | 45 | /** Computes the probability of a binary prediction based on a dot product 46 | * @param dot the dot product of the linear model and the instance 47 | * @return the predicted probability 48 | */ 49 | def prob(dot: Double): Double = loss(1, dot) 50 | } 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # akka-online 2 | Online and streaming algorithms with Akka 3 | 4 | For good overview of streaming algos go to [this debasishg's gist](https://gist.github.com/debasishg/8172796). 5 | 6 | In `org.apache.spark.streamdm` there is code copied from [Huawei Noah's Ark Lab streamDM](https://github.com/huawei-noah/streamDM) project (the licencse is Apache 2.0 too). The code has been adapted to work with Akka streams instead of Spark streaming module by removing dependencies on the Spark stuff. Tested path is the `HoeffdingTree` model usage, which works best with input sourced from `Arff` files via `SpecificationParser` / `ExampleParser` - see see other parts of `akka-online` for usage examples. 7 | If you'd like to know more about HoeffdingTree on-line classifier, please read the [HDT docs](http://huawei-noah.github.io/streamDM/docs/HDT.html) or go to [Massive Online Analysis](http://moa.cms.waikato.ac.nz/) website that contains more information on the context. 8 | 9 | Sample `Arff` files can be found in the [MOA dataset repository](http://moa.cs.waikato.ac.nz/datasets/). 10 | 11 | For theoretical exposition to `HoeffdingTree` usage go to [the original paper](http://homes.cs.washington.edu/%7Epedrod/papers/kdd00.pdf). 12 | 13 | If you are asking yourself question _why Akka, not Spark?_ - please read the classic [bigger data, same laptop](http://www.frankmcsherry.org/graph/scalability/cost/2015/02/04/COST2.html) to see what we can gain _going lightweight_. In previous century, 4.5k records [machine learning data set](http://informatique.umons.ac.be/ssi/teaching/dwdm/spambase.arff) could be considered sizeable and hence all the growth of _clustering_ for data processing. But with today's hardware I believe _we can do better_. 14 | 15 | In the `/lib` directory of the project I have put `suffixtree-1.0.0-SNAPSHOT.jar` which is a compiled artifact of the [Ukkonen's on-line Suffix Tree implementation](https://github.com/abahgat/suffixtree). The licensse for this project is also Apache 2.0. 16 | 17 | In this project I also use Guava implementation of `Bloom Filter` (see project dependencies) but there are other [options](https://github.com/alexandrnikitin/bloom-filter-scala). 18 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/recursive/least/squares/RecursiveLeastSquaresStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.recursive.least.squares 2 | 3 | import akka.Done 4 | import akka.stream._ 5 | import akka.stream.stage._ 6 | 7 | import scala.concurrent.{Future, Promise} 8 | 9 | class RecursiveLeastSquaresStage(val rls: RecursiveLeastSquresFilter) 10 | extends GraphStageWithMaterializedValue[FanInShape2[Double, Double, Double], Future[Done]] { 11 | 12 | // Stage syntax 13 | val dataIn: Inlet[Double] = Inlet("RecursiveLeastSquaresStage.dataIn") 14 | val resultsIn: Inlet[Double] = Inlet("RecursiveLeastSquaresStage.resultsIn") 15 | val predictionsOut: Outlet[Double] = Outlet("RecursiveLeastSquaresStage.predictionsOut") 16 | 17 | override val shape: FanInShape2[Double, Double, Double] = new FanInShape2(dataIn, resultsIn, predictionsOut) 18 | 19 | // Stage semantics 20 | override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { 21 | // Completion notification 22 | val p: Promise[Done] = Promise() 23 | 24 | val logic = new GraphStageLogic(shape) { 25 | 26 | setHandler(resultsIn, new InHandler { 27 | @scala.throws[Exception](classOf[Exception]) 28 | override def onPush(): Unit = { 29 | val nextResult = grab(resultsIn) 30 | read(dataIn)({ x => 31 | if (isAvailable(predictionsOut)) push(predictionsOut, rls.predict(x, nextResult)) 32 | }, () => {}) 33 | } 34 | }) 35 | 36 | setHandler(dataIn, new InHandler { 37 | override def onPush(): Unit = { 38 | val x = grab(dataIn) 39 | read(resultsIn)({previousResult => 40 | if (isAvailable(predictionsOut)) push(predictionsOut, rls.predict(x, previousResult)) 41 | }, () => {}) 42 | } 43 | 44 | override def onUpstreamFinish(): Unit = { 45 | completeStage() 46 | } 47 | }) 48 | 49 | setHandler(predictionsOut, new OutHandler { 50 | override def onPull(): Unit = { 51 | pull(dataIn) 52 | } 53 | }) 54 | } 55 | 56 | (logic, p.future) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/follow/the/leader/FollowTheLeaderStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online.follow.the.leader 2 | 3 | import akka.Done 4 | import akka.stream.stage.{GraphStageLogic, GraphStageWithMaterializedValue, InHandler, OutHandler} 5 | import akka.stream.{Attributes, FanInShape2, Inlet, Outlet} 6 | 7 | import scala.concurrent.{Future, Promise} 8 | 9 | class FollowTheLeaderStage(private val ftl: FollowTheLeaderLogic) extends GraphStageWithMaterializedValue[FanInShape2[Double, Double, Double], Future[Done]]{ 10 | 11 | // Stage syntax 12 | val dataIn: Inlet[Double] = Inlet("FollowTheLeaderStage.dataIn") 13 | val resultsIn: Inlet[Double] = Inlet("FollowTheLeaderStage.resultsIn") 14 | val predictionsOut: Outlet[Double] = Outlet("FollowTheLeaderStage.predictionsOut") 15 | 16 | override val shape: FanInShape2[Double, Double, Double] = new FanInShape2(dataIn, resultsIn, predictionsOut) 17 | 18 | @scala.throws[Exception](classOf[Exception]) 19 | override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Done]) = { 20 | // Completion notification 21 | val p: Promise[Done] = Promise() 22 | 23 | val logic = new GraphStageLogic(shape) { 24 | 25 | setHandler(resultsIn, new InHandler { 26 | @scala.throws[Exception](classOf[Exception]) 27 | override def onPush(): Unit = { 28 | val nextResult = grab(resultsIn) 29 | read(dataIn)({ x => 30 | if (isAvailable(predictionsOut)) push(predictionsOut, ftl.predict(x, nextResult)) 31 | }, () => {}) 32 | } 33 | }) 34 | 35 | 36 | setHandler(dataIn, new InHandler { 37 | override def onPush(): Unit = { 38 | val x = grab(dataIn) 39 | read(resultsIn)({previousResult => 40 | if (isAvailable(predictionsOut)) push(predictionsOut, ftl.predict(x, previousResult)) 41 | }, () => {}) 42 | } 43 | 44 | override def onUpstreamFinish(): Unit = { 45 | completeStage() 46 | } 47 | }) 48 | 49 | setHandler(predictionsOut, new OutHandler { 50 | override def onPull(): Unit = { 51 | pull(dataIn) 52 | } 53 | }) 54 | } 55 | 56 | (logic, p.future) 57 | } 58 | } -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/HoeffdingTreeWithAlpakka.scala: -------------------------------------------------------------------------------- 1 | import java.nio.file.FileSystems 2 | 3 | import akka.NotUsed 4 | import akka.actor.ActorSystem 5 | import akka.stream.{ActorMaterializer, ClosedShape, ThrottleMode} 6 | import akka.stream.scaladsl.{GraphDSL, Merge, RunnableGraph, Sink, Source} 7 | import akka.stream.alpakka.file.scaladsl 8 | import org.apache.spark.streamdm.core.ExampleParser 9 | import org.apache.spark.streamdm.core.specification.SpecificationParser 10 | import pl.gosub.akka.online.{HoeffdingTreeProcessor, LearnerQuery} 11 | 12 | import scala.concurrent.Await 13 | import scala.concurrent.duration.Duration 14 | import scala.concurrent.duration._ 15 | 16 | 17 | object HoeffdingTreeWithAlpakka extends App { 18 | implicit val system = ActorSystem() 19 | implicit val mat = ActorMaterializer() 20 | 21 | val specParser = new SpecificationParser 22 | 23 | val arffPath = this.getClass.getResource("/elecNormNew.arff").getPath // add for Windows .replaceFirst("^/(.:/)", "$1") 24 | 25 | val exampleSpec = specParser.fromArff(arffPath) 26 | 27 | val fsPath = this.getClass.getResource("/elecNormData.txt").getPath // add for Windows .replaceFirst("^/(.:/)", "$1") 28 | println(fsPath) 29 | 30 | 31 | val fs = FileSystems.getDefault 32 | val lines: Source[String, NotUsed] = scaladsl.FileTailSource.lines( 33 | path = fs.getPath(fsPath), 34 | maxLineSize = 8192, 35 | pollingInterval = 250.millis 36 | ) 37 | 38 | // if the lines below do not work, please make sure that you got the linefeed character right wrt your operating system (LF vs CRLF) 39 | // lines.map(line => LearnerQuery(line.split(";").apply(0), ExampleParser.fromArff(line.split(";").apply(1), exampleSpec))) 40 | // .runForeach(line => System.out.println(line)) 41 | 42 | val masterControlProgram = RunnableGraph.fromGraph(GraphDSL.create(Sink.foreach(print)) { implicit builder => 43 | outMatches => 44 | import GraphDSL.Implicits._ 45 | val taggedInput = lines.map(line => LearnerQuery(line.split(";").apply(0), ExampleParser.fromArff(line.split(";").apply(1), exampleSpec))) 46 | 47 | taggedInput.statefulMapConcat(() => { 48 | val proc = new HoeffdingTreeProcessor(exampleSpec) 49 | proc.process(_) 50 | }) ~> outMatches 51 | ClosedShape 52 | }).run() 53 | 54 | import scala.concurrent.ExecutionContext.Implicits.global 55 | 56 | masterControlProgram.onComplete(_ => system.terminate()) 57 | Await.ready(system.whenTerminated, Duration.Inf) 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/KadaneFlowActor.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.actor.{Actor, ActorRef, ActorSystem, Props} 4 | import akka.event.Logging 5 | import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} 6 | import akka.stream.scaladsl.{Sink, Source, SourceQueueWithComplete} 7 | import akka.pattern.pipe // For the pipeTo pattern 8 | 9 | import scala.concurrent.Await 10 | import scala.concurrent.duration.Duration 11 | import scala.util.Random 12 | 13 | class KadaneFlowActor(queue: SourceQueueWithComplete[Int]) extends Actor { 14 | val log = Logging(context.system, this) 15 | // state 16 | var max_ending_here = 0 17 | var max_so_far = 0 18 | var upStream : ActorRef = self 19 | 20 | // Execution context for scheduling of the piped futures 21 | import scala.concurrent.ExecutionContext.Implicits.global 22 | 23 | def receive = { 24 | case "TEST" => log.info("Received TEST message") 25 | upStream = sender() 26 | sender() ! "ACK" // ask for first element 27 | 28 | case "END" => log.info("Received END message, terminating actor") 29 | queue.complete() 30 | 31 | case elem : Int => // onPush() + grab(in) 32 | // "Business" logic 33 | max_ending_here = Math.max(0, max_ending_here + elem) 34 | max_so_far = Math.max(max_so_far, max_ending_here) 35 | 36 | // Thread.sleep(100) // FIXME: don't do this at home 37 | 38 | log.info(s"Received: $elem, sending out: $max_so_far") 39 | 40 | val offF = queue.offer(max_so_far) // push element downstream 41 | offF pipeTo self // generate backpressure in the Actor 42 | 43 | case e : QueueOfferResult => 44 | upStream ! "ACK" // ask for next element 45 | 46 | case _ => log.info("Unrecognised message, ignoring") 47 | } 48 | } 49 | 50 | object KadaneFlowActorMain extends App { 51 | implicit val system = ActorSystem() 52 | implicit val mat = ActorMaterializer() 53 | 54 | import scala.concurrent.ExecutionContext.Implicits.global 55 | 56 | val actorQueue = Source.queue[Int](10, OverflowStrategy.backpressure) 57 | val actorSink = actorQueue 58 | .throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 59 | .to(Sink.foreach((x) => println(x))).run() 60 | 61 | val done = actorSink.watchCompletion() 62 | 63 | val kadaneFlowActor = system.actorOf(Props(new KadaneFlowActor(actorSink))) 64 | 65 | Source.repeat(1).take(100).map(_ => Random.nextInt(1100) - 1000) 66 | .runWith(Sink.actorRefWithAck(kadaneFlowActor, "TEST", "ACK", "END", _ => "FAIL")) 67 | 68 | done.onComplete(_ => system.terminate()) 69 | Await.ready(system.whenTerminated, Duration.Inf) 70 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | package org.apache.spark.streamdm.classifiers.trees 18 | 19 | import scala.math.{ max, min } 20 | object Utils { 21 | /* 22 | * Add two arrays and return the min length of two input arrays. 23 | * 24 | * @param array1 double Array 25 | * @param array2 double Array 26 | * @return Array which added in each index 27 | */ 28 | def addArrays(array1: Array[Double], array2: Array[Double]): Array[Double] = { 29 | val merge = new Array[Double](min(array1.length, array2.length)) 30 | for (i <- 0 until merge.length) 31 | merge(i) = array1(i) + array2(i) 32 | merge 33 | } 34 | 35 | /* 36 | * Return a string for a matrix 37 | * 38 | * @param m matrix in form of 2-D array 39 | * @param split string of the split 40 | * @param head string of matrix head and each line's head 41 | * @param tail string of matrix tail and each line's tail 42 | * @return string format of input matrix 43 | */ 44 | def matrixtoString[T](m: Array[Array[T]], split: String = ",", head: String = "{", tail: String = "}"): String = { 45 | val sb = new StringBuffer(head) 46 | for (i <- 0 until m.length) { 47 | sb.append(head) 48 | for (j <- 0 until m(i).length) { 49 | sb.append(m(i)(j)) 50 | if (j < m(i).length - 1) 51 | sb.append(split) 52 | } 53 | sb.append(tail) 54 | if (i < m.length - 1) 55 | sb.append(split) 56 | } 57 | sb.append(tail).toString() 58 | } 59 | 60 | /* 61 | * Return a string for an array 62 | * 63 | * @param m array 64 | * @param split string of the split 65 | * @param head string of array head 66 | * @param tail string of array tail 67 | * @return string format of input array 68 | */ 69 | def arraytoString[T](pre: Array[T], split: String = ",", head: String = "{", tail: String = "}"): String = { 70 | val sb = new StringBuffer(head) 71 | for (i <- 0 until pre.length) { 72 | sb.append(pre(i)) 73 | if (i < pre.length - 1) 74 | sb.append(split) 75 | } 76 | sb.append(tail).toString() 77 | } 78 | } -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/KadaneFlowStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.actor.ActorSystem 4 | import akka.stream.scaladsl.{Sink, Source} 5 | import akka.stream._ 6 | import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} 7 | 8 | import scala.concurrent.Await 9 | import scala.concurrent.duration.Duration 10 | import scala.util.Random 11 | 12 | /* 13 | def max_subarray(A): 14 | max_ending_here = max_so_far = 0 15 | for x in A: 16 | max_ending_here = max(0, max_ending_here + x) 17 | max_so_far = max(max_so_far, max_ending_here) 18 | return max_so_far 19 | */ 20 | 21 | class KadaneFlowStage extends GraphStage[FlowShape[Int, Int]] { 22 | 23 | /* 24 | This stage has a Flow shape 25 | +-------+ 26 | | | 27 | ---> >Inlet > Logic > Outlet> ---> 28 | | | 29 | +-------+ 30 | */ 31 | 32 | // Shape definition 33 | val in: Inlet[Int] = Inlet("KadaneFlowStage.in") 34 | val out: Outlet[Int] = Outlet("KadaneFlowStage.out") 35 | override val shape: FlowShape[Int, Int] = FlowShape(in, out) 36 | 37 | // Logic for the stage 38 | override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = 39 | new GraphStageLogic(shape) { 40 | // state 41 | var maxEndingHere = 0 42 | var maxSoFar = 0 43 | 44 | // Handler(s) for the Inlet 45 | setHandler(in, new InHandler { 46 | // what to do when a new element is ready to be consumed 47 | override def onPush(): Unit = { 48 | val elem = grab(in) 49 | 50 | // "Business" logic 51 | maxEndingHere = Math.max(0, maxEndingHere + elem) 52 | maxSoFar = Math.max(maxSoFar, maxEndingHere) 53 | 54 | // this should never happen 55 | // we decide to not push the value, avoiding the error 56 | // but potentially losing the value 57 | if (isAvailable(out)) 58 | push(out, maxSoFar) 59 | } 60 | 61 | override def onUpstreamFinish(): Unit = { 62 | completeStage() 63 | } 64 | }) 65 | 66 | // Handler for the Outlet 67 | setHandler(out, new OutHandler { 68 | override def onPull(): Unit = { 69 | if (!hasBeenPulled(in)) 70 | pull(in) 71 | } 72 | }) 73 | 74 | } 75 | 76 | } 77 | 78 | object KadaneFlowMain extends App { 79 | implicit val system = ActorSystem() 80 | implicit val mat = ActorMaterializer() 81 | 82 | val kadaneFlowStage = new KadaneFlowStage 83 | 84 | val done = Source.repeat(1).take(100).map(_ => Random.nextInt(1100) - 1000) 85 | // .throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 86 | .via(kadaneFlowStage) 87 | .throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 88 | .runWith(Sink.foreach(println)) 89 | 90 | import scala.concurrent.ExecutionContext.Implicits.global 91 | 92 | done.onComplete(_ => system.terminate()) 93 | Await.ready(system.whenTerminated, Duration.Inf) 94 | } 95 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/model/LinearModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.model 19 | 20 | import org.apache.spark.streamdm.core._ 21 | 22 | /** 23 | * A Model trait defines the needed operations on any learning Model. It 24 | * provides methods for updating the model and for predicting the label of a 25 | * given Instance 26 | */ 27 | class LinearModel(lossFunction: Loss, initialModel: Instance,numberFeatures:Int) 28 | extends ClassificationModel with Serializable { 29 | 30 | type T = LinearModel 31 | 32 | val loss = lossFunction 33 | val modelInstance = initialModel 34 | val numFeatures = numberFeatures 35 | /* Update the model, depending on an Instance given for training 36 | * 37 | * @param instance the Instance based on which the Model is updated 38 | * @return the updated Model 39 | */ 40 | override def update(change: Example): LinearModel = 41 | new LinearModel(loss, modelInstance.add(change.in), numFeatures) 42 | 43 | /* Predict the label of the Instance, given the current Model 44 | * 45 | * @param instance the Instance which needs a class predicted 46 | * @return a Double representing the class predicted 47 | */ 48 | override def predict(instance: Example): Double = 49 | loss.predict(modelInstance.dot(instance.in.set(numFeatures,1.0))) 50 | 51 | /* Compute the loss of the direction of the change 52 | * @param instance the Instance for which the gradient is computed 53 | * @return an instance containging the gradients for every feature 54 | */ 55 | def gradient(instance: Example): Instance = { 56 | //compute the gradient based on the dot product, then compute the changes 57 | val ins = instance.in.set(numFeatures,1.0) 58 | val ch = -loss.gradient(instance.labelAt(0), modelInstance.dot(ins)) 59 | ins.map(x => ch*x) 60 | } 61 | 62 | def regularize(regularizer: Regularizer): Instance = 63 | modelInstance.map(x => -regularizer.gradient(x)) 64 | 65 | /** Computes the probability for a given label class, given the current Model 66 | * 67 | * @param instance the Instance which needs a class predicted 68 | * @return the predicted probability 69 | */ 70 | 71 | def prob(instance: Example): Double = 72 | loss.prob(modelInstance.dot(instance.in.set(numFeatures,1.0))) 73 | 74 | override def toString = "Model %s".format(modelInstance.toString) 75 | } 76 | -------------------------------------------------------------------------------- /src/main/resources/elecNormData.txt: -------------------------------------------------------------------------------- 1 | EXAMPLE;0,2,0,0.056443,0.439155,0.003467,0.422915,0.414912,UP 2 | EXAMPLE;0,2,0,0.056443,0.439155,0.003467,0.422915,0.414912,UP 3 | EXAMPLE;0,2,0.021277,0.051699,0.415055,0.003467,0.422915,0.414912,UP 4 | EXAMPLE;0,2,0.042553,0.051489,0.385004,0.003467,0.422915,0.414912,UP 5 | EXAMPLE;0,2,0.06383,0.045485,0.314639,0.003467,0.422915,0.414912,UP 6 | EXAMPLE;0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN 7 | EXAMPLE;0,2,0.106383,0.041161,0.207528,0.003467,0.422915,0.414912,DOWN 8 | EXAMPLE;0,2,0.12766,0.041161,0.171824,0.003467,0.422915,0.414912,DOWN 9 | EXAMPLE;0,2,0.148936,0.041161,0.152782,0.003467,0.422915,0.414912,DOWN 10 | QUERY;0,2,0.170213,0.041161,0.13493,0.003467,0.422915,0.414912,DOWN 11 | EXAMPLE;0,2,0.191489,0.041161,0.140583,0.003467,0.422915,0.414912,DOWN 12 | EXAMPLE;0,2,0.042553,0.051489,0.385004,0.003467,0.422915,0.414912,UP 13 | QUERY;0,2,0.06383,0.045485,0.314639,0.003467,0.422915,0.414912,UP 14 | EXAMPLE;0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN 15 | EXAMPLE;0,2,0.106383,0.041161,0.207528,0.003467,0.422915,0.414912,DOWN 16 | EXAMPLE;0,2,0,0.056443,0.439155,0.003467,0.422915,0.414912,UP 17 | EXAMPLE;0,2,0.021277,0.051699,0.415055,0.003467,0.422915,0.414912,UP 18 | EXAMPLE;0,2,0.042553,0.051489,0.385004,0.003467,0.422915,0.414912,UP 19 | EXAMPLE;0,2,0.06383,0.045485,0.314639,0.003467,0.422915,0.414912,UP 20 | EXAMPLE;0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN 21 | EXAMPLE;0,2,0.106383,0.041161,0.207528,0.003467,0.422915,0.414912,DOWN 22 | EXAMPLE;0,2,0.12766,0.041161,0.171824,0.003467,0.422915,0.414912,DOWN 23 | EXAMPLE;0,2,0.148936,0.041161,0.152782,0.003467,0.422915,0.414912,DOWN 24 | EXAMPLE;0,2,0.170213,0.041161,0.13493,0.003467,0.422915,0.414912,DOWN 25 | EXAMPLE;0,2,0.191489,0.041161,0.140583,0.003467,0.422915,0.414912,DOWN 26 | EXAMPLE;0,2,0,0.056443,0.439155,0.003467,0.422915,0.414912,UP 27 | EXAMPLE;0,2,0.021277,0.051699,0.415055,0.003467,0.422915,0.414912,UP 28 | QUERY;0,2,0.042553,0.051489,0.385004,0.003467,0.422915,0.414912,UP 29 | EXAMPLE;0,2,0.06383,0.045485,0.314639,0.003467,0.422915,0.414912,UP 30 | EXAMPLE;0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN 31 | EXAMPLE;0,2,0.106383,0.041161,0.207528,0.003467,0.422915,0.414912,DOWN 32 | QUERY;0,2,0.12766,0.041161,0.171824,0.003467,0.422915,0.414912,DOWN 33 | EXAMPLE;0,2,0.148936,0.041161,0.152782,0.003467,0.422915,0.414912,DOWN 34 | EXAMPLE;0,2,0.170213,0.041161,0.13493,0.003467,0.422915,0.414912,DOWN 35 | EXAMPLE;0,2,0.191489,0.041161,0.140583,0.003467,0.422915,0.414912,DOWN 36 | EXAMPLE;0,2,0,0.056443,0.439155,0.003467,0.422915,0.414912,UP 37 | EXAMPLE;0,2,0.021277,0.051699,0.415055,0.003467,0.422915,0.414912,UP 38 | QUERY;0,2,0.042553,0.051489,0.385004,0.003467,0.422915,0.414912,UP 39 | QUERY;0,2,0.06383,0.045485,0.314639,0.003467,0.422915,0.414912,UP 40 | EXAMPLE;0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN 41 | EXAMPLE;0,2,0.106383,0.041161,0.207528,0.003467,0.422915,0.414912,DOWN 42 | EXAMPLE;0,2,0.12766,0.041161,0.171824,0.003467,0.422915,0.414912,DOWN 43 | EXAMPLE;0,2,0.148936,0.041161,0.152782,0.003467,0.422915,0.414912,DOWN 44 | EXAMPLE;0,2,0.170213,0.041161,0.13493,0.003467,0.422915,0.414912,DOWN 45 | EXAMPLE;0,2,0.191489,0.041161,0.140583,0.003467,0.422915,0.414912,DOWN 46 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/Instance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * An Instance represents the input or output of any learning algorithm. It is 22 | * normally composed of a feature vector (having various implementations). 23 | */ 24 | 25 | trait Instance extends Serializable { 26 | 27 | type T <: Instance 28 | 29 | /** 30 | * Get the value present at position index 31 | * 32 | * @param index the index of the features 33 | * @return a Double representing the feature value, or 0 if the index is not 34 | * present in the underlying data structure 35 | */ 36 | def apply(index: Int): Double 37 | 38 | /** 39 | * Return an array of features and indexes 40 | * 41 | * @return an array of turple2(value,index) 42 | */ 43 | def getFeatureIndexArray(): Array[(Double, Int)] 44 | 45 | /** 46 | * Perform a dot product between two instances 47 | * 48 | * @param input an Instance with which the dot product is performed 49 | * @return a Double representing the dot product 50 | */ 51 | def dot(input: Instance): Double 52 | 53 | /** 54 | * Compute the Euclidean distance to another Instance 55 | * 56 | * @param input the Instance to which the distance is computed 57 | * @return a Double representing the distance value 58 | */ 59 | def distanceTo(input: Instance): Double 60 | 61 | /** 62 | * Perform an element by element addition between two instances 63 | * 64 | * @param input an Instance which is added up 65 | * @return an Instance representing the added Instances 66 | */ 67 | def add(input: Instance): T 68 | 69 | /** 70 | * Perform an element by element multiplication between two instances 71 | * 72 | * @param input an Instance which is multiplied 73 | * @return an Instance representing the Hadamard product 74 | */ 75 | def hadamard(input: Instance): T 76 | 77 | /** 78 | * Add a feature to the instance 79 | * 80 | * @param index the index at which the value is added 81 | * @param input the feature value which is added up 82 | * @return an Instance representing the new feature vector 83 | */ 84 | def set(index: Int, input: Double): T 85 | 86 | /** 87 | * Apply an operation to every feature of the Instance 88 | * 89 | * @param func the function for the transformation 90 | * @return a new Instance with the transformed features 91 | */ 92 | def map(func: Double => Double): T 93 | 94 | /** 95 | * Aggregate the values of an instance 96 | * 97 | * @param func the function for the transformation 98 | * @return the reduced value 99 | */ 100 | def reduce(func: (Double, Double) => Double): Double 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/specification/FeatureSpecification.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | 19 | package org.apache.spark.streamdm.core.specification 20 | 21 | import scala.collection.mutable.Map 22 | 23 | /** 24 | * trait FeatureSpecification. 25 | * 26 | */ 27 | trait FeatureSpecification { 28 | 29 | /** 30 | * whether the feature is nominal 31 | * 32 | * @return true if the feature is nominal 33 | */ 34 | def isNominal(): Boolean 35 | 36 | /** 37 | * whether the feature is numeric 38 | * 39 | * @return true if the feature is discrete 40 | */ 41 | def isNumeric(): Boolean 42 | /** 43 | * if a feature is numeric, return -1, else return the nominal values size 44 | */ 45 | def range(): Int 46 | } 47 | 48 | /** 49 | * class NumericFeatureSpecification. 50 | * 51 | */ 52 | class NumericFeatureSpecification extends FeatureSpecification with Serializable { 53 | /** 54 | * whether the feature is nominal 55 | * 56 | * @return true if the feature is nominal 57 | */ 58 | override def isNominal(): Boolean = false 59 | /** 60 | * whether the feature is numeric 61 | * 62 | * @return true if the feature is discrete 63 | */ 64 | override def isNumeric(): Boolean = !isNominal() 65 | override def range(): Int = -1 66 | } 67 | 68 | /** 69 | * A NominalFeatureSpecification contains information about its nominal values. 70 | * 71 | */ 72 | 73 | class NominalFeatureSpecification(nominalValues: Array[String]) extends FeatureSpecification with Serializable { 74 | val values = nominalValues 75 | 76 | val nominalToNumericMap = Map[String,Int]() 77 | values.zipWithIndex.map{ case (element, index) => 78 | (nominalToNumericMap += (element -> (index))) } 79 | 80 | /** Get the nominal string value present at position index 81 | * 82 | * @param index the index of the feature value 83 | * @return a string containing the nominal value of the feature 84 | */ 85 | def apply(index: Int): String = values(index) 86 | 87 | /** Get the position index given the nominal string value 88 | * 89 | * @param string a string containing the nominal value of the feature 90 | * @return the index of the feature value 91 | */ 92 | def apply(string: String): Int = nominalToNumericMap(string) 93 | 94 | /** 95 | * whether the feature is nominal 96 | * 97 | * @return true if the feature is nominal 98 | */ 99 | override def isNominal(): Boolean = true 100 | /** 101 | * whether the feature is numeric 102 | * 103 | * @return true if the feature is discrete 104 | */ 105 | override def isNumeric(): Boolean = !isNominal() 106 | /** 107 | * return the nominal values size 108 | */ 109 | override def range(): Int = values.length 110 | } 111 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/NullInstance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * A NullInstance is an Instance which does not contain anything in it. It is 22 | * present to aid the design of Example, and to allow cases when we have 23 | * instances which do not have output values. 24 | * Every operation on a NullInstance with either return a NullInstance or a 25 | * value of 0. 26 | */ 27 | 28 | case class NullInstance() extends Instance with Serializable { 29 | 30 | type T = NullInstance 31 | 32 | /* Get the feature value present at position index 33 | * 34 | * @param index the index of the desired value 35 | * @return a value of 0 36 | */ 37 | override def apply(index: Int): Double = 0.0 38 | 39 | /* 40 | * Return an array of features and indexes 41 | * 42 | * @return an array of turple2(value,index) 43 | */ 44 | def getFeatureIndexArray(): Array[(Double, Int)] = new Array[(Double, Int)](0) 45 | 46 | /* Perform a dot product between two instances 47 | * 48 | * @param input an Instance with which the dot 49 | * product is performed 50 | * @return a value of 0 51 | */ 52 | override def dot(input: Instance): Double = 0.0 53 | 54 | /** 55 | * Compute the Euclidean distance to another Instance 56 | * 57 | * @param input the Instance to which the distance is computed 58 | * @return an infinite distance (implemented as Double.MaxValue) 59 | */ 60 | def distanceTo(input: Instance): Double = Double.MaxValue 61 | 62 | /** 63 | * Perform an element by element addition between two instances 64 | * 65 | * @param input an Instance which is added up 66 | * @return a NullInstance 67 | */ 68 | override def add(input: Instance): NullInstance = this 69 | 70 | /** 71 | * Perform an element by element multiplication between two instances 72 | * 73 | * @param input an Instance which is multiplied 74 | * @return a NullInstance 75 | */ 76 | override def hadamard(input: Instance): NullInstance = this 77 | 78 | /** 79 | * Add a feature to the instance 80 | * 81 | * @param index the index at which the value is added 82 | * @param input the feature value which is added up 83 | * @return a NullInstance 84 | */ 85 | override def set(index: Int, input: Double): NullInstance = this 86 | 87 | /** 88 | * Apply an operation to every feature of a NullInstance 89 | * @param func the function for the transformation 90 | * @return a NullInstance 91 | */ 92 | override def map(func: Double => Double): NullInstance = this 93 | 94 | /** 95 | * Aggregate the values of a NullInstance 96 | * 97 | * @param func the function for the transformation 98 | * @return a value of 0 99 | */ 100 | override def reduce(func: (Double, Double) => Double): Double = 0.0 101 | 102 | override def toString = "" 103 | } 104 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/specification/ExampleSpecification.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core.specification 19 | 20 | /** 21 | * An ExampleSpecification contains information about the input and output 22 | * features. It contains a reference to an input InstanceSpecification and an 23 | * output InstanceSpecification, and provides setters and getters for the 24 | * feature specification properties. 25 | */ 26 | 27 | class ExampleSpecification(inInstanceSpecification: InstanceSpecification, 28 | outInstanceSpecification: InstanceSpecification) 29 | extends Serializable { 30 | 31 | val in = inInstanceSpecification 32 | val out = outInstanceSpecification 33 | 34 | /** 35 | * Gets the input FeatureSpecification value present at position index 36 | * 37 | * @param index the index of the specification 38 | * @return a FeatureSpecification representing the specification for the 39 | * feature 40 | */ 41 | def inputFeatureSpecification(index: Int): FeatureSpecification = in(index) 42 | 43 | /** 44 | * Gets the output FeatureSpecification value present at position index 45 | * 46 | * @param index the index of the specification 47 | * @return a FeatureSpecification representing the specification for the 48 | * feature 49 | */ 50 | def outputFeatureSpecification(index: Int): FeatureSpecification = out(index) 51 | 52 | /** 53 | * Evaluates whether an input feature is numeric 54 | * 55 | * @param index the index of the feature 56 | * @return true if the feature is numeric 57 | */ 58 | def isNumericInputFeature(index: Int): Boolean = in.isNumeric(index) 59 | 60 | /** 61 | * Evaluates whether an output feature is numeric 62 | * 63 | * @param index the index of the feature 64 | * @return true if the feature is numeric 65 | */ 66 | def isNumericOutputFeature(index: Int): Boolean = out.isNumeric(index) 67 | 68 | /** 69 | * Gets the input name of the feature at position index 70 | * 71 | * @param index the index of the class 72 | * @return a string representing the name of the feature 73 | */ 74 | def nameInputFeature(index: Int): String = in.name(index) 75 | 76 | /** 77 | * Gets the output name of the feature at position index 78 | * 79 | * @param index the index of the class 80 | * @return a string representing the name of the feature 81 | */ 82 | def nameOutputFeature(index: Int): String = out.name(index) 83 | 84 | /** 85 | * Gets the number of input features 86 | * 87 | * @return an Integer representing the number of input features 88 | */ 89 | def numberInputFeatures(): Int = in.size 90 | 91 | /** 92 | * Gets the number of output features 93 | * 94 | * @return an Integer representing the number of output features 95 | */ 96 | def numberOutputFeatures(): Int = out.size 97 | } 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/BloomFilterCrossStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.NotUsed 4 | import akka.actor.ActorSystem 5 | import akka.stream._ 6 | import akka.stream.scaladsl.{GraphDSL, RunnableGraph, Sink, Source} 7 | import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} 8 | import com.google.common.hash.{BloomFilter, Funnels} 9 | 10 | import scala.concurrent.duration.Duration 11 | import scala.util.Random 12 | 13 | // Cross Shape is actually BidiShape - for shape the semantics doesn't count, only syntax 14 | class BloomFilterCrossStage extends GraphStage[BidiShape[Int, Int, Int, String]] { 15 | 16 | // Stage syntax 17 | val dataIn: Inlet[Int] = Inlet("BloomFilterCrossStage.dataIn") 18 | val dataOut: Outlet[Int] = Outlet("BloomFilterCrossStage.dataOut") 19 | val queriesIn: Inlet[Int] = Inlet("BloomFilterCrossStage.queriesIn") 20 | val answersOut: Outlet[String] = Outlet("BloomFilterCrossStage.answersOut") 21 | override val shape: BidiShape[Int, Int, Int, String] = BidiShape(dataIn, dataOut, queriesIn, answersOut) 22 | 23 | // Stage semantics 24 | override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = 25 | new GraphStageLogic(shape) { 26 | 27 | // State 28 | val filter = BloomFilter.create[Integer](Funnels.integerFunnel(), 1000, 0.01) 29 | 30 | setHandler(dataIn, new InHandler { 31 | override def onPush(): Unit = { 32 | val elem = grab(dataIn) 33 | filter.put(elem) 34 | if (isAvailable(dataOut)) 35 | push(dataOut, elem) 36 | } 37 | 38 | override def onUpstreamFinish(): Unit = { 39 | completeStage() 40 | } 41 | }) 42 | 43 | setHandler(dataOut, new OutHandler { 44 | override def onPull(): Unit = { 45 | if (!hasBeenPulled(dataIn)) 46 | pull(dataIn) 47 | } 48 | }) 49 | 50 | setHandler(queriesIn, new InHandler { 51 | override def onPush(): Unit = { 52 | val x = grab(queriesIn) 53 | val answer = if (filter.mightContain(x)) { 54 | s"MAYBE, filter probably contains $x" 55 | } else { 56 | s"NO, filter definitely does not contain $x" 57 | } 58 | if (isAvailable(answersOut)) 59 | push(answersOut, answer) 60 | } 61 | 62 | override def onUpstreamFinish(): Unit = { 63 | completeStage() 64 | } 65 | }) 66 | 67 | setHandler(answersOut, new OutHandler { 68 | override def onPull(): Unit = { 69 | if (!hasBeenPulled(queriesIn)) 70 | pull(queriesIn) 71 | } 72 | }) 73 | 74 | } 75 | 76 | } 77 | 78 | object BloomFilterCrossStageMain extends App { 79 | implicit val system = ActorSystem() 80 | implicit val mat = ActorMaterializer() 81 | 82 | val crossStage = new BloomFilterCrossStage 83 | 84 | val graph = RunnableGraph.fromGraph(GraphDSL.create() { implicit builder: GraphDSL.Builder[NotUsed] => 85 | import GraphDSL.Implicits._ 86 | val inData = Source.repeat(1).take(1000).map(_ => Random.nextInt(1000)).throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 87 | val outData = Sink.foreach(println) 88 | val inControl = Source.repeat(1).take(100).map(_ => Random.nextInt(2000) - 1000).throttle(1, Duration(1500, "millisecond"), 1, ThrottleMode.shaping) 89 | val outControl = Sink.foreach(println) 90 | 91 | val cross = builder.add(crossStage) 92 | 93 | inData ~> cross.in1; cross.out1 ~> outData 94 | inControl ~> cross.in2; cross.out2 ~> outControl 95 | ClosedShape 96 | }).run() 97 | } 98 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/specification/InstanceSpecification.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core.specification 19 | 20 | import scala.collection.mutable.Map 21 | 22 | /** 23 | * An InstanceSpecification contains information about the features. 24 | * It returns information of features that are not numeric, 25 | * and the names of all features, numeric and discrete, 26 | */ 27 | 28 | class InstanceSpecification extends Serializable { 29 | val nominalFeatureSpecificationMap = Map[Int, FeatureSpecification]() 30 | val featureNameMap = Map[Int, String]() 31 | val numericFeatureSpecification: NumericFeatureSpecification = new NumericFeatureSpecification 32 | 33 | /** 34 | * Gets the FeatureSpecification value present at position index 35 | * 36 | * @param index the index of the position 37 | * @return a FeatureSpecification representing the specification for the 38 | * feature 39 | */ 40 | def apply(index: Int): FeatureSpecification = { 41 | if (nominalFeatureSpecificationMap.contains(index)) 42 | nominalFeatureSpecificationMap(index) 43 | else numericFeatureSpecification 44 | } 45 | 46 | /** 47 | * Removes the FeatureSpecification value present at position index 48 | * 49 | * @param index the index of the position 50 | * @return Unit 51 | */ 52 | def removeFeatureSpecification(index: Int): Unit = { 53 | if (nominalFeatureSpecificationMap.contains(index)) 54 | nominalFeatureSpecificationMap.remove(index) 55 | featureNameMap.remove(index) 56 | } 57 | 58 | /** 59 | * Adds a specification for the instance feature 60 | * 61 | * @param index the index at which the value is added 62 | * @param input the feature specification which is added up 63 | */ 64 | def addFeatureSpecification(index: Int, name: String, fSpecification: FeatureSpecification = null): Unit = { 65 | if (fSpecification != null && fSpecification.isInstanceOf[NominalFeatureSpecification]) { 66 | nominalFeatureSpecificationMap += (index -> fSpecification) 67 | } 68 | featureNameMap += (index -> name) 69 | } 70 | /** 71 | * Evaluates whether a feature is nominal or discrete 72 | * 73 | * @param index the index of the feature 74 | * @return true if the feature is discrete 75 | */ 76 | def isNominal(index: Int): Boolean = 77 | this(index).isNominal() 78 | 79 | /** 80 | * Evaluates whether a feature is numeric 81 | * 82 | * @param index the index of the feature 83 | * @return true if the feature is numeric 84 | */ 85 | def isNumeric(index: Int): Boolean = 86 | !isNominal(index) 87 | 88 | /** 89 | * Gets the name of the feature at position index 90 | * 91 | * @param index the index of the class 92 | * @return a string representing the name of the feature 93 | */ 94 | def name(index: Int): String = featureNameMap(index) 95 | 96 | /** 97 | * Gets the number of features 98 | * 99 | * @return the number of features 100 | */ 101 | def size(): Int = featureNameMap.size 102 | } 103 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/HoeffdingTreeFlowStream.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.actor.{Actor, ActorRef, ActorSystem, Props} 4 | import akka.stream._ 5 | import akka.stream.scaladsl.{GraphDSL, Merge, RunnableGraph, Sink, Source} 6 | import com.sun.xml.internal.ws.util.StreamUtils 7 | import org.apache.spark.streamdm.classifiers.trees.HoeffdingTree 8 | import org.apache.spark.streamdm.core.{Example, ExampleParser} 9 | import org.apache.spark.streamdm.core.specification.{ExampleSpecification, SpecificationParser} 10 | 11 | import scala.concurrent.Await 12 | import scala.concurrent.duration.Duration 13 | 14 | class HoeffdingTreeProcessor(val schema: ExampleSpecification) { 15 | // state 16 | val hTree = new HoeffdingTree 17 | hTree.init(schema) 18 | var count = 0 19 | 20 | def process(query: LearnerQuery) = { 21 | query.queryType match { 22 | case "EXAMPLE" => 23 | hTree.trainIncremental(query.elem) 24 | count += 1 25 | "Received: " + count + " " + query.elem.toString + ", learned\n" 26 | 27 | case "QUERY" => 28 | val answer = hTree.predictSingle(query.elem) 29 | count += 1 30 | "Received: " + count + " " + query.elem.toString + ", generated prediction " + answer._2 + "\n" 31 | 32 | case _ => "Unrecognised message, ignoring" 33 | } 34 | } 35 | } 36 | 37 | object HoeffdingTreeFlowStreamMain extends App { 38 | implicit val system = ActorSystem() 39 | implicit val mat = ActorMaterializer() 40 | 41 | val specParser = new SpecificationParser 42 | 43 | val arffPath = this.getClass.getResource("/elecNormNew.arff").getPath 44 | 45 | val exampleSpec = specParser.fromArff(arffPath) 46 | 47 | val example1 = ExampleParser.fromArff("0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN", exampleSpec) 48 | val example2 = ExampleParser.fromArff("0,2,0.255319,0.051489,0.298721,0.003467,0.422915,0.414912,UP", exampleSpec) 49 | val example3 = ExampleParser.fromArff("0.424627,6,0.234043,0.070854,0.108004,0.003467,0.422915,0.414912,DOWN", exampleSpec) 50 | val example4 = ExampleParser.fromArff("0.424627,6,0.170213,0.070854,0.070515,0.003467,0.422915,0.414912,UP", exampleSpec) 51 | 52 | /* 53 | val learn1 = LearnerQuery("EXAMPLE", example1) 54 | val learn2 = LearnerQuery("EXAMPLE", example2) 55 | val query1 = LearnerQuery("QUERY", example1) 56 | val query2 = LearnerQuery("QUERY", example2) 57 | val query3 = LearnerQuery("QUERY", example3) 58 | val query4 = LearnerQuery("QUERY", example4) 59 | */ 60 | 61 | val examples = List(example1, example2) 62 | val queries = List(example1, example2, example3, example4, example2) 63 | 64 | val masterControlProgram = RunnableGraph.fromGraph(GraphDSL.create(Sink.foreach(print)) { implicit builder => 65 | outMatches => 66 | import GraphDSL.Implicits._ 67 | val inExamples = Source.fromIterator(() => examples.toIterator).throttle(1, Duration(2500, "millisecond"), 1, ThrottleMode.shaping) 68 | val inQueries = Source.fromIterator(() => queries.toIterator).throttle(1, Duration(1000, "millisecond"), 1, ThrottleMode.shaping) 69 | 70 | val taggedExamples = inExamples.map(LearnerQuery("EXAMPLE", _)) 71 | val taggedQueries = inQueries.map(LearnerQuery("QUERY", _)) 72 | val fanIn = builder.add(Merge[LearnerQuery](2)) 73 | 74 | taggedExamples ~> fanIn.in(0) 75 | fanIn.out.statefulMapConcat(() => { 76 | val proc = new HoeffdingTreeProcessor(exampleSpec) 77 | proc.process(_) 78 | }) ~> outMatches 79 | taggedQueries ~> fanIn.in(1) 80 | ClosedShape 81 | }).run() 82 | 83 | import scala.concurrent.ExecutionContext.Implicits.global 84 | 85 | masterControlProgram.onComplete(_ => system.terminate()) 86 | Await.ready(system.whenTerminated, Duration.Inf) 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/BloomFilterCrossMatStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.{Done, NotUsed} 4 | import akka.actor.ActorSystem 5 | import akka.stream._ 6 | import akka.stream.scaladsl.{GraphDSL, RunnableGraph, Sink, Source} 7 | import akka.stream.stage._ 8 | import com.google.common.hash.{BloomFilter, Funnels} 9 | 10 | import scala.concurrent.{Await, Future, Promise} 11 | import scala.concurrent.duration.Duration 12 | import scala.util.Random 13 | 14 | // Cross Shape is actually BidiShape - for shape the semantics doesn't count, only syntax 15 | class BloomFilterCrossMatStage extends GraphStageWithMaterializedValue[BidiShape[Int, Int, Int, String], Future[Done]] { 16 | 17 | // Stage syntax 18 | val dataIn: Inlet[Int] = Inlet("BloomFilterCrossMatStage.dataIn") 19 | val dataOut: Outlet[Int] = Outlet("BloomFilterCrossMatStage.dataOut") 20 | val queriesIn: Inlet[Int] = Inlet("BloomFilterCrossMatStage.queriesIn") 21 | val answersOut: Outlet[String] = Outlet("BloomFilterCrossMatStage.answersOut") 22 | override val shape: BidiShape[Int, Int, Int, String] = BidiShape(dataIn, dataOut, queriesIn, answersOut) 23 | 24 | // Stage semantics 25 | override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { 26 | // Completion notification 27 | val p: Promise[Done] = Promise() 28 | 29 | val logic = new GraphStageLogic(shape) { 30 | // State 31 | val filter = BloomFilter.create[Integer](Funnels.integerFunnel(), 1000, 0.01) 32 | 33 | setHandler(dataIn, new InHandler { 34 | override def onPush(): Unit = { 35 | val elem = grab(dataIn) 36 | filter.put(elem) 37 | if (isAvailable(dataOut)) 38 | push(dataOut, elem) 39 | } 40 | 41 | override def onUpstreamFinish(): Unit = { 42 | completeStage() 43 | } 44 | }) 45 | 46 | setHandler(dataOut, new OutHandler { 47 | override def onPull(): Unit = { 48 | if (!hasBeenPulled(dataIn)) 49 | pull(dataIn) 50 | } 51 | }) 52 | 53 | setHandler(queriesIn, new InHandler { 54 | override def onPush(): Unit = { 55 | val x = grab(queriesIn) 56 | val answer = if (filter.mightContain(x)) { 57 | s"MAYBE, filter probably contains $x" 58 | } else { 59 | s"NO, filter definitely does not contain $x" 60 | } 61 | if (isAvailable(answersOut)) 62 | push(answersOut, answer) 63 | } 64 | 65 | override def onUpstreamFinish(): Unit = { 66 | p.trySuccess(Done) // we are done when no more queries 67 | completeStage() 68 | } 69 | 70 | override def onUpstreamFailure(ex: Throwable): Unit = { 71 | p.tryFailure(ex) 72 | failStage(ex) 73 | } 74 | }) 75 | 76 | setHandler(answersOut, new OutHandler { 77 | override def onPull(): Unit = { 78 | if (!hasBeenPulled(queriesIn)) 79 | pull(queriesIn) 80 | } 81 | }) 82 | 83 | } 84 | (logic, p.future) 85 | } 86 | } 87 | 88 | object BloomFilterCrossStageMatMain extends App { 89 | implicit val system = ActorSystem() 90 | implicit val mat = ActorMaterializer() 91 | 92 | val crossStage = new BloomFilterCrossStage 93 | 94 | val graph = RunnableGraph.fromGraph(GraphDSL.create(Sink.foreach(println)) { implicit builder => outControl => 95 | import GraphDSL.Implicits._ 96 | val inData = Source.repeat(1).take(100).map(_ => Random.nextInt(1000)).throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 97 | val outData = Sink.foreach(println) 98 | val inControl = Source.repeat(1).take(10).map(_ => Random.nextInt(2000) - 1000).throttle(1, Duration(1500, "millisecond"), 1, ThrottleMode.shaping) 99 | //val outControl = Sink.foreach(println) // Moved to foreach/builder construct 100 | 101 | val cross = builder.add(crossStage) 102 | 103 | inData ~> cross.in1; cross.out1 ~> outData 104 | inControl ~> cross.in2; cross.out2 ~> outControl 105 | ClosedShape 106 | }).run() 107 | 108 | import scala.concurrent.ExecutionContext.Implicits.global 109 | 110 | graph.onComplete(_ => system.terminate()) 111 | Await.ready(system.whenTerminated, Duration.Inf) 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/HoeffdingTreeFlowActor.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.actor.{Actor, ActorRef, ActorSystem, Props} 4 | import akka.event.Logging 5 | import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} 6 | import akka.stream.scaladsl.{Sink, Source, SourceQueueWithComplete} 7 | import org.apache.spark.streamdm.classifiers.trees.HoeffdingTree 8 | import org.apache.spark.streamdm.core.{Example, ExampleParser} 9 | import org.apache.spark.streamdm.core.specification.{ExampleSpecification, SpecificationParser} 10 | 11 | import scala.concurrent.{Await, Future} 12 | import scala.concurrent.duration.Duration 13 | import akka.pattern.pipe // For the pipeTo pattern 14 | 15 | case class LearnerQuery(queryType : String, elem : Example) 16 | 17 | class HoeffdingTreeFlowActor (queue: SourceQueueWithComplete[(Example, Double)], schema: ExampleSpecification) extends Actor { 18 | val log = Logging(context.system, this) 19 | // state 20 | val hTree = new HoeffdingTree 21 | hTree.init(schema) 22 | var upStream : ActorRef = self 23 | 24 | // Execution context for scheduling of the piped futures 25 | import scala.concurrent.ExecutionContext.Implicits.global 26 | 27 | def receive = { 28 | case "TEST" => log.info("Received TEST message") 29 | upStream = sender() 30 | sender() ! "ACK" // ask for first element 31 | 32 | case "END" => log.info("Received END message, terminating actor") 33 | queue.complete() 34 | 35 | case query : LearnerQuery => // onPush() + grab(in) 36 | // "Business" logic 37 | 38 | query.queryType match { 39 | case "EXAMPLE" => 40 | hTree.trainIncremental(query.elem) 41 | log.info(s"Received: ${query.elem.toString}, learned") 42 | val futureF = Future.successful(QueueOfferResult.Enqueued) 43 | futureF pipeTo self 44 | 45 | case "QUERY" => 46 | val answer = hTree.predictSingle(query.elem) 47 | log.info(s"Received: ${query.elem.toString}, generating prediction") 48 | val offF = queue.offer(answer) // push element downstream 49 | offF pipeTo self // generate backpressure in the Actor 50 | 51 | case _ => // ignore but continue consuming 52 | val futureF = Future.successful(QueueOfferResult.Enqueued) 53 | futureF pipeTo self 54 | 55 | } 56 | 57 | case e : QueueOfferResult => 58 | upStream ! "ACK" // ask for next element 59 | 60 | case _ => log.info("Unrecognised message, ignoring") 61 | } 62 | } 63 | 64 | object HoeffdingTreeFlowActorMain extends App { 65 | implicit val system = ActorSystem() 66 | implicit val mat = ActorMaterializer() 67 | 68 | import scala.concurrent.ExecutionContext.Implicits.global 69 | 70 | val actorQueue = Source.queue[(Example, Double)](10, OverflowStrategy.backpressure) 71 | val actorSink = actorQueue 72 | .throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 73 | .to(Sink.foreach((x) => println(x.toString()))).run() 74 | 75 | val done = actorSink.watchCompletion() 76 | 77 | val specParser = new SpecificationParser 78 | 79 | val arffPath = this.getClass.getResource("/elecNormNew.arff").getPath 80 | 81 | val exampleSpec = specParser.fromArff(arffPath) 82 | 83 | val hoeffdingTreeFlowActor = system.actorOf(Props(new HoeffdingTreeFlowActor(actorSink, exampleSpec))) 84 | 85 | val example1 = ExampleParser.fromArff("0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN", exampleSpec) 86 | val example2 = ExampleParser.fromArff("0,2,0.255319,0.051489,0.298721,0.003467,0.422915,0.414912,UP", exampleSpec) 87 | val example3 = ExampleParser.fromArff("0.424627,6,0.234043,0.070854,0.108004,0.003467,0.422915,0.414912,DOWN", exampleSpec) 88 | val example4 = ExampleParser.fromArff("0.424627,6,0.170213,0.070854,0.070515,0.003467,0.422915,0.414912,UP", exampleSpec) 89 | 90 | 91 | val learn1 = LearnerQuery("EXAMPLE", example1) 92 | val learn2 = LearnerQuery("EXAMPLE", example2) 93 | val query1 = LearnerQuery("QUERY", example1) 94 | val query2 = LearnerQuery("QUERY", example2) 95 | val query3 = LearnerQuery("QUERY", example3) 96 | val query4 = LearnerQuery("QUERY", example4) 97 | 98 | val examples = List(learn1, query1, query2, query3, query4, learn2, query2) 99 | 100 | Source.fromIterator(() => examples.toIterator) 101 | .throttle(1, Duration(1000, "millisecond"), 1, ThrottleMode.shaping) 102 | .runWith(Sink.actorRefWithAck(hoeffdingTreeFlowActor, "TEST", "ACK", "END", _ => "FAIL")) 103 | 104 | done.onComplete(_ => system.terminate()) 105 | Await.ready(system.whenTerminated, Duration.Inf) 106 | } -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/Main.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.NotUsed 4 | import akka.actor.{Actor, ActorRef, ActorSystem, PoisonPill, Props} 5 | import akka.event.Logging 6 | import akka.http.scaladsl.Http 7 | import akka.http.scaladsl.model.{HttpMethods, HttpRequest, HttpResponse, Uri} 8 | import akka.stream.QueueOfferResult.Enqueued 9 | import akka.stream.scaladsl.{Flow, GraphDSL, Keep, RunnableGraph, Sink, Source, SourceQueueWithComplete} 10 | import akka.stream._ 11 | 12 | import scala.concurrent.Future 13 | import akka.pattern.pipe 14 | import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler} 15 | import com.abahgat.suffixtree.GeneralizedSuffixTree 16 | import com.google.common.hash.{BloomFilter, Funnels} 17 | import org.apache.spark.streamdm.classifiers.trees.HoeffdingTree 18 | import org.apache.spark.streamdm.core.specification._ 19 | import org.apache.spark.streamdm.core.{Example, ExampleParser, Instance, TextInstance} 20 | 21 | import scala.concurrent.Await 22 | import scala.concurrent.duration.Duration 23 | import scala.util.Random 24 | 25 | 26 | object Main { 27 | implicit val system = ActorSystem() 28 | implicit val mat = ActorMaterializer() 29 | 30 | import scala.concurrent.ExecutionContext.Implicits.global 31 | 32 | def main(args: Array[String]): Unit = { 33 | println("Hello from main") 34 | 35 | val reqResponseFlow = Flow[HttpRequest].map[HttpResponse] (_ match { 36 | case HttpRequest(HttpMethods.GET, Uri.Path("/"), _, _, _) => 37 | HttpResponse(200, entity = "Hello!") 38 | 39 | case _ => HttpResponse(200, entity = "Ooops, not found") 40 | }) 41 | 42 | Http().bindAndHandle(reqResponseFlow, "localhost", 8888) 43 | 44 | // system.scheduler.schedule(Duration(100, "millisecond"), Duration(50, "millisecond"), myActor, "KABOOM !!!") 45 | // 46 | // val stdoutSink = new StdoutSink 47 | // 48 | // val done = 49 | // Source 50 | // .repeat("Hello") 51 | // .zip(Source.fromIterator(() => Iterator.from(0))) 52 | // .take(7) 53 | // .mapConcat{ 54 | // case (s, n) => 55 | // val i = " " * n 56 | // f"$i$s%n" 57 | // } 58 | // .throttle(42, Duration(1500, "millisecond"), 1, ThrottleMode.Shaping) 59 | // .runWith(Sink.actorRefWithAck(myActor, "test", "ack", "end", _ => "fail")) 60 | 61 | //done.onComplete(_ => system.terminate()) 62 | 63 | 64 | val crossStage = new BloomFilterCrossStage 65 | 66 | // Source.repeat(1).take(100).map(_ => Random.nextInt(1100) - 1000).via(kadaneStage).runWith(Sink.foreach(println)) 67 | 68 | // val g = RunnableGraph.fromGraph(GraphDSL.create() { implicit builder: GraphDSL.Builder[NotUsed] => 69 | // import GraphDSL.Implicits._ 70 | // val inData = Source.repeat(1).take(1000).map(_ => Random.nextInt(1000)).throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 71 | // val outData = Sink.foreach(println) 72 | // val inControl = Source.repeat(1).take(100).map(_ => Random.nextInt(2000) - 1000).throttle(1, Duration(1500, "millisecond"), 1, ThrottleMode.shaping) 73 | // val outControl = Sink.foreach(println) 74 | // 75 | // val cross = builder.add(crossStage) 76 | // 77 | // inData ~> cross.in1; cross.out1 ~> outData 78 | // inControl ~> cross.in2; cross.out2 ~> outControl 79 | // ClosedShape 80 | // }).run() 81 | 82 | println("Now trying the Hoeffding Tree") 83 | 84 | println("arffStuff") 85 | 86 | val specParser = new SpecificationParser 87 | 88 | val exampleSpec = specParser.fromArff("/home/janek/Downloads/arff/elecNormNew.arff") 89 | 90 | val example1 = ExampleParser.fromArff("0,2,0.085106,0.042482,0.251116,0.003467,0.422915,0.414912,DOWN", exampleSpec) 91 | 92 | val example2 = ExampleParser.fromArff("0,2,0.255319,0.051489,0.298721,0.003467,0.422915,0.414912,UP", exampleSpec) 93 | 94 | println("example Arff " + example1.in.toString + " / " + example1.out.toString) 95 | println("example Arff2 " + example2.in.toString + " / " + example2.out.toString) 96 | 97 | println("Spec " + exampleSpec.in.size() + " " + exampleSpec.out.size() + " " + exampleSpec.out.isNominal(0)) 98 | 99 | println("after arff") 100 | 101 | val hTree = new HoeffdingTree 102 | 103 | hTree.init(exampleSpec) 104 | 105 | hTree.trainIncremental(example1) 106 | 107 | println(hTree.predictSingle(example1)._2) 108 | 109 | hTree.trainIncremental(example2) 110 | 111 | println(hTree.predictSingle(example2)._2) 112 | 113 | println("now suffix tree") 114 | 115 | val suffixTree = new GeneralizedSuffixTree() 116 | 117 | suffixTree.put("cacao", 0) 118 | 119 | println("Searching: " + suffixTree.search("cac")) 120 | 121 | println("Searching: " + suffixTree.search("caco")) 122 | 123 | Await.ready(system.whenTerminated, Duration.Inf) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/Example.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * An Example is a wrapper on top of the Instance class hierarchy. It contains a 22 | * reference to an input Instance and an output Instance, and provides setters 23 | * and getters for the features and labels. This is done so that the DStream 24 | * accepts any type of Instance in the parameters, and that the same DStream can 25 | * be allowed to contain multiple types of Instance. 26 | */ 27 | class Example(inInstance: Instance, outInstance: Instance = new NullInstance, 28 | weightValue: Double=1.0) 29 | extends Serializable { 30 | 31 | val in = inInstance 32 | val out = outInstance 33 | val weight = weightValue 34 | 35 | /** Get the input value present at position index 36 | * 37 | * @param index the index of the value 38 | * @return a Double representing the feature value 39 | */ 40 | def featureAt(index: Int): Double = in(index) 41 | 42 | /** Get the output value present at position index 43 | * 44 | * @param index the index of the value 45 | * @return a Double representing the value 46 | */ 47 | def labelAt(index: Int): Double = out(index) 48 | 49 | /** Set the weight of the Example 50 | * 51 | * @param value the weight value 52 | * @return an Example containing the new weight 53 | */ 54 | def setWeight(value: Double): Example = 55 | new Example(in, out, value) 56 | 57 | /** Add a feature to the instance in the example 58 | * 59 | * @param index the index at which the value is added 60 | * @param input the feature value which is added up 61 | * @return an Example containing an Instance with the new features 62 | */ 63 | def setFeature(index: Int, input: Double): Example = 64 | new Example(in.set(index, input), out, weight) 65 | 66 | /** Add a feature to the instance in the example 67 | * 68 | * @param index the index at which the value is added 69 | * @param input the label value which is added up 70 | * @return an Example containing an Instance with the new labels 71 | */ 72 | def setLabel(index: Int, input: Double): Example = 73 | new Example(in, out.set(index, input), weight) 74 | 75 | override def toString = { 76 | val inString = in.toString 77 | val weightString = if (weight==1.0) "" else " %f".format(weight) 78 | val outString = out match { 79 | case NullInstance() => "" 80 | case _ => "%s ".format(out.toString) 81 | } 82 | "%s%s%s".format(outString, inString, weightString) 83 | } 84 | } 85 | 86 | object Example extends Serializable { 87 | 88 | /** Parse the input string as an SparseInstance class. The input and output 89 | * instances are separated by a whitespace character, of the form 90 | * "output_instanceinput_instanceweight". The output 91 | * and the weight can be missing. 92 | * 93 | * @param input the String line to be read 94 | * @param outType String specifying the format of the output instance 95 | * @return a DenseInstance which is parsed from input 96 | */ 97 | def parse(input: String, inType: String, outType: String): Example = { 98 | val tokens = input.split("\\s+") 99 | val numTokens = tokens.length 100 | if (numTokens==1) 101 | new Example(getInstance(tokens.head, inType)) 102 | else if (numTokens==2) 103 | new Example(getInstance(tokens.last, inType), 104 | getInstance(tokens.head, outType)) 105 | else 106 | new Example(getInstance(tokens.tail.head, inType), 107 | getInstance(tokens.head, outType), tokens.last.toDouble) 108 | } 109 | 110 | /** Parse the input string based on the type of Instance, by calling the 111 | * associated .parse static method 112 | * @param input the String to be parsed 113 | * @param instType the type of instance to be parsed ("dense" or "sparse") 114 | * @return the parsed Instance, or null if the type is not properly specified 115 | */ 116 | private def getInstance(input: String, instType: String): Instance = 117 | instType match { 118 | case "dense" => DenseInstance.parse(input) 119 | case "sparse" => SparseInstance.parse(input) 120 | case _ => null 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/ExampleParser.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | import org.apache.spark.streamdm.core.specification._ 21 | 22 | /* 23 | * object ExampleParser helps to parse example from/to different data format. 24 | */ 25 | object ExampleParser { 26 | 27 | /* 28 | * create Example from arff format string 29 | * 30 | * @param e arff format string 31 | * @spec ExampleSpecification 32 | * @return Example 33 | */ 34 | def fromArff(e: String, spec: ExampleSpecification): Example = { 35 | var input: String = e 36 | var output: String = null 37 | val isSparse: Boolean = e.contains("{") && e.contains("}") 38 | if (isSparse) { 39 | // sparse data 40 | input = e.substring(e.indexOf("{") + 1, e.indexOf("}")).trim() 41 | } 42 | if (spec.out.size() == 1) { 43 | val pos = input.lastIndexOf(",") 44 | output = input.substring(pos + 1).trim() 45 | input = input.substring(0, pos) 46 | } 47 | new Example(arffToInstace(input, spec.in, isSparse), arffToInstace(output, spec.out, isSparse)) 48 | } 49 | 50 | /* 51 | * create Instance from arff format string 52 | * 53 | * @param e arff format string 54 | * @spec InstanceSpecification 55 | * @return Instance 56 | */ 57 | private def arffToInstace(data: String, spec: InstanceSpecification, isSparse: Boolean): Instance = { 58 | if (data == null || data.length() == 0) null 59 | else { 60 | if (isSparse) 61 | arffToSparceInstace(data, spec) 62 | else 63 | arffToDenseInstace(data, spec) 64 | } 65 | } 66 | 67 | /* 68 | * create SparseInstance from arff format string 69 | * 70 | * @param e arff format string 71 | * @spec InstanceSpecification 72 | * @return SparseInstance 73 | */ 74 | private def arffToSparceInstace(data: String, spec: InstanceSpecification): SparseInstance = { 75 | val tokens = data.split(",[\\s]?") 76 | val values = Array[Double](tokens.length) 77 | val features = tokens.map(_.split("\\s+")) 78 | for (index <- 0 until tokens.length) { 79 | values(index) = spec(index) match { 80 | case nominal: NominalFeatureSpecification => nominal(features(index)(1)) 81 | case _ => features(index)(1).toDouble 82 | } 83 | } 84 | new SparseInstance(features.map(_(0).toInt), values) 85 | } 86 | 87 | /* 88 | * create DenseInstance from arff format string 89 | * 90 | * @param e arff format string 91 | * @spec InstanceSpecification 92 | * @return DenseInstance 93 | */ 94 | private def arffToDenseInstace(data: String, spec: InstanceSpecification): DenseInstance = { 95 | val stringValues = data.split(",[\\s]?") 96 | val values = new Array[Double](stringValues.length) 97 | for (index <- 0 until stringValues.length) { 98 | values(index) = spec(index) match { 99 | case nominal: NominalFeatureSpecification => nominal(stringValues(index)) 100 | case _ => stringValues(index).toDouble 101 | } 102 | } 103 | new DenseInstance(values) 104 | } 105 | 106 | /* 107 | * create arff format string for Example 108 | * 109 | * @param e Exmaple 110 | * @spec ExampleSpecification 111 | * @return arff format string 112 | */ 113 | def toArff(e: Example, spec: ExampleSpecification): String = { 114 | val isSparse = e.in.isInstanceOf[SparseInstance] 115 | val sb = new StringBuffer() 116 | sb.append(instanceToArff(e.in, spec.in, 0)) 117 | if (spec.out.size() != 0) 118 | sb.append(", " + instanceToArff(e.out, spec.out, spec.in.size())) 119 | sb.delete(sb.lastIndexOf(","), sb.length()) 120 | if (isSparse) { 121 | sb.insert(0, "{") 122 | sb.append("}") 123 | } 124 | sb.toString() 125 | } 126 | /* 127 | * create part arff format string for Instance 128 | * 129 | * @param instance: Instance 130 | * @spec ExampleSpecification 131 | * @return part arff format string 132 | */ 133 | def instanceToArff(instance: Instance, insSpec: InstanceSpecification, startIndex: Int): String = { 134 | val isSparse = instance.isInstanceOf[SparseInstance] 135 | val sb = new StringBuffer() 136 | instance.getFeatureIndexArray().foreach(token => 137 | { 138 | if (isSparse) { 139 | sb.append(startIndex + token._2 + " ") 140 | } 141 | val value = insSpec(token._2) match { 142 | case nominal: NominalFeatureSpecification => nominal(token._1.toInt) 143 | case numeric: NumericFeatureSpecification => token._1 144 | } 145 | sb.append(value + ", ") 146 | }) 147 | sb.toString() 148 | } 149 | 150 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/utils/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.utils 19 | 20 | import java.io._ 21 | import java.util.Random 22 | import scala.math.{ min, max, log } 23 | 24 | import org.apache.spark.streamdm.classifiers.OnlineClassifier 25 | 26 | /** 27 | * Utility methods. 28 | * 29 | */ 30 | object Utils { 31 | 32 | /* Copy a classifier using serialization 33 | * 34 | * @param classifier the original classifier to copy 35 | * @return the copy of the classifier 36 | */ 37 | def copyClassifier(classifier: OnlineClassifier): OnlineClassifier = { 38 | val baoStream: ByteArrayOutputStream = new ByteArrayOutputStream() 39 | val out: ObjectOutputStream = new ObjectOutputStream( 40 | new BufferedOutputStream(baoStream)) 41 | out.writeObject(classifier) 42 | out.flush() 43 | out.close() 44 | val byteArray: Array[Byte] = baoStream.toByteArray() 45 | val in: ObjectInputStream = new ObjectInputStream(new BufferedInputStream( 46 | new ByteArrayInputStream(byteArray))) 47 | val copy: OnlineClassifier = in.readObject().asInstanceOf[OnlineClassifier] 48 | in.close() 49 | copy 50 | } 51 | 52 | /* Compute a random value from a Poisson distribution 53 | * 54 | * @param lambda the mean of the Poisson distribution 55 | * @param r the random generator 56 | * @return a random value sampled from the distribution 57 | */ 58 | def poisson(lambda: Double, r: Random) = { 59 | if (lambda < 100.0) { 60 | var product = 1.0 61 | var sum = 1.0 62 | val threshold = r.nextDouble() * Math.exp(lambda) 63 | var i = 1.0 64 | var max = Math.max(100, 10 * Math.ceil(lambda).toInt) 65 | while ((i < max) && (sum <= threshold)) { 66 | product *= (lambda / i) 67 | sum += product 68 | i += 1.0 69 | } 70 | i - 1.0 71 | } else { 72 | val x = lambda + Math.sqrt(lambda) * r.nextGaussian() 73 | if (x < 0.0) 0.0 else Math.floor(x) 74 | } 75 | } 76 | 77 | /* Get the most frequent value of an array of numeric values 78 | * 79 | * @param array the Array of numeric values 80 | * @return the argument of the most frequent value 81 | */ 82 | def majorityVote(array: Array[Double], size: Integer): Double = { 83 | val frequencyArray: Array[Double] = Array.fill(size)(0) 84 | for (i <- 0 until array.length) 85 | frequencyArray(array(i).toInt) += 1 86 | argmax(frequencyArray) 87 | } 88 | 89 | /* Get the argument of the minimum value of an array of numeric values 90 | * 91 | * @param array the Array of numeric values 92 | * @return the argument of the minimum value 93 | */ 94 | def argmax(array: Array[Double]): Double = array.zipWithIndex.maxBy(_._1)._2 95 | /* 96 | * Get the log2 of input 97 | * 98 | * @param v double value 99 | * @return the log2 of v 100 | */ 101 | def log2(v: Double): Double = log(v) / log(2) 102 | 103 | /* Transpose a matrix 104 | * 105 | * @param input matrix in form of 2-D array 106 | * @return the transpose of input matrix 107 | */ 108 | def transpose(input: Array[Array[Double]]): Array[Array[Double]] = { 109 | val output: Array[Array[Double]] = Array.fill(input(0).length)(new Array[Double](input.length)) 110 | input.zipWithIndex.map { 111 | row => 112 | row._1.zipWithIndex.map { 113 | col => output(col._2)(row._2) = input(row._2)(col._2) 114 | } 115 | } 116 | output 117 | } 118 | /* 119 | * Split a matrix with the input index, merge other columns into one column and transpose 120 | * 121 | * @param input matrix in form of 2-D array 122 | * @param fIndex index of columns 123 | * @return a matrix of 2 rows 124 | */ 125 | def splitTranspose(input: Array[Array[Double]], fIndex: Int): Array[Array[Double]] = { 126 | val output: Array[Array[Double]] = Array.fill(2)(new Array[Double](input.length)) 127 | input.zipWithIndex.map { 128 | row => 129 | row._1.zipWithIndex.map { 130 | col => 131 | if (col._2 == fIndex) output(0)(row._2) = input(row._2)(col._2) 132 | else output(1)(row._2) += input(row._2)(col._2) 133 | } 134 | } 135 | output 136 | } 137 | 138 | /* 139 | * Normalize input matrix 140 | * 141 | * @param input matrix in form of 2-D array 142 | * @return normalized matrix 143 | */ 144 | def normal(input: Array[Array[Double]]): Array[Array[Double]] = { 145 | val total = input.map(_.sum).sum 146 | input.map { row => row.map { _ / total } } 147 | } 148 | 149 | /* 150 | * Normalize input array 151 | * 152 | * @param input double array 153 | * @return normalized array 154 | */ 155 | def normal(input: Array[Double]): Array[Double] = { 156 | val total = input.sum 157 | input.map { { _ / total } } 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /src/main/scala/pl/gosub/akka/online/SuffixTreeTripodMatStage.scala: -------------------------------------------------------------------------------- 1 | package pl.gosub.akka.online 2 | 3 | import akka.Done 4 | import akka.actor.ActorSystem 5 | import akka.stream._ 6 | import akka.stream.scaladsl.{GraphDSL, RunnableGraph, Sink, Source} 7 | import akka.stream.stage.{GraphStageLogic, GraphStageWithMaterializedValue, InHandler, OutHandler} 8 | import com.abahgat.suffixtree.GeneralizedSuffixTree 9 | 10 | import scala.annotation.unchecked.uncheckedVariance 11 | import scala.collection.immutable 12 | import scala.concurrent.duration.Duration 13 | import scala.concurrent.{Await, Future, Promise} 14 | 15 | /** 16 | * A Y-shaped flow of elements that consequently has two inputs 17 | * and one output, arranged like this: 18 | * 19 | * {{{ 20 | * +--------+ 21 | * In1 ~>| | 22 | * | tripod |~> Out 23 | * In2 ~>| | 24 | * +--------+ 25 | * }}} 26 | */ 27 | final case class TripodShape[-In1, -In2, +Out]( 28 | in1: Inlet[In1 @uncheckedVariance], 29 | in2: Inlet[In2 @uncheckedVariance], 30 | out: Outlet[Out @uncheckedVariance]) extends Shape { 31 | override val inlets: immutable.Seq[Inlet[_]] = List(in1, in2) 32 | override val outlets: immutable.Seq[Outlet[_]] = List(out) 33 | 34 | override def deepCopy(): TripodShape[In1, In2, Out] = 35 | TripodShape(in1.carbonCopy(), in2.carbonCopy(), out.carbonCopy()) 36 | override def copyFromPorts(inlets: immutable.Seq[Inlet[_]], outlets: immutable.Seq[Outlet[_]]): Shape = { 37 | require(inlets.size == 2, s"proposed inlets [${inlets.mkString(", ")}] do not fit TripodShape") 38 | require(outlets.size == 1, s"proposed outlets [${outlets.mkString(", ")}] do not fit TripodShape") 39 | TripodShape(inlets(0), inlets(1), outlets(0)) 40 | } 41 | def reversed: Shape = copyFromPorts(inlets.reverse, outlets.reverse) 42 | } 43 | 44 | object TripodShape { 45 | def of[In1, In2, Out]( 46 | in1: Inlet[In1 @uncheckedVariance], 47 | in2: Inlet[In2 @uncheckedVariance], 48 | out: Outlet[Out @uncheckedVariance]): TripodShape[In1, In2, Out] = 49 | TripodShape(in1, in2, out) 50 | 51 | } 52 | 53 | 54 | class SuffixTreeTripodMatStage extends GraphStageWithMaterializedValue[TripodShape[String, String, List[Int]], Future[Done]] { 55 | 56 | // Stage syntax 57 | val stringsIn: Inlet[String] = Inlet("BloomFilterCrossMatStage.stringsIn") 58 | val searchesIn: Inlet[String] = Inlet("BloomFilterCrossMatStage.searchesIn") 59 | val matchesOut: Outlet[List[Int]] = Outlet("BloomFilterCrossMatStage.matchesOut") 60 | override val shape: TripodShape[String, String, List[Int]] = TripodShape(stringsIn, searchesIn, matchesOut) 61 | 62 | // Stage semantics 63 | override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { 64 | // Completion notification 65 | val p: Promise[Done] = Promise() 66 | 67 | val logic = new GraphStageLogic(shape) { 68 | // State 69 | val sTree = new GeneralizedSuffixTree 70 | var index = 0 71 | 72 | // stringsIn effectively is a Sink, 73 | // so we need to kick it off 74 | override def preStart(): Unit = pull(stringsIn) 75 | 76 | setHandler(stringsIn, new InHandler { 77 | override def onPush(): Unit = { 78 | val elem = grab(stringsIn) 79 | println(s"Getting $elem") 80 | sTree.put(elem, index) 81 | index += 1 82 | pull(stringsIn) 83 | } 84 | 85 | override def onUpstreamFinish(): Unit = { 86 | completeStage() 87 | } 88 | }) 89 | 90 | setHandler(searchesIn, new InHandler { 91 | override def onPush(): Unit = { 92 | val s = grab(searchesIn) 93 | import scala.collection.JavaConverters._ 94 | val mi : Iterable[Integer] = sTree.search(s).asScala 95 | val m = mi.toList.map{_.toInt} 96 | if (isAvailable(matchesOut)) 97 | push(matchesOut, m) 98 | } 99 | 100 | override def onUpstreamFinish(): Unit = { 101 | p.trySuccess(Done) // we are done when no more queries 102 | completeStage() 103 | } 104 | 105 | override def onUpstreamFailure(ex: Throwable): Unit = { 106 | p.tryFailure(ex) 107 | failStage(ex) 108 | } 109 | }) 110 | 111 | setHandler(matchesOut, new OutHandler { 112 | override def onPull(): Unit = { 113 | if (!hasBeenPulled(searchesIn)) 114 | pull(searchesIn) 115 | } 116 | }) 117 | 118 | } 119 | (logic, p.future) 120 | } 121 | } 122 | 123 | object SuffixTreeTripodMatStageMatMain extends App { 124 | implicit val system = ActorSystem() 125 | implicit val mat = ActorMaterializer() 126 | 127 | val tripodStage = new SuffixTreeTripodMatStage 128 | 129 | val gen = List("aaagtc", "aaaggtc", "aaaatttccdg", "aaggtgta", "abb", "ggggttaacca", "attgttaca", "gttacgggga") 130 | val sequence = List.fill(12)(gen).flatten 131 | 132 | println(sequence) 133 | 134 | val graph = RunnableGraph.fromGraph(GraphDSL.create(Sink.foreach(println)) { implicit builder => outMatches => 135 | import GraphDSL.Implicits._ 136 | val inStrings = Source.fromIterator(() => sequence.toIterator).throttle(1, Duration(100, "millisecond"), 1, ThrottleMode.shaping) 137 | val inSearches = Source.repeat(1).take(10).map(_ => "aaa").throttle(1, Duration(300, "millisecond"), 1, ThrottleMode.shaping) 138 | 139 | val tripod = builder.add(tripodStage) 140 | 141 | inStrings ~> tripod.in1; tripod.out ~> outMatches 142 | inSearches ~> tripod.in2 143 | ClosedShape 144 | }).run() 145 | 146 | import scala.concurrent.ExecutionContext.Implicits.global 147 | 148 | graph.onComplete(_ => system.terminate()) 149 | Await.ready(system.whenTerminated, Duration.Inf) 150 | } 151 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/GaussianEstimator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import scala.math.{ sqrt, Pi, pow, exp, max } 21 | import org.apache.spark.streamdm.utils.Statistics 22 | 23 | /** 24 | * Gaussian incremental estimator that uses incremental method, more resilient 25 | * to floating point imprecision. 26 | * For more info see Donald Knuth's "The Art of Computer Programming, Volume 2: 27 | * Seminumerical Algorithms", section 4.2.2. 28 | */ 29 | 30 | class GaussianEstimator(var weightSum: Double = 0.0, var mean: Double = 0.0, 31 | var varianceSum: Double = 0.0) extends Serializable { 32 | val normal_constant: Double = sqrt(2 * Pi) 33 | var blockWeightSum: Double = 0.0 34 | var blockMean: Double = 0.0 35 | var blockVarianceSum: Double = 0.0 36 | 37 | def this(that: GaussianEstimator) { 38 | this(that.weightSum, that.mean, that.varianceSum) 39 | } 40 | /** 41 | * Observe the data and update the Gaussian estimator 42 | * 43 | * @param value value of a feature 44 | * @param weight weight of the Example 45 | */ 46 | def observe(value: Double, weight: Double): Unit = { 47 | if (!value.isInfinite() && !value.isNaN() && weight > 0) { 48 | if (blockWeightSum == 0) { 49 | blockMean = value 50 | blockWeightSum = weight 51 | } else { 52 | blockWeightSum += weight 53 | val lastMean = blockMean 54 | blockMean += weight * (value - lastMean) / blockWeightSum 55 | blockVarianceSum += weight * (value - lastMean) * (value - blockMean) 56 | } 57 | } 58 | } 59 | /** 60 | * Merge current GaussianEstimator with another one. 61 | * 62 | * @param that the GaussianEstimator to be merged 63 | * @param trySplit flag indicating whether the Hoeffding Tree tries to split 64 | * @return the new GaussianEstimator 65 | */ 66 | def merge(that: GaussianEstimator, trySplit: Boolean): GaussianEstimator = { 67 | if (!trySplit) { 68 | //add to block variables 69 | if (this.blockWeightSum == 0) { 70 | blockWeightSum = that.blockWeightSum 71 | blockMean = that.blockMean 72 | blockVarianceSum = that.blockVarianceSum 73 | } else { 74 | val newBlockWeightSum = blockWeightSum + that.blockWeightSum 75 | val newBlockMean = (this.blockMean * blockWeightSum + that.blockMean * 76 | that.blockWeightSum) / newBlockWeightSum 77 | val newBlockVarianceSum = this.blockVarianceSum + that.blockVarianceSum + 78 | pow(this.blockMean - that.blockMean, 2) * 79 | this.blockWeightSum * that.blockWeightSum / (this.blockWeightSum + that.blockWeightSum) 80 | blockWeightSum = newBlockWeightSum 81 | blockMean = newBlockMean 82 | blockVarianceSum = newBlockVarianceSum 83 | } 84 | } else { 85 | //add to the total variables 86 | if (this.weightSum == 0) { 87 | weightSum = that.blockWeightSum 88 | mean = that.blockMean 89 | varianceSum = that.blockVarianceSum 90 | } else { 91 | val newWeightSum = weightSum + that.blockWeightSum 92 | val newMean = (this.mean * weightSum + that.blockMean * that.blockWeightSum) / 93 | newWeightSum 94 | val newVarianceSum = this.varianceSum + that.blockVarianceSum + pow(this.mean - 95 | that.blockMean, 2) * this.weightSum * that.blockWeightSum / 96 | (this.weightSum + that.blockWeightSum) 97 | weightSum = newWeightSum 98 | mean = newMean 99 | varianceSum = newVarianceSum 100 | } 101 | } 102 | this 103 | } 104 | 105 | /** 106 | * Returns the total weight 107 | * 108 | * @return the total weight 109 | */ 110 | def totalWeight(): Double = { 111 | weightSum 112 | } 113 | /** 114 | * Returns the mean value 115 | * 116 | * @return the mean value 117 | */ 118 | def getMean(): Double = { 119 | mean 120 | } 121 | /** 122 | * Returns the standard deviation 123 | * 124 | * @return the standard deviation 125 | */ 126 | def stdDev(): Double = { 127 | sqrt(variance()) 128 | } 129 | 130 | /** 131 | * Returns the variance 132 | * 133 | * @return the variance 134 | */ 135 | def variance(): Double = { 136 | if (weightSum <= 1.0) 0 137 | else varianceSum / (weightSum - 1) 138 | } 139 | 140 | /** 141 | * Returns the cumulative probability of the input value in the current 142 | * distribution. 143 | * 144 | * @param value the value 145 | * @return the cumulative probability 146 | */ 147 | 148 | def probabilityDensity(value: Double): Double = { 149 | if (weightSum == 0) 0.0 150 | else { 151 | val stddev = stdDev() 152 | if (stddev > 0) { 153 | val diff = value - mean 154 | exp(-pow(diff / stddev, 2) / 2) / (normal_constant * stddev) 155 | } else { 156 | if (value == mean) 1.0 else 0 157 | } 158 | } 159 | } 160 | 161 | /** 162 | * Returns an array of weights which have the sum less than, equal to, and 163 | * greater than the split value. 164 | * 165 | * @param splitValue the value of the split 166 | * @return the resulting Array of values 167 | */ 168 | def tripleWeights(splitValue: Double): Array[Double] = { 169 | //equal weights sum 170 | val eqWeight = probabilityDensity(splitValue) * weightSum 171 | //less than weights sum 172 | val lsWeight = { 173 | if (stdDev() > 0) { 174 | Statistics.normalProbability((splitValue - getMean()) / stdDev()) 175 | } else { 176 | if (splitValue < getMean()) weightSum - eqWeight 177 | else 0.0 178 | } 179 | } 180 | //greater than weights sum 181 | val gtWeight = max(0, weightSum - eqWeight - lsWeight) 182 | Array[Double](lsWeight, eqWeight, gtWeight) 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/TextInstance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | /** 21 | * A TextInstance is an Instance in which the features are a map of text keys, 22 | * each associated with a value. 23 | */ 24 | 25 | case class TextInstance(inFeatures: Map[String, Double]) 26 | extends Instance with Serializable { 27 | 28 | type T = TextInstance 29 | 30 | val features = inFeatures 31 | 32 | /* Get the feature value for a given index 33 | * 34 | * @param index the key of the features 35 | * @return a Double representing the feature value 36 | */ 37 | def apply(index: Int): Double = 38 | valueAt(index.toString) 39 | 40 | /* 41 | * Return an array of features and indexes 42 | * 43 | * @return an array of turple2(value,index) 44 | */ 45 | def getFeatureIndexArray(): Array[(Double, Int)] = features.toArray.map{x =>(x._2,x._1.toInt)} 46 | 47 | /* Get the feature value for a given key 48 | * 49 | * @param key the key of the features 50 | * @return a Double representing the feature value 51 | */ 52 | def valueAt(key: String): Double = 53 | features.getOrElse(key, 0.0) 54 | 55 | /* Perform a dot product between two instances 56 | * 57 | * @param input an Instance with which the dot 58 | * product is performed 59 | * @return a Double representing the dot product 60 | */ 61 | override def dot(input: Instance): Double = input match { 62 | case TextInstance(f) => 63 | dotTupleArrays(f.toArray, features.toArray) 64 | case _ => 0.0 65 | } 66 | 67 | /** 68 | * Perform an element by element addition between two instances 69 | * 70 | * @param input an Instance which is added up 71 | * @return an Instance representing the added Instances 72 | */ 73 | override def add(input: Instance): TextInstance = input match { 74 | case TextInstance(f) => { 75 | val addedInstance = addTupleArrays(f.toArray, features.toArray) 76 | new TextInstance(arrayToMap(addedInstance)) 77 | } 78 | case _ => this 79 | } 80 | 81 | /** 82 | * Perform an element by element multiplication between two instances 83 | * 84 | * @param input an Instance which is multiplied 85 | * @return an Instance representing the Hadamard product 86 | */ 87 | override def hadamard(input: Instance): TextInstance = input match { 88 | case TextInstance(f) => { 89 | val addedInstance = mulTupleArrays(f.toArray, features.toArray) 90 | new TextInstance(arrayToMap(addedInstance)) 91 | } 92 | case _ => this 93 | 94 | } 95 | 96 | /** 97 | * Compute the Euclidean distance to another Instance 98 | * 99 | * @param input the Instance to which the distance is computed 100 | * @return a Double representing the distance value 101 | */ 102 | override def distanceTo(input: Instance): Double = input match { 103 | case TextInstance(f) => { 104 | var sum: Double = 0.0 105 | for ((k, v) <- f) 106 | if (v != 0) sum += math.pow(valueAt(k) - v, 2.0) 107 | for ((k, v) <- features) 108 | if (f.getOrElse(k, 0.0) == 0) sum += math.pow(v, 2.0) 109 | math.sqrt(sum) 110 | } 111 | case _ => Double.MaxValue 112 | } 113 | 114 | /** 115 | * Append a feature to the instance 116 | * 117 | * @param key the key on which the feature is set 118 | * @param value the value on which the feature is set 119 | * @return an Instance representing the new feature vector 120 | */ 121 | def setFeature(key: String, value: Double): TextInstance = 122 | new TextInstance((features - key) + (key -> value)) 123 | 124 | /** 125 | * Append a feature to the instance 126 | * 127 | * @param index the index at which the feature is set 128 | * @param input the new value of the feature 129 | * @return an Instance representing the new feature vector 130 | */ 131 | def set(index: Int, input: Double): TextInstance = 132 | setFeature(index.toString, input) 133 | 134 | /** 135 | * Apply an operation to every feature of the Instance (essentially a map) 136 | * TODO try to extend map to this case 137 | * @param func the function for the transformation 138 | * @return a new Instance with the transformed features 139 | */ 140 | override def map(func: Double => Double): TextInstance = 141 | new TextInstance(features.mapValues { case x => func(x) }) 142 | 143 | /** 144 | * Aggregate the values of an instance 145 | * 146 | * @param func the function for the transformation 147 | * @return the reduced value 148 | */ 149 | override def reduce(func: (Double, Double) => Double): Double = 150 | features.valuesIterator.reduce(func) 151 | 152 | private def dotTupleArrays(l1: Array[(String, Double)], 153 | l2: Array[(String, Double)]): Double = 154 | (l1 ++ l2).groupBy(_._1).filter { case (k, v) => v.length == 2 }. 155 | map { case (k, v) => (k, v.map(_._2).reduce(_ * _)) }.toArray.unzip._2.sum 156 | 157 | private def addTupleArrays(l1: Array[(String, Double)], 158 | l2: Array[(String, Double)]): Array[(String, Double)] = 159 | (l1 ++ l2).groupBy(_._1).map { case (k, v) => (k, v.map(_._2).sum) }. 160 | toArray.filter(_._2 != 0) 161 | 162 | private def mulTupleArrays(l1: Array[(String, Double)], 163 | l2: Array[(String, Double)]): Array[(String, Double)] = 164 | (l1 ++ l2).groupBy(_._1).map { case (k, v) => (k, v.map(_._2).product) }. 165 | toArray.filter(_._2 != 0) 166 | 167 | private def arrayToMap(l: Array[(String, Double)]): Map[String, Double] = 168 | l.groupBy(_._1).map { case (k, v) => (k, v.head._2) } 169 | 170 | } 171 | 172 | object TextInstance extends Serializable { 173 | 174 | /** 175 | * Parse the input string as an SparseInstance class 176 | * 177 | * @param input the String line to be read 178 | * @return a DenseInstance which is parsed from input 179 | */ 180 | def parse(input: String): TextInstance = { 181 | val tokens = input.split(",") 182 | val features = tokens.map(_.split(":")) 183 | val featMap = features.groupBy(_.head).map { 184 | case (k, v) => (k, 185 | v.head.tail.head.toDouble) 186 | } 187 | new TextInstance(featMap) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/specification/SpecificationParser.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Portions Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | /* 19 | This file has been changed by gosubpl 20 | */ 21 | 22 | 23 | package org.apache.spark.streamdm.core.specification 24 | 25 | import scala.io.Source 26 | import org.apache.spark.streamdm.classifiers.trees.Utils.{ arraytoString } 27 | 28 | /* 29 | * class SpecificationParser helps to generate head for data 30 | */ 31 | class SpecificationParser { 32 | 33 | /* 34 | * Get string head from ExampleSpecification by head type, which will be saved to head file. 35 | * 36 | * @param spec ExampleSpecification 37 | * @param t data head type 38 | * @return string of head 39 | * 40 | */ 41 | def getHead(spec: ExampleSpecification, t: String = "arff"): String = { 42 | if (t.equalsIgnoreCase("arff")) 43 | toArff(spec) 44 | else if (t.equalsIgnoreCase("csv")) 45 | toCsv(spec) 46 | else 47 | toArff(spec) 48 | } 49 | 50 | /* 51 | * Get ExampleSpecification from head file by head type. 52 | * 53 | * @param fileName name of head file 54 | * @param t data head type 55 | * @return ExampleSpecification of data 56 | * 57 | */ 58 | def getSpecification(fileName: String, t: String = "arff"): ExampleSpecification = { 59 | if (t.equalsIgnoreCase("arff")) 60 | fromArff(fileName) 61 | else if (t.equalsIgnoreCase("csv")) 62 | fromCsv(fileName) 63 | else fromArff(fileName) 64 | } 65 | 66 | def toArff(spec: ExampleSpecification): String = { 67 | val sb = new StringBuffer() 68 | sb.append("@relation sample-data\n") 69 | val inputIS = spec.in 70 | val outputIs = spec.out 71 | val atr = "@attribute" 72 | val nu = "numeric" 73 | // add arff attributes of input 74 | for (index <- 0 until inputIS.size()) { 75 | val featureName: String = inputIS.name(index) 76 | val featureSpec: FeatureSpecification = inputIS(index) 77 | 78 | val line = featureSpec match { 79 | case numeric: NumericFeatureSpecification => { nu } 80 | case nominal: NominalFeatureSpecification => { arraytoString(nominal.values) } 81 | } 82 | sb.append(atr + " " + featureName + " " + line + "\n") 83 | } 84 | // add arff attributes of outnput 85 | sb.append(atr + " " + outputIs.name(0) + " " + 86 | arraytoString(outputIs(0).asInstanceOf[NominalFeatureSpecification].values)) 87 | sb.toString() 88 | } 89 | 90 | def fromArff(fileName: String): ExampleSpecification = { 91 | val lines = Source.fromFile(fileName).getLines() 92 | var line: String = lines.next() 93 | while (line == null || line.length() == 0 || line.startsWith(" ") || 94 | line.startsWith("%") || "@relation".equalsIgnoreCase(line.substring(0, 9))) { 95 | line = lines.next() 96 | } 97 | var finished: Boolean = false 98 | var index: Int = 0 99 | val inputIS = new InstanceSpecification() 100 | val outputIS = new InstanceSpecification() 101 | while (!finished && line.startsWith("@")) { 102 | if ("@data".equalsIgnoreCase(line.substring(0, 5))) { 103 | finished = true 104 | } else if ("@attribute".equalsIgnoreCase(line.substring(0, 10))) { 105 | 106 | val featureInfos: Array[String] = line.split("\\s+") 107 | val name: String = featureInfos(1) 108 | if (!isArffNumeric(featureInfos(2))) { 109 | val featurevalues: Array[String] = featureInfos(2).substring( 110 | featureInfos(2).indexOf("{") + 1, featureInfos(2).indexOf("}")). 111 | trim().split(",[\\s]?") 112 | 113 | val fSpecification = new NominalFeatureSpecification(featurevalues) 114 | 115 | inputIS.addFeatureSpecification(index, "Norminal" + index, fSpecification) 116 | } else { 117 | inputIS.addFeatureSpecification(index, "Numeric" + index) 118 | } 119 | index += 1 120 | } 121 | if (lines.hasNext) 122 | line = lines.next() 123 | else 124 | finished = true 125 | } 126 | val fSpecification: FeatureSpecification = inputIS(index - 1) 127 | outputIS.addFeatureSpecification(0, "class", fSpecification) 128 | inputIS.removeFeatureSpecification(index - 1) 129 | new ExampleSpecification(inputIS, outputIS) 130 | 131 | } 132 | 133 | def fromArffStrings(arffData: Seq[String]): ExampleSpecification = { 134 | val lines = arffData.iterator 135 | var line: String = lines.next() 136 | while (line == null || line.length() == 0 || line.startsWith(" ") || 137 | line.startsWith("%") || "@relation".equalsIgnoreCase(line.substring(0, 9))) { 138 | line = lines.next() 139 | } 140 | var finished: Boolean = false 141 | var index: Int = 0 142 | val inputIS = new InstanceSpecification() 143 | val outputIS = new InstanceSpecification() 144 | while (!finished && line.startsWith("@")) { 145 | if ("@data".equalsIgnoreCase(line.substring(0, 5))) { 146 | finished = true 147 | } else if ("@attribute".equalsIgnoreCase(line.substring(0, 10))) { 148 | 149 | val featureInfos: Array[String] = line.split("\\s+") 150 | val name: String = featureInfos(1) 151 | if (!isArffNumeric(featureInfos(2))) { 152 | val featurevalues: Array[String] = featureInfos(2).substring( 153 | featureInfos(2).indexOf("{") + 1, featureInfos(2).indexOf("}")). 154 | trim().split(",[\\s]?") 155 | 156 | val fSpecification = new NominalFeatureSpecification(featurevalues) 157 | 158 | inputIS.addFeatureSpecification(index, "Norminal" + index, fSpecification) 159 | } else { 160 | inputIS.addFeatureSpecification(index, "Numeric" + index) 161 | } 162 | index += 1 163 | } 164 | if (lines.hasNext) 165 | line = lines.next() 166 | else 167 | finished = true 168 | } 169 | val fSpecification: FeatureSpecification = inputIS(index - 1) 170 | outputIS.addFeatureSpecification(0, "class", fSpecification) 171 | inputIS.removeFeatureSpecification(index - 1) 172 | new ExampleSpecification(inputIS, outputIS) 173 | 174 | } 175 | 176 | def isArffNumeric(t: String): Boolean = { 177 | if ("numeric".equalsIgnoreCase(t)) true 178 | else if ("integer".equalsIgnoreCase(t)) true 179 | else if ("real".equalsIgnoreCase(t)) true 180 | else false 181 | } 182 | 183 | def toCsv(spec: ExampleSpecification): String = { 184 | //todo 185 | "" 186 | } 187 | 188 | def fromCsv(fileName: String): ExampleSpecification = { 189 | //todo 190 | null 191 | } 192 | 193 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/DenseInstance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | import math._ 21 | 22 | /** 23 | * A DenseInstance is an Instance in which the features are dense, i.e., there 24 | * exists a value for (almost) every feature. 25 | * The DenseInstance will keep an Array of the values and the 26 | * corresponding operations will be based on this data structure. 27 | */ 28 | case class DenseInstance(inVector: Array[Double]) 29 | extends Instance with Serializable { 30 | 31 | type T = DenseInstance 32 | 33 | val features = inVector 34 | 35 | /* Get the feature value present at position index 36 | * 37 | * @param index the index of the desired value 38 | * @return a Double representing the feature value 39 | */ 40 | override def apply(index: Int): Double = 41 | if (index >= features.length || index < 0) 0.0 else features(index) 42 | 43 | /* 44 | * Return an array of features and indexes 45 | * 46 | * @return an array of turple2(value,index) 47 | */ 48 | def getFeatureIndexArray(): Array[(Double, Int)] = features.zipWithIndex 49 | 50 | /* Perform a dot product between two instances 51 | * 52 | * @param input an DenseSingleLabelInstance with which the dot 53 | * product is performed 54 | * @return a Double representing the dot product 55 | */ 56 | override def dot(input: Instance): Double = input match { 57 | case DenseInstance(f) => { 58 | var sum: Double = 0.0 59 | var i: Int = 0 60 | while (i < features.length) { 61 | sum += f(i) * this(i) 62 | i += 1 63 | } 64 | sum 65 | } 66 | //using imperative version for efficiency reasons 67 | //normally it should be implemented as below 68 | // (0 until features.length).foldLeft(0.0)((d,i) => d + features(i)*f(i)) 69 | case SparseInstance(i, v) => 70 | input.dot(this) 71 | case _ => 0.0 72 | } 73 | 74 | /** 75 | * Perform an element by element addition between two instances 76 | * 77 | * @param input an Instance which is added up 78 | * @return an Instance representing the added Instances 79 | */ 80 | override def add(input: Instance): DenseInstance = input match { 81 | case DenseInstance(f) => { 82 | var newF: Array[Double] = Array.fill(features.length)(0.0) 83 | var i: Int = 0 84 | while (i < features.length) { 85 | newF(i) = features(i) + f(i) 86 | i += 1 87 | } 88 | new DenseInstance(newF) 89 | } 90 | //val addedInstance = (0 until features.length).map(i => features(i)+f(i)) 91 | //new DenseSingleLabelInstance(addedInstance.toArray, label) 92 | case SparseInstance(ind, v) => { 93 | var newF: Array[Double] = Array.fill(features.length)(0.0) 94 | var i: Int = 0 95 | while (i < features.length) { 96 | newF(i) = features(i) 97 | i += 1 98 | } 99 | i = 0 100 | while (i < ind.length) { 101 | newF(ind(i)) += v(i) 102 | i += 1 103 | } 104 | new DenseInstance(newF) 105 | } 106 | case _ => new DenseInstance(features) 107 | } 108 | 109 | /** 110 | * Perform an element by element multiplication between two instances 111 | * 112 | * @param input an Instance which is multiplied 113 | * @return an Instance representing the Hadamard product 114 | */ 115 | override def hadamard(input: Instance): DenseInstance = input match { 116 | case DenseInstance(f) => { 117 | var newF: Array[Double] = Array.fill(features.length)(0.0) 118 | var i: Int = 0 119 | while (i < features.length) { 120 | newF(i) = features(i) * f(i) 121 | i += 1 122 | } 123 | new DenseInstance(newF) 124 | } 125 | case SparseInstance(ind, v) => { 126 | var newF: Array[Double] = Array.fill(features.length)(0.0) 127 | var i: Int = 0 128 | while (i < ind.length) { 129 | newF(ind(i)) = features(ind(i)) * v(i) 130 | i += 1 131 | } 132 | new DenseInstance(newF) 133 | } 134 | case _ => new DenseInstance(features) 135 | } 136 | 137 | /** 138 | * Compute the Euclidean distance to another Instance 139 | * 140 | * @param input the Instance to which the distance is computed 141 | * @return a Double representing the distance value 142 | */ 143 | override def distanceTo(input: Instance): Double = input match { 144 | case DenseInstance(f) => { 145 | var sum: Double = 0.0 146 | var i: Int = 0 147 | while (i < features.length) { 148 | sum += math.pow(features(i) - f(i), 2.0) 149 | i += 1 150 | } 151 | math.sqrt(sum) 152 | } 153 | case SparseInstance(ind, v) => { 154 | var sum: Double = 0.0 155 | var i: Int = 0 156 | while (i < ind.length) { 157 | if (v(i) != 0) sum += math.pow(features(ind(i)) - v(i), 2.0) 158 | i += 1 159 | } 160 | i = 0 161 | while (i < features.length) { 162 | if (input(i) == 0) sum += math.pow(features(i), 2.0) 163 | i += 1 164 | } 165 | math.sqrt(sum) 166 | } 167 | case _ => Double.MaxValue 168 | } 169 | 170 | /** 171 | * Add a feature to the instance 172 | * 173 | * @param index the index at which the value is added 174 | * @param input the feature value which is added up 175 | * @return an Instance representing the new feature vector 176 | */ 177 | override def set(index: Int, input: Double): DenseInstance = { 178 | var returnInstance = this 179 | if (index >= 0 && index < features.length) { 180 | features(index) = input 181 | returnInstance = new DenseInstance(features) 182 | } else if (index == features.length) 183 | returnInstance = new DenseInstance(features :+ input) 184 | returnInstance 185 | } 186 | 187 | /** 188 | * Apply an operation to every feature of the Instance 189 | * @param func the function for the transformation 190 | * @return a new Instance with the transformed features 191 | */ 192 | override def map(func: Double => Double): DenseInstance = 193 | new DenseInstance(features.map { case x => func(x) }) 194 | 195 | /** 196 | * Aggregate the values of an instance 197 | * 198 | * @param func the function for the transformation 199 | * @return the aggregated value 200 | */ 201 | override def reduce(func: (Double, Double) => Double): Double = 202 | features.reduce(func) 203 | 204 | override def toString = features.mkString(",") 205 | } 206 | 207 | object DenseInstance extends Serializable { 208 | 209 | /** 210 | * Parse the input string as an DenseInstance class, where each features is 211 | * separated by a comma (CSV format). 212 | * 213 | * @param input the String line to be read 214 | * @return a DenseInstance which is parsed from input 215 | */ 216 | def parse(input: String): DenseInstance = { 217 | val tokens = input.split(",") 218 | new DenseInstance(tokens.map(_.toDouble)) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/core/SparseInstance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.core 19 | 20 | import math._ 21 | 22 | /** 23 | * A SparseInstance is an Instance in which the features are sparse, i.e., most 24 | * features will not have any value. 25 | * The SparseInstance will keep two Arrays: one with the values and one with the 26 | * corresponding indexes. The implementation will be based on these two data 27 | * structures. 28 | */ 29 | 30 | case class SparseInstance(inIndexes:Array[Int], inValues:Array[Double]) 31 | extends Instance with Serializable { 32 | 33 | type T = SparseInstance 34 | 35 | val indexes = inIndexes 36 | val values = inValues 37 | 38 | /* Get the value present at position index 39 | * 40 | * @param index the index of the features 41 | * @return a Double representing the value, or 0.0 if not found 42 | */ 43 | def apply(index: Int): Double = { 44 | var i: Int = 0 45 | var value: Double = 0.0 46 | var found = false 47 | while(i { 87 | var i: Int = 0 88 | var addedFeatures: Array[Double] = Array() 89 | var addedIndexes: Array[Int] = Array() 90 | while(i { 110 | var i: Int = 0 111 | var addedFeatures: Array[Double] = Array() 112 | var addedIndexes: Array[Int] = Array() 113 | while(i new SparseInstance(indexes, values) 124 | } 125 | 126 | /** Perform an element by element multiplication between two instances 127 | * 128 | * @param input an Instance which is multiplied 129 | * @return a SparseInstance representing the Hadamard product 130 | */ 131 | override def hadamard(input: Instance): SparseInstance = input match { 132 | case SparseInstance(ind,v) => { 133 | var i: Int = 0 134 | var addedFeatures: Array[Double] = Array() 135 | var addedIndexes: Array[Int] = Array() 136 | while(i { 147 | var i: Int = 0 148 | var addedFeatures: Array[Double] = Array() 149 | var addedIndexes: Array[Int] = Array() 150 | while(i new SparseInstance(indexes, values) 161 | } 162 | 163 | /** Compute the Euclidean distance to another Instance 164 | * 165 | * @param input the Instance to which the distance is computed 166 | * @return a Double representing the distance value 167 | */ 168 | override def distanceTo(input: Instance): Double = input match { 169 | case SparseInstance(ind,v) => { 170 | var i: Int = 0 171 | var sum: Double = 0.0 172 | while(i input.distanceTo(this) 185 | case _ => Double.MaxValue 186 | } 187 | 188 | /** Append a feature to the instance 189 | * 190 | * @param input the value which is added up 191 | * @return a SparseInstance representing the new feature vector 192 | */ 193 | override def set(index: Int, input: Double): SparseInstance = 194 | new SparseInstance(indexes:+index,values:+input) 195 | 196 | /** Apply an operation to every feature of the Instance (essentially a map) 197 | * @param func the function for the transformation 198 | * @return a new SparseInstance with the transformed features 199 | */ 200 | override def map(func: Double=>Double): SparseInstance = 201 | new SparseInstance(indexes, values.map{case x => func(x)}) 202 | 203 | /** Aggregate the values of an instance 204 | * 205 | * @param func the function for the transformation 206 | * @return the aggregated value 207 | */ 208 | override def reduce(func: (Double,Double)=>Double): Double = 209 | values.reduce(func) 210 | 211 | override def toString = (indexes zip values).map{ case (i,v) => 212 | "%d:%f".format(i+1,v)}.mkString(",") 213 | 214 | } 215 | 216 | object SparseInstance extends Serializable { 217 | 218 | /** Parse the input string as an SparseInstance class, in LibSVM 219 | * comma-separated format, where each feature is of the form "i:v" where i is 220 | * the index of the feature (starting at 1), and v is the value of the 221 | * feature. 222 | * 223 | * @param input the String line to be read, in LibSVM format 224 | * @return a DenseInstance which is parsed from input 225 | */ 226 | def parse(input: String): SparseInstance = { 227 | val tokens = input.split(",") 228 | val features = tokens.map(_.split(":")) 229 | new SparseInstance(features.map(_(0).toInt-1),features.map(_(1).toDouble)) 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/ConditionalTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import org.apache.spark.streamdm.core.{ Example } 21 | 22 | /** 23 | * ConditionalTest is an abstract class for conditional tests, used for 24 | * splitting nodes in Hoeffding trees. 25 | */ 26 | abstract class ConditionalTest(var fIndex: Int) extends Serializable { 27 | 28 | /** 29 | * Returns the number of the branch for an example, -1 if unknown. 30 | * 31 | * @param example the input Example 32 | * @return the number of the branch for an example, -1 if unknown. 33 | */ 34 | def branch(example: Example): Int 35 | 36 | /** 37 | * Gets whether the number of the branch for an example is known. 38 | * 39 | * @param example the input Example 40 | * @return true if the number of the branch for an example is known 41 | */ 42 | def hasResult(example: Example): Boolean = { branch(example) >= 0 } 43 | 44 | /** 45 | * Gets the number of maximum branches, -1 if unknown. 46 | * 47 | * @return the number of maximum branches, -1 if unknown. 48 | */ 49 | def maxBranches(): Int 50 | 51 | /** 52 | * Returns the index of the feature 53 | * 54 | * @return the index of the feature 55 | */ 56 | def featureIndex(): Int = fIndex 57 | 58 | /** 59 | * Get the conditional test description. 60 | * 61 | * @return an Array containing the description 62 | */ 63 | def description(): Array[String] 64 | 65 | } 66 | 67 | /** 68 | * Numeric binary conditional test for splitting nodes in Hoeffding trees. 69 | */ 70 | 71 | class NumericBinaryTest(fIndex: Int, val value: Double, val isequalTest: Boolean) 72 | extends ConditionalTest(fIndex) with Serializable { 73 | 74 | /** 75 | * Returns the number of the branch for an example, -1 if unknown. 76 | * 77 | * @param example the input Example. 78 | * @return the number of the branch for an example, -1 if unknown. 79 | */ 80 | override def branch(example: Example): Int = { 81 | // todo process missing value 82 | val v = example.featureAt(fIndex) 83 | if (isequalTest) { 84 | if (v == value) 0 else 1 85 | } else { 86 | if (v < value) 0 else 1 87 | } 88 | } 89 | 90 | /** 91 | * Gets the number of maximum branches, -1 if unknown. 92 | * 93 | * @return the number of maximum branches, -1 if unknown.. 94 | */ 95 | override def maxBranches(): Int = 2 96 | 97 | /** 98 | * Returns the index of the tested feature 99 | * 100 | * @return the index of the tested feature 101 | */ 102 | override def featureIndex(): Int = fIndex 103 | 104 | /** 105 | * Get the conditional test description. 106 | * 107 | * @return an Array containing the description 108 | */ 109 | override def description(): Array[String] = { 110 | val des = new Array[String](2) 111 | val ops = if (isequalTest) Array("==", "!=") else Array("<", ">=") 112 | des(0) = "[feature " + fIndex + " numeric 0] " + ops(0) + " " + value 113 | des(1) = "[feature " + fIndex + " numeric 1] " + ops(1) + " " + value 114 | des 115 | } 116 | 117 | override def toString = "NumericBinaryTest(" + isequalTest + ") feature[" + fIndex + "] = " + 118 | value 119 | } 120 | /** 121 | * Nominal binary conditional test for splitting nodes in Hoeffding trees. 122 | */ 123 | class NominalBinaryTest(fIndex: Int, val value: Double) 124 | extends ConditionalTest(fIndex) with Serializable { 125 | 126 | /** 127 | * Returns the number of the branch for an example, -1 if unknown. 128 | * 129 | * @param example the input example 130 | * @return the number of the branch for an example, -1 if unknown. 131 | */ 132 | override def branch(example: Example): Int = { 133 | // todo process missing value 134 | if (example.featureAt(fIndex) == value) 0 else 1 135 | } 136 | 137 | /** 138 | * Gets the number of maximum branches, -1 if unknown. 139 | * 140 | * @return the number of maximum branches, -1 if unknown. 141 | */ 142 | override def maxBranches(): Int = 2 143 | 144 | override def toString(): String = { 145 | "NominalBinaryTest feature[" + fIndex + "] = " + value 146 | 147 | } 148 | 149 | /** 150 | * Get the conditional test description. 151 | * 152 | * @return an Array containing the description 153 | */ 154 | override def description(): Array[String] = { 155 | val des = new Array[String](2) 156 | des(0) = "[feature " + fIndex + " nominal 0] == " + value 157 | des(1) = "[feature " + fIndex + " nominal 1] != " + value 158 | des 159 | } 160 | } 161 | 162 | /** 163 | * Nominal multi-way conditional test for splitting nodes in Hoeffding trees. 164 | */ 165 | class NominalMultiwayTest(fIndex: Int, val numFeatureValues: Int) 166 | extends ConditionalTest(fIndex) with Serializable { 167 | 168 | /** 169 | * Returns the number of the branch for an example, -1 if unknown. 170 | * 171 | * @param example the input example 172 | * @return the number of the branch for an example, -1 if unknown. 173 | */ 174 | override def branch(example: Example): Int = { 175 | // todo process missing value 176 | example.featureAt(fIndex).toInt 177 | } 178 | 179 | /** 180 | * Gets the number of maximum branches, -1 if unknown. 181 | * 182 | * @return the number of maximum branches, -1 if unknown. 183 | */ 184 | override def maxBranches(): Int = numFeatureValues 185 | 186 | override def toString(): String = "NominalMultiwayTest" + "feature[" + fIndex + "] " + 187 | numFeatureValues 188 | 189 | /** 190 | * Get the conditional test description. 191 | * 192 | * @return an Array containing the description 193 | */ 194 | override def description(): Array[String] = { 195 | val des = new Array[String](numFeatureValues) 196 | for (i <- 0 until numFeatureValues) 197 | des(i) = "[feature " + fIndex + " nominal " + i + "] == " + i 198 | des 199 | } 200 | } 201 | 202 | /** 203 | * Numeric binary rule predicate test for splitting nodes in Hoeffding trees. 204 | */ 205 | class NumericBinaryRulePredicate(fIndex: Int, val value: Double, val operator: Int) 206 | extends ConditionalTest(fIndex) with Serializable { 207 | 208 | /** 209 | * Returns the number of the branch for an example, -1 if unknown. 210 | * 211 | * @param example the input example 212 | * @return the number of the branch for an example, -1 if unknown. 213 | */ 214 | override def branch(example: Example): Int = { 215 | // todo process missing value 216 | val v = example.featureAt(fIndex) 217 | operator match { 218 | // operator: 0 ==, 1 <=, 2 < different from MOA which is > 219 | case 0 => if (v == value) 0 else 1 220 | case 1 => if (v <= value) 0 else 1 221 | case 2 => if (v < value) 0 else 1 222 | case _ => if (v == value) 0 else 1 223 | } 224 | } 225 | 226 | /** 227 | * Gets the number of maximum branches, -1 if unknown. 228 | * 229 | * @return the number of maximum branches, -1 if unknown. 230 | */ 231 | override def maxBranches(): Int = 2 232 | 233 | /** 234 | * Get the conditional test description. 235 | * 236 | * @return an Array containing the description 237 | */ 238 | override def description(): Array[String] = { 239 | val des = new Array[String](2) 240 | val ops = if (operator == 0) Array("==", "!=") else if (operator == 1) 241 | Array("<=", ">") else Array("<", ">=") 242 | des(0) = "[feature " + fIndex + " numeric 0] " + ops(0) + " " + value 243 | des(1) = "[feature " + fIndex + " numeric 1] " + ops(1) + " " + value 244 | des 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/SplitCriterion.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import scala.math.{ max,min } 21 | 22 | import org.apache.spark.streamdm.utils.Utils.{log2} 23 | 24 | trait SplitCriterionType 25 | 26 | case class InfoGainSplitCriterionType() extends SplitCriterionType 27 | 28 | case class GiniSplitCriterionType() extends SplitCriterionType 29 | 30 | case class VarianceReductionSplitCriterionType() extends SplitCriterionType 31 | 32 | /** 33 | * Trait for computing splitting criteria with respect to distributions of class values. 34 | * The split criterion is used as a parameter on decision trees and decision stumps. 35 | * The two split criteria most used are Information Gain and Gini. 36 | */ 37 | 38 | trait SplitCriterion extends Serializable { 39 | 40 | /** 41 | * Computes the merit of splitting for a given distribution before and after the split. 42 | * 43 | * @param pre the class distribution before the split 44 | * @param post the class distribution after the split 45 | * @return value of the merit of splitting 46 | */ 47 | def merit(pre: Array[Double], post: Array[Array[Double]]): Double 48 | 49 | /** 50 | * Computes the range of splitting merit 51 | * 52 | * @param pre the class distribution before the split 53 | * @return value of the range of splitting merit 54 | */ 55 | def rangeMerit(pre: Array[Double]): Double 56 | 57 | } 58 | 59 | /** 60 | * Class for computing splitting criteria using information gain with respect to 61 | * distributions of class values. 62 | */ 63 | class InfoGainSplitCriterion extends SplitCriterion with Serializable { 64 | 65 | var minBranch: Double = 0.01 66 | 67 | def this(minBranch: Double) { 68 | this() 69 | this.minBranch = minBranch 70 | } 71 | 72 | /** 73 | * Computes the merit of splitting for a given distribution before and after the split. 74 | * 75 | * @param pre the class distribution before the split 76 | * @param post the class distribution after the split 77 | * @return value of the merit of splitting 78 | */ 79 | override def merit(pre: Array[Double], post: Array[Array[Double]]): Double = { 80 | val num = numGTFrac(post, minBranch) 81 | if (numGTFrac(post, minBranch) < 2) Double.NegativeInfinity 82 | else { 83 | val merit = entropy(pre) - entropy(post) 84 | merit 85 | } 86 | } 87 | 88 | /** 89 | * Computes the range of splitting merit 90 | * 91 | * @param pre the class distribution before the split 92 | * @return value of the range of splitting merit 93 | */ 94 | override def rangeMerit(pre: Array[Double]): Double = log2(max(pre.length, 2)) 95 | 96 | /** 97 | * Returns the entropy of a distribution 98 | * 99 | * @param pre an Array containing the distribution 100 | * @return the entropy 101 | */ 102 | def entropy(pre: Array[Double]): Double = { 103 | if (pre == null || pre.sum <= 0 || hasNegative(pre)) 0.0 104 | log2(pre.sum) - pre.filter(_ > 0).map(x => x * log2(x)).sum / pre.sum 105 | } 106 | 107 | /** 108 | * Computes the entropy of an matrix 109 | * 110 | * @param post the matrix as an Array of Array 111 | * @return the entropy 112 | */ 113 | def entropy(post: Array[Array[Double]]): Double = { 114 | if (post == null || post.length == 0 || post(0).length == 0) 0 115 | else { 116 | post.map { row => (row.sum * entropy(row)) }.sum / post.map(_.sum).sum 117 | } 118 | 119 | } 120 | 121 | /** 122 | * Returns number of subsets which have values greater than minFrac 123 | * 124 | * @param post he matrix as an Array of Array 125 | * @param minFrac the min threshold 126 | * @return number of subsets 127 | */ 128 | def numGTFrac(post: Array[Array[Double]], minFrac: Double): Int = { 129 | if (post == null || post.length == 0) { 130 | 0 131 | } else { 132 | val sums = post.map { _.sum } 133 | sums.filter(_ > sums.sum * minFrac).length 134 | } 135 | } 136 | 137 | /** 138 | * Returns whether a array has negative value 139 | * 140 | * @param pre an Array to be valued 141 | * @return whether a array has negative value 142 | */ 143 | private[trees] def hasNegative(pre: Array[Double]): Boolean = pre.filter(x => x < 0).length > 0 144 | 145 | } 146 | 147 | /** 148 | * Class for computing splitting criteria using Gini with respect to 149 | * distributions of class values. 150 | */ 151 | 152 | class GiniSplitCriterion extends SplitCriterion with Serializable { 153 | 154 | /** 155 | * Computes the merit of splitting for a given distribution before and after the split. 156 | * 157 | * @param pre the class distribution before the split 158 | * @param post the class distribution after the split 159 | * @return value of the merit of splitting 160 | */ 161 | override def merit(pre: Array[Double], post: Array[Array[Double]]): Double = { 162 | val sums = post.map(_.sum) 163 | val totalWeight = sums.sum 164 | val ginis: Array[Double] = post.zip(sums).map { 165 | case (x, y) => computeGini(x, y) * y / totalWeight 166 | } 167 | 1.0 - ginis.sum 168 | } 169 | 170 | /** 171 | * Computes the range of splitting merit 172 | * 173 | * @param pre the class distribution before the split 174 | * @return value of the range of splitting merit 175 | */ 176 | override def rangeMerit(pre: Array[Double]): Double = 1.0 177 | 178 | /** 179 | * Computes the gini of an array 180 | * 181 | * @param dist an array to be computed 182 | * @param sum the sum of the array 183 | * @return the gini of an array 184 | */ 185 | private[trees] def computeGini(dist: Array[Double], sum: Double): Double = 186 | 1.0 - dist.map { x => x * x / sum / sum }.sum 187 | 188 | } 189 | 190 | /** 191 | * Class for computing splitting criteria using variance reduction with respect 192 | * to distributions of class values. 193 | */ 194 | class VarianceReductionSplitCriterion extends SplitCriterion with Serializable { 195 | 196 | val magicNumber = 5.0 197 | 198 | /** 199 | * Computes the merit of splitting for a given distribution before and after the split. 200 | * 201 | * @param pre the class distribution before the split 202 | * @param post the class distribution after the split 203 | * @return value of the merit of splitting 204 | */ 205 | override def merit(pre: Array[Double], post: Array[Array[Double]]): Double = { 206 | val count = post.map { row => if (row(0) >= magicNumber) 1 else 0 }.sum 207 | if (count != post.length) 0 208 | else { 209 | var sdr = computeSD(pre) 210 | post.foreach { row => sdr -= (row(0) / pre(0)) * computeSD(row) } 211 | sdr 212 | } 213 | } 214 | 215 | /** 216 | * Computes the range of splitting merit 217 | * 218 | * @param pre the class distribution before the split 219 | * @return value of the range of splitting merit 220 | */ 221 | override def rangeMerit(pre: Array[Double]): Double = 1.0 222 | 223 | /** 224 | * Computes the standard deviation of a distribution 225 | * 226 | * @param pre an Array containing the distribution 227 | * @return the standard deviation 228 | */ 229 | private[trees] def computeSD(pre: Array[Double]): Double = { 230 | val n = pre(0).toInt 231 | val sum = pre(1) 232 | val sumSq = pre(2) 233 | (sumSq - ((sum * sum) / n)) / n 234 | } 235 | } 236 | 237 | object SplitCriterion { 238 | 239 | /** 240 | * Return a new SplitCriterion, by default InfoGainSplitCriterion. 241 | * @param scType the type of the split criterion 242 | * @param minBranch branch parameter 243 | * @return the new SplitCriterion 244 | */ 245 | def createSplitCriterion( 246 | scType: SplitCriterionType, minBranch: Double = 0.01): SplitCriterion = scType match { 247 | case infoGrain: InfoGainSplitCriterionType => new InfoGainSplitCriterion(minBranch) 248 | case gini: GiniSplitCriterionType => new GiniSplitCriterion() 249 | case vr: VarianceReductionSplitCriterionType => new VarianceReductionSplitCriterion() 250 | case _ => new InfoGainSplitCriterion(minBranch) 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/bayes/MultinomialNaiveBayes.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Portions Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | /* 19 | This file has been changed by gosubpl 20 | */ 21 | 22 | package org.apache.spark.streamdm.classifiers.bayes 23 | 24 | import scala.math.{ log, log10 } 25 | import com.github.javacliparser.IntOption 26 | import org.apache.spark.streamdm.classifiers.OnlineClassifier 27 | import org.apache.spark.streamdm.core._ 28 | import org.apache.spark.streamdm.classifiers.trees._ 29 | import org.apache.spark.streamdm.core.specification.ExampleSpecification 30 | /** 31 | * Incremental Multinomial Naive Bayes learner. Builds a bayesian text 32 | * classifier making the naive assumption that all inputs are independent and 33 | * that feature values represent the frequencies with words occur. For more 34 | * information see,

Andrew Mccallum, Kamal Nigam: A Comparison of 35 | * Event Models for Naive Bayes Text Classification. In: AAAI-98 Workshop on 36 | * 'Learning for Text Categorization', 1998.

37 | * 38 | *

It uses the following options: 39 | *

    40 | *
  • Number of features (-f) 41 | *
  • Number of classes (-c) 42 | *
  • Laplace smoothing parameter (-s) 43 | *
44 | */ 45 | class MultinomialNaiveBayes extends OnlineClassifier { 46 | 47 | type T = MultinomialNaiveBayesModel 48 | 49 | val numClassesOption: IntOption = new IntOption("numClasses", 'c', 50 | "Number of Classes", 2, 2, Integer.MAX_VALUE) 51 | 52 | val numFeaturesOption: IntOption = new IntOption("numFeatures", 'f', 53 | "Number of Features", 3, 1, Integer.MAX_VALUE) 54 | 55 | val laplaceSmoothingFactorOption: IntOption = new IntOption( 56 | "laplaceSmoothingFactor", 's', "Laplace Smoothing Factor", 1, 1, 57 | Integer.MAX_VALUE) 58 | 59 | var model: MultinomialNaiveBayesModel = null 60 | 61 | var exampleLearnerSpecification: ExampleSpecification = null 62 | 63 | /** 64 | * Init the model based on the algorithm implemented in the learner. 65 | * 66 | * @param exampleSpecification the ExampleSpecification of the input stream. 67 | */ 68 | override def init(exampleSpecification: ExampleSpecification): Unit = { 69 | exampleLearnerSpecification = exampleSpecification 70 | model = new MultinomialNaiveBayesModel( 71 | numClassesOption.getValue, numFeaturesOption.getValue, 72 | laplaceSmoothingFactorOption.getValue) 73 | } 74 | 75 | /* Train the model incrementally with a single instance given for training. 76 | * 77 | * @param input an instance 78 | * @return Unit 79 | */ 80 | override def trainIncremental(input: Example): Unit = { 81 | val tmodel = new MultinomialNaiveBayesModel(model.numClasses, model.numFeatures, 82 | model.laplaceSmoothingFactor) 83 | tmodel.update(input) 84 | model = model.merge(tmodel) 85 | } 86 | 87 | /* Predict the label of the Instance, given the current model 88 | * 89 | * @param instance the Instance which needs a class predicted 90 | * @return a tuple containing the original instance and the predicted value 91 | */ 92 | def predictSingle(input: Example): (Example, Double) = (input, model.predict(input)) 93 | 94 | /* Gets the current Model used for the Learner. 95 | * 96 | * @return the Model object used for training 97 | */ 98 | override def getModel: MultinomialNaiveBayesModel = model 99 | } 100 | 101 | /** 102 | * The Model used for the multinomial Naive Bayes. It contains the class 103 | * statistics and the class feature statistics. 104 | */ 105 | class MultinomialNaiveBayesModel(val numClasses: Int, val numFeatures: Int, 106 | val laplaceSmoothingFactor: Int) 107 | extends Model with Serializable { 108 | type T = MultinomialNaiveBayesModel 109 | 110 | var classStatistics: Array[Double] = new Array[Double](numClasses) 111 | var classFeatureStatistics: Array[Array[Double]] = Array.fill(numClasses)( 112 | new Array[Double](numFeatures)) 113 | 114 | @transient var isReady: Boolean = false 115 | 116 | // variables used for prediction 117 | @transient var logNumberDocuments: Double = 0 118 | @transient var logProbability: Array[Double] = null 119 | @transient var logConditionalProbability: Array[Array[Double]] = null 120 | @transient var logNumberDocumentsOfClass: Double = 0 121 | 122 | def this(numClasses: Int, numFeatures: Int, laplaceSmoothingFactor: Int, 123 | classStatistics: Array[Double], 124 | classFeatureStatistics: Array[Array[Double]]) { 125 | this(numClasses, numFeatures, laplaceSmoothingFactor) 126 | this.classStatistics = classStatistics 127 | this.classFeatureStatistics = classFeatureStatistics 128 | } 129 | 130 | /** 131 | * Update the model, depending on the Instance given for training. 132 | * 133 | * @param change the example based on which the Model is updated 134 | * @return the updated Model 135 | */ 136 | override def update(instance: Example): MultinomialNaiveBayesModel = { 137 | if (isReady) { 138 | isReady = false 139 | } 140 | classStatistics(instance.labelAt(0).toInt) += 1 141 | for (i <- 0 until numFeatures) { 142 | classFeatureStatistics(instance.labelAt(0).toInt)(i) += 143 | instance.featureAt(i) 144 | } 145 | this 146 | } 147 | 148 | /** 149 | * Prepare the model. 150 | * 151 | * @return a boolean indicating whether the model is ready 152 | */ 153 | private def prepare(): Boolean = { 154 | if (!isReady) { 155 | logProbability = new Array[Double](numClasses) 156 | logConditionalProbability = Array.fill(numClasses)(new Array[Double]( 157 | numFeatures)) 158 | val totalnum = classFeatureStatistics.map { x => x.sum }.sum 159 | logNumberDocuments = math.log(totalnum + numClasses * 160 | laplaceSmoothingFactor) 161 | for (i <- 0 until numClasses) { 162 | val logNumberDocumentsOfClass = math.log(classFeatureStatistics(i).sum + 163 | numFeatures * laplaceSmoothingFactor) 164 | logProbability(i) = math.log(classStatistics(i) + 165 | laplaceSmoothingFactor) - logNumberDocuments 166 | for (j <- 0 until numFeatures) { 167 | logConditionalProbability(i)(j) = 168 | math.log(classFeatureStatistics(i)(j) + laplaceSmoothingFactor) - 169 | logNumberDocumentsOfClass 170 | } 171 | } 172 | isReady = true 173 | } 174 | return isReady 175 | } 176 | 177 | /** 178 | * Predict the label of the Instance, given the current Model. 179 | * 180 | * @param instance the Example which needs a class predicted 181 | * @return a Double representing the class predicted 182 | */ 183 | def predict(instance: Example): Double = { 184 | if (prepare()) { 185 | val predictlogProbability = new Array[Double](numClasses) 186 | for (i <- 0 until numClasses) { 187 | predictlogProbability(i) = logProbability(i) 188 | for (j <- 0 until numFeatures) { 189 | predictlogProbability(i) += logConditionalProbability(i)(j) * 190 | instance.featureAt(j) 191 | } 192 | } 193 | argMax(predictlogProbability) 194 | } else 0 195 | } 196 | 197 | /** 198 | * Index corresponding to the maximum value of an array of Doubles. 199 | * 200 | * @param array the input Array 201 | * @return the argmax index 202 | */ 203 | private def argMax(array: Array[Double]): Double = array.zipWithIndex. 204 | maxBy(_._1)._2 205 | 206 | /** 207 | * Merge the statistics of another model into the current model. 208 | * 209 | * @param mod2 MultinomialNaiveBayesModel 210 | * @return MultinomialNaiveBayesModel 211 | */ 212 | def merge(mod2: MultinomialNaiveBayesModel): MultinomialNaiveBayesModel = { 213 | val mergeClassStatistics = this.classStatistics. 214 | zip(mod2.classStatistics).map { case (x, y) => x + y } 215 | val mergeClassFeatureStatistics = this.classFeatureStatistics. 216 | zip(mod2.classFeatureStatistics).map { case (a1, a2) => a1.zip(a2).map { 217 | case (x, y) => x + y } } 218 | new MultinomialNaiveBayesModel( 219 | mod2.numClasses, mod2.numFeatures, mod2.laplaceSmoothingFactor, 220 | mergeClassStatistics, mergeClassFeatureStatistics) 221 | } 222 | 223 | } 224 | 225 | object NaiveBayes { 226 | // predict the probabilities of all features for Hoeffding Tree 227 | def predict(point: Example, classDistribution: Array[Double], featureObservers: Array[FeatureClassObserver]): Array[Double] = { 228 | val votes: Array[Double] = classDistribution.map { _ / classDistribution.sum } 229 | for (i <- 0 until votes.length; j <- 0 until featureObservers.length) { 230 | votes(i) *= featureObservers(j).probability(i, point.featureAt(j)) 231 | } 232 | votes 233 | } 234 | 235 | // predict the log10 probabilities of all features for Hoeffding Tree 236 | def predictLog10(point: Example, classDistribution: Array[Double], featureObservers: Array[FeatureClassObserver]): Array[Double] = { 237 | val votes: Array[Double] = classDistribution.map { x => log10(x / classDistribution.sum) } 238 | for (i <- 0 until votes.length; j <- 0 until featureObservers.length) { 239 | votes(i) += log10(featureObservers(j).probability(i, point.featureAt(j))) 240 | } 241 | votes 242 | } 243 | } 244 | 245 | 246 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/FeatureClassObserver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import scala.collection.mutable.TreeSet 21 | import scala.math.{ min, max } 22 | 23 | import org.apache.spark.streamdm.core._ 24 | import org.apache.spark.streamdm.core.specification._ 25 | import org.apache.spark.streamdm.utils.Utils.{ normal, transpose, splitTranspose } 26 | /** 27 | * Trait FeatureClassObserver for observing the class distribution of one feature. 28 | * The observer monitors the class distribution of a given feature. 29 | * Used in Naive Bayes and decision trees to monitor data statistics on leaves. 30 | */ 31 | trait FeatureClassObserver extends Serializable { 32 | 33 | /** 34 | * Updates statistics of this observer given a feature value, a class index 35 | * and the weight of the example observed 36 | * 37 | * @param cIndex the index of class 38 | * @param fValue the value of the feature 39 | * @param weight the weight of the example 40 | */ 41 | def observeClass(cIndex: Double, fValue: Double, weight: Double): Unit 42 | 43 | /** 44 | * Gets the probability for an attribute value given a class 45 | * 46 | * @param cIndex the index of class 47 | * @param fValue the value of the feature 48 | * @return probability for a feature value given a class 49 | */ 50 | def probability(cIndex: Double, fValue: Double): Double 51 | 52 | /** 53 | * Gets the best split suggestion given a criterion and a class distribution 54 | * 55 | * @param criterion the split criterion to use 56 | * @param pre the class distribution before the split 57 | * @param fValue the value of the feature 58 | * @param isBinarySplit true to use binary splits 59 | * @return suggestion of best feature split 60 | */ 61 | def bestSplit(criterion: SplitCriterion, pre: Array[Double], fValue: Double, 62 | isBinarySplit: Boolean): FeatureSplit 63 | 64 | /** 65 | * Merge the FeatureClassObserver to current FeatureClassObserver 66 | * 67 | * @param that the FeatureClassObserver will be merged 68 | * @param trySplit whether called when a Hoeffding tree try to split 69 | * @return current FeatureClassObserver 70 | */ 71 | def merge(that: FeatureClassObserver, trySplit: Boolean): FeatureClassObserver 72 | 73 | /** 74 | * Not yet supported. 75 | */ 76 | def observeTarget(fValue: Double, weight: Double): Unit = {} 77 | 78 | } 79 | /** 80 | * NullFeatureClassObserver is a null class for observers. 81 | */ 82 | class NullFeatureClassObserver extends FeatureClassObserver with Serializable { 83 | 84 | /** 85 | * Updates statistics of this observer given a feature value, a class index 86 | * and the weight of the example observed 87 | * 88 | * @param cIndex the index of class 89 | * @param fValue the value of the feature 90 | * @param weight the weight of the example 91 | */ 92 | override def observeClass(cIndex: Double, fValue: Double, weight: Double): Unit = {} 93 | 94 | /** 95 | * Gets the probability for an attribute value given a class 96 | * 97 | * @param cIndex the index of class 98 | * @param fValue the value of the feature 99 | * @return probability for a feature value given a class 100 | */ 101 | override def probability(cIndex: Double, fValue: Double): Double = 0.0 102 | 103 | /** 104 | * Gets the best split suggestion given a criterion and a class distribution 105 | * 106 | * @param criterion the split criterion to use 107 | * @param pre the class distribution before the split 108 | * @param fValue the value of the feature 109 | * @param isBinarySplit true to use binary splits 110 | * @return suggestion of best feature split 111 | */ 112 | override def bestSplit(criterion: SplitCriterion, pre: Array[Double], fValue: Double, 113 | isBinarySplit: Boolean): FeatureSplit = { null } 114 | 115 | /** 116 | * Merge the FeatureClassObserver to current FeatureClassObserver 117 | * 118 | * @param that the FeatureClassObserver will be merged 119 | * @param trySplit whether called when a Hoeffding tree try to split 120 | * @return current FeatureClassObserver 121 | */ 122 | override def merge(that: FeatureClassObserver, trySplit: Boolean): FeatureClassObserver = this 123 | } 124 | /** 125 | * Class for observing the class distribution of a nominal feature. 126 | * The observer monitors the class distribution of a given feature. 127 | * Used in Naive Bayes and decision trees to monitor data statistics on leaves. 128 | */ 129 | class NominalFeatureClassObserver(val numClasses: Int, val fIndex: Int, val numFeatureValues: Int, 130 | val laplaceSmoothingFactor: Int = 1) extends FeatureClassObserver with Serializable { 131 | 132 | var classFeatureStatistics: Array[Array[Double]] = Array.fill(numClasses)(new Array[Double](numFeatureValues)) 133 | 134 | var blockClassFeatureStatistics: Array[Array[Double]] = Array.fill(numClasses)(new Array[Double](numFeatureValues)) 135 | 136 | var totalWeight: Double = 0.0 137 | var blockWeight: Double = 0.0 138 | 139 | def this(that: NominalFeatureClassObserver) { 140 | this(that.numClasses, that.fIndex, that.numFeatureValues, that.laplaceSmoothingFactor) 141 | for (i <- 0 until numClasses; j <- 0 until numFeatureValues) { 142 | classFeatureStatistics(i)(j) = that.classFeatureStatistics(i)(j) + 143 | that.blockClassFeatureStatistics(i)(j) 144 | } 145 | totalWeight = that.totalWeight + that.blockWeight 146 | } 147 | /** 148 | * Updates statistics of this observer given a feature value, a class index 149 | * and the weight of the example observed 150 | * 151 | * @param cIndex the index of class 152 | * @param fValue the value of the feature 153 | * @param weight the weight of the example 154 | */ 155 | override def observeClass(cIndex: Double, fValue: Double, weight: Double): Unit = { 156 | blockClassFeatureStatistics(cIndex.toInt)(fValue.toInt) += weight 157 | blockWeight += weight 158 | } 159 | 160 | /** 161 | * Gets the probability for an attribute value given a class 162 | * 163 | * @param cIndex the index of class 164 | * @param fValue the value of the feature 165 | * @return probability for a feature value given a class 166 | */ 167 | override def probability(cIndex: Double, fValue: Double): Double = { 168 | val sum = classFeatureStatistics(cIndex.toInt).sum 169 | if (sum == 0) 0.0 else { 170 | (classFeatureStatistics(cIndex.toInt)(fValue.toInt) + laplaceSmoothingFactor) / 171 | (sum + numFeatureValues * laplaceSmoothingFactor) 172 | } 173 | } 174 | 175 | /** 176 | * Gets the best split suggestion given a criterion and a class distribution 177 | * 178 | * @param criterion the split criterion to use 179 | * @param pre the class distribution before the split 180 | * @param fValue the value of the feature 181 | * @param isBinarySplit true to use binary splits 182 | * @return suggestion of best feature split 183 | */ 184 | override def bestSplit(criterion: SplitCriterion, pre: Array[Double], 185 | fValue: Double, isBinarySplit: Boolean): FeatureSplit = { 186 | var fSplit: FeatureSplit = null 187 | for (i <- 0 until pre.length) { 188 | val post: Array[Array[Double]] = binarySplit(i) 189 | val merit = criterion.merit(normal(pre), normal(post)) 190 | if (fSplit == null || fSplit.merit < merit) { 191 | fSplit = new FeatureSplit(new NominalBinaryTest(fIndex, i), merit, post) 192 | } 193 | } 194 | if (!isBinarySplit) { 195 | val post = multiwaySplit() 196 | val merit = criterion.merit(pre, post) 197 | if (fSplit.merit < merit) 198 | fSplit = new FeatureSplit(new NominalMultiwayTest(fIndex, numFeatureValues), merit, post) 199 | } 200 | fSplit 201 | } 202 | /** 203 | * Merge the FeatureClassObserver to current FeatureClassObserver 204 | * 205 | * @param that the FeatureClassObserver will be merged 206 | * @param trySplit whether called when a Hoeffding tree try to split 207 | * @return current FeatureClassObserver 208 | */ 209 | override def merge(that: FeatureClassObserver, trySplit: Boolean): FeatureClassObserver = { 210 | if (!that.isInstanceOf[NominalFeatureClassObserver]) 211 | this 212 | else { 213 | val observer = that.asInstanceOf[NominalFeatureClassObserver] 214 | if (numClasses != observer.numClasses || fIndex != observer.fIndex || 215 | numFeatureValues != observer.numFeatureValues || 216 | laplaceSmoothingFactor != observer.laplaceSmoothingFactor) this 217 | else { 218 | if (!trySplit) { 219 | totalWeight += observer.blockWeight 220 | for ( 221 | i <- 0 until blockClassFeatureStatistics.length; j <- 0 until 222 | blockClassFeatureStatistics(0).length 223 | ) { 224 | blockClassFeatureStatistics(i)(j) += observer.blockClassFeatureStatistics(i)(j) 225 | } 226 | } else { 227 | totalWeight += observer.totalWeight 228 | for ( 229 | i <- 0 until classFeatureStatistics.length; j <- 0 until 230 | classFeatureStatistics(0).length 231 | ) { 232 | classFeatureStatistics(i)(j) += observer.blockClassFeatureStatistics(i)(j) 233 | } 234 | } 235 | this 236 | } 237 | } 238 | } 239 | /** 240 | * Binary split the tree data depending on the input value 241 | * @param fValue the input value 242 | * @return an Array encoding the split 243 | */ 244 | private[trees] def binarySplit(fValue: Double): Array[Array[Double]] = 245 | { splitTranspose(classFeatureStatistics, fValue.toInt) } 246 | /** 247 | * Split the data globally. 248 | * @return an Array encoding the split 249 | */ 250 | private[trees] def multiwaySplit(): Array[Array[Double]] = 251 | { transpose(classFeatureStatistics) } 252 | } 253 | /** 254 | * Trait for the numeric feature observers. 255 | */ 256 | trait NumericFeatureClassObserver extends FeatureClassObserver 257 | 258 | /** 259 | * Class GuassianNumericFeatureClassObserver for observing the class data distribution for a numeric feature using gaussian estimators. 260 | * This observer monitors the class distribution of a given feature. 261 | * Used in naive Bayes and decision trees to monitor data statistics on leaves. 262 | */ 263 | 264 | class GaussianNumericFeatureClassObserver(val numClasses: Int, val fIndex: Int, val numBins: Int = 10) 265 | extends NumericFeatureClassObserver with Serializable { 266 | 267 | val estimators: Array[GaussianEstimator] = Array.fill(numClasses)(new GaussianEstimator()) 268 | val minValuePerClass: Array[Double] = Array.fill(numClasses)(Double.PositiveInfinity) 269 | val maxValuePerClass: Array[Double] = Array.fill(numClasses)(Double.NegativeInfinity) 270 | 271 | def this(that: GaussianNumericFeatureClassObserver) { 272 | this(that.numClasses, that.fIndex, that.numBins) 273 | for (i <- 0 until numClasses) estimators(i) = new GaussianEstimator(that.estimators(i)) 274 | } 275 | 276 | /** 277 | * Updates statistics of this observer given a feature value, a class index 278 | * and the weight of the example observed 279 | * 280 | * @param cIndex the index of class 281 | * @param fValue the value of the feature 282 | * @param weight the weight of the example 283 | */ 284 | override def observeClass(cIndex: Double, fValue: Double, weight: Double): Unit = { 285 | if (false) { 286 | // todo, process missing value 287 | 288 | } else { 289 | if (minValuePerClass(cIndex.toInt) > fValue) 290 | minValuePerClass(cIndex.toInt) = fValue 291 | if (maxValuePerClass(cIndex.toInt) < fValue) 292 | maxValuePerClass(cIndex.toInt) = fValue 293 | estimators(cIndex.toInt).observe(fValue, weight) 294 | } 295 | } 296 | 297 | /** 298 | * Gets the probability for an attribute value given a class 299 | * 300 | * @param cIndex the index of class 301 | * @param fValue the value of the feature 302 | * @return probability for a feature value given a class 303 | */ 304 | override def probability(cIndex: Double, fValue: Double): Double = { 305 | if (estimators(cIndex.toInt) == null) 0.0 306 | else estimators(cIndex.toInt).probabilityDensity(fValue) 307 | } 308 | 309 | /** 310 | * Gets the best split suggestion given a criterion and a class distribution 311 | * 312 | * @param criterion the split criterion to use 313 | * @param pre the class distribution before the split 314 | * @param fValue the value of the feature 315 | * @param isBinarySplit true to use binary splits 316 | * @return suggestion of best feature split 317 | */ 318 | override def bestSplit(criterion: SplitCriterion, pre: Array[Double], 319 | fValue: Double, isBinarySplit: Boolean): FeatureSplit = { 320 | var fSplit: FeatureSplit = null 321 | val points: Array[Double] = splitPoints() 322 | for (splitValue: Double <- points) { 323 | val post: Array[Array[Double]] = binarySplit(splitValue) 324 | val merit = criterion.merit(normal(pre), normal(post)) 325 | if (fSplit == null || fSplit.merit < merit) 326 | fSplit = new FeatureSplit(new NumericBinaryTest(fIndex, splitValue, false), merit, post) 327 | } 328 | fSplit 329 | } 330 | 331 | /** 332 | * Binary split the tree data depending on the input value 333 | * @param splitValue the input value 334 | * @return an Array encoding the split 335 | */ 336 | private[trees] def binarySplit(splitValue: Double): Array[Array[Double]] = { 337 | val rst: Array[Array[Double]] = Array.fill(2)(new Array(numClasses)) 338 | estimators.zipWithIndex.foreach { 339 | case (es, i) => { 340 | if (splitValue < minValuePerClass(i)) { 341 | rst(1)(i) += es.totalWeight() 342 | } else if (splitValue >= maxValuePerClass(i)) { 343 | rst(0)(i) += es.totalWeight() 344 | } else { 345 | val weights: Array[Double] = es.tripleWeights(splitValue) 346 | rst(0)(i) += weights(0) + weights(1) 347 | rst(1)(i) += weights(2) 348 | } 349 | } 350 | } 351 | rst 352 | } 353 | 354 | private[trees] def splitPoints(): Array[Double] = { 355 | var minValue = Double.PositiveInfinity 356 | var maxValue = Double.NegativeInfinity 357 | val points = new TreeSet[Double]() 358 | minValuePerClass.foreach { x => minValue = min(minValue, x) } 359 | maxValuePerClass.foreach { x => maxValue = max(maxValue, x) } 360 | if (minValue < Double.PositiveInfinity) { 361 | val range = maxValue - minValue 362 | for (i <- 0 until numBins) { 363 | val splitValue = range * (i + 1) / (numBins) + minValue 364 | if (splitValue > minValue && splitValue < maxValue) 365 | points.add(splitValue) 366 | } 367 | } 368 | points.toArray 369 | } 370 | 371 | /** 372 | * Merge the FeatureClassObserver to current FeatureClassObserver 373 | * 374 | * @param that the FeatureClassObserver will be merged 375 | * @param trySplit whether called when a Hoeffding tree try to split 376 | * @return current FeatureClassObserver 377 | */ 378 | override def merge(that: FeatureClassObserver, trySplit: Boolean): FeatureClassObserver = { 379 | if (!that.isInstanceOf[GaussianNumericFeatureClassObserver]) this 380 | else { 381 | val observer = that.asInstanceOf[GaussianNumericFeatureClassObserver] 382 | if (numClasses == observer.numClasses && fIndex == observer.fIndex) { 383 | for (i <- 0 until numClasses) { 384 | estimators(i) = estimators(i).merge(observer.estimators(i), trySplit) 385 | minValuePerClass(i) = min(minValuePerClass(i), observer.minValuePerClass(i)) 386 | maxValuePerClass(i) = max(maxValuePerClass(i), observer.maxValuePerClass(i)) 387 | } 388 | } 389 | this 390 | } 391 | } 392 | } 393 | 394 | /** 395 | * ojbect FeatureClassObserver for create FeatureClassObserver. 396 | */ 397 | object FeatureClassObserver { 398 | def createFeatureClassObserver(numClasses: Int, fIndex: Int, 399 | featureSpec: FeatureSpecification): FeatureClassObserver = { 400 | if (featureSpec.isNominal()) 401 | new NominalFeatureClassObserver(numClasses, fIndex, featureSpec.range()) 402 | else 403 | new GaussianNumericFeatureClassObserver(numClasses, fIndex) 404 | } 405 | 406 | def createFeatureClassObserver(observer: FeatureClassObserver): FeatureClassObserver = { 407 | if (observer.isInstanceOf[NominalFeatureClassObserver]) 408 | new NominalFeatureClassObserver(observer.asInstanceOf[NominalFeatureClassObserver]) 409 | else if (observer.isInstanceOf[GaussianNumericFeatureClassObserver]) 410 | new GaussianNumericFeatureClassObserver(observer.asInstanceOf[GaussianNumericFeatureClassObserver]) 411 | else new NullFeatureClassObserver 412 | } 413 | } 414 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/streamdm/classifiers/trees/Node.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2015 Holmes Team at HUAWEI Noah's Ark Lab. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | * 16 | */ 17 | 18 | package org.apache.spark.streamdm.classifiers.trees 19 | 20 | import scala.collection.mutable.ArrayBuffer 21 | import scala.math.{ max } 22 | 23 | import org.apache.spark.streamdm.core._ 24 | import org.apache.spark.streamdm.core.specification._ 25 | import org.apache.spark.streamdm.classifiers.bayes._ 26 | import org.apache.spark.streamdm.utils.Utils.{ argmax } 27 | 28 | /** 29 | * Abstract class containing the node information for the Hoeffding trees. 30 | */ 31 | abstract class Node(val classDistribution: Array[Double]) extends Serializable { 32 | 33 | var dep: Int = 0 34 | // stores class distribution of a block of RDD 35 | val blockClassDistribution: Array[Double] = new Array[Double](classDistribution.length) 36 | 37 | /** 38 | * Filter the data to the related leaf node 39 | * 40 | * @param example the input Example 41 | * @param parent the parent of current node 42 | * @param index the index of current node in the parent children 43 | * @return a FoundNode containing the leaf node 44 | */ 45 | def filterToLeaf(example: Example, parent: SplitNode, index: Int): FoundNode 46 | 47 | /** 48 | * Return the class distribution 49 | * @return an Array containing the class distribution 50 | */ 51 | def classVotes(ht: HoeffdingTreeModel, example: Example): Array[Double] = 52 | classDistribution.clone() 53 | 54 | /** 55 | * Checks whether a node is a leaf 56 | * @return true if a node is a leaf, false otherwise 57 | */ 58 | def isLeaf(): Boolean = true 59 | 60 | /** 61 | * Returns height of the tree 62 | * 63 | * @return the height 64 | */ 65 | def height(): Int = 0 66 | 67 | /** 68 | * Returns depth of current node in the tree 69 | * 70 | * @return the depth 71 | */ 72 | def depth(): Int = dep 73 | 74 | /** 75 | * Set the depth of current node 76 | * 77 | * @param depth the new depth 78 | */ 79 | def setDepth(depth: Int): Unit = { 80 | dep = depth 81 | if (this.isInstanceOf[SplitNode]) { 82 | val splidNode = this.asInstanceOf[SplitNode] 83 | splidNode.children.foreach { _.setDepth(depth + 1) } 84 | } 85 | } 86 | 87 | /** 88 | * Merge two nodes 89 | * 90 | * @param node the node which will be merged 91 | * @param trySplit flag indicating whether the node will be split 92 | * @return new node 93 | */ 94 | def merge(that: Node, trySplit: Boolean): Node 95 | 96 | /** 97 | * Returns number of children 98 | * 99 | * @return number of children 100 | */ 101 | def numChildren(): Int = 0 102 | 103 | /** 104 | * Returns the node description 105 | * @return String containing the description 106 | */ 107 | def description(): String = { 108 | " " * dep + "Leaf" + " weight = " + 109 | Utils.arraytoString(classDistribution) + "\n" 110 | } 111 | 112 | } 113 | 114 | /** 115 | * The container of a node. 116 | */ 117 | class FoundNode(val node: Node, val parent: SplitNode, val index: Int) extends Serializable { 118 | 119 | } 120 | 121 | /** 122 | * Branch node of the Hoeffding tree. 123 | */ 124 | class SplitNode(classDistribution: Array[Double], val conditionalTest: ConditionalTest) 125 | extends Node(classDistribution) with Serializable { 126 | 127 | val children: ArrayBuffer[Node] = new ArrayBuffer[Node]() 128 | 129 | def this(that: SplitNode) { 130 | this(Utils.addArrays(that.classDistribution, that.blockClassDistribution), 131 | that.conditionalTest) 132 | } 133 | 134 | /** 135 | * Filter the data to the related leaf node 136 | * 137 | * @param example input example 138 | * @param parent the parent of current node 139 | * @param index the index of current node in the parent children 140 | * @return FoundNode cotaining the leaf node 141 | */ 142 | override def filterToLeaf(example: Example, parent: SplitNode, index: Int): FoundNode = { 143 | val cIndex = childIndex(example) 144 | if (cIndex >= 0) { 145 | if (cIndex < children.length && children(cIndex) != null) { 146 | children(cIndex).filterToLeaf(example, this, cIndex) 147 | } else new FoundNode(null, this, cIndex) 148 | } else new FoundNode(this, parent, index) 149 | } 150 | 151 | def childIndex(example: Example): Int = { 152 | conditionalTest.branch(example) 153 | } 154 | 155 | def setChild(index: Int, node: Node): Unit = { 156 | if (children.length > index) { 157 | children(index) = node 158 | node.setDepth(dep + 1) 159 | } else if (children.length == index) { 160 | children.append(node) 161 | node.setDepth(dep + 1) 162 | } else { 163 | assert(children.length < index) 164 | } 165 | } 166 | /** 167 | * Returns whether a node is a leaf 168 | */ 169 | override def isLeaf() = false 170 | 171 | /** 172 | * Returns height of the tree 173 | * 174 | * @return the height 175 | */ 176 | override def height(): Int = { 177 | var height = 0 178 | for (child: Node <- children) { 179 | height = max(height, child.height()) + 1 180 | } 181 | height 182 | } 183 | 184 | /** 185 | * Returns number of children 186 | * 187 | * @return number of children 188 | */ 189 | override def numChildren(): Int = children.filter { _ != null }.length 190 | 191 | /** 192 | * Merge two nodes 193 | * 194 | * @param node the node which will be merged 195 | * @param trySplit flag indicating whether the node will be split 196 | * @return new node 197 | */ 198 | override def merge(that: Node, trySplit: Boolean): Node = { 199 | if (!that.isInstanceOf[SplitNode]) this 200 | else { 201 | val splitNode = that.asInstanceOf[SplitNode] 202 | for (i <- 0 until children.length) 203 | this.children(i) = (this.children(i)).merge(splitNode.children(i), trySplit) 204 | this 205 | } 206 | } 207 | 208 | /** 209 | * Returns the node description 210 | * @return String containing the description 211 | */ 212 | override def description(): String = { 213 | val sb = new StringBuffer(" " * dep + "\n") 214 | val testDes = conditionalTest.description() 215 | for (i <- 0 until children.length) { 216 | sb.append(" " * dep + " if " + testDes(i) + "\n") 217 | sb.append(" " * dep + children(i).description()) 218 | } 219 | sb.toString() 220 | } 221 | 222 | override def toString(): String = "level[" + dep + "] SplitNode" 223 | 224 | } 225 | /** 226 | * Learning node class type for Hoeffding trees. 227 | */ 228 | abstract class LearningNode(classDistribution: Array[Double]) extends Node(classDistribution) 229 | with Serializable { 230 | 231 | /** 232 | * Learn and update the node 233 | * 234 | * @param ht a Hoeffding tree model 235 | * @param example the input Example 236 | */ 237 | def learn(ht: HoeffdingTreeModel, example: Example): Unit 238 | 239 | /** 240 | * Return whether a learning node is active 241 | */ 242 | def isActive(): Boolean 243 | 244 | /** 245 | * Filter the data to the related leaf node 246 | * 247 | * @param example the input example 248 | * @param parent the parent of current node 249 | * @param index the index of current node in the parent children 250 | * @return FoundNode containing the leaf node 251 | */ 252 | override def filterToLeaf(example: Example, parent: SplitNode, index: Int): FoundNode = 253 | new FoundNode(this, parent, index) 254 | 255 | } 256 | 257 | /** 258 | * Basic majority class active learning node for Hoeffding tree 259 | */ 260 | class ActiveLearningNode(classDistribution: Array[Double]) 261 | extends LearningNode(classDistribution) with Serializable { 262 | 263 | var addonWeight: Double = 0 264 | 265 | var blockAddonWeight: Double = 0 266 | 267 | var instanceSpecification: InstanceSpecification = null 268 | 269 | var featureObservers: Array[FeatureClassObserver] = null 270 | 271 | def this(classDistribution: Array[Double], instanceSpecification: InstanceSpecification) { 272 | this(classDistribution) 273 | this.instanceSpecification = instanceSpecification 274 | init() 275 | } 276 | 277 | def this(that: ActiveLearningNode) { 278 | this(Utils.addArrays(that.classDistribution, that.blockClassDistribution), 279 | that.instanceSpecification) 280 | this.addonWeight = that.addonWeight 281 | } 282 | /** 283 | * init featureObservers array 284 | */ 285 | def init(): Unit = { 286 | if (featureObservers == null) { 287 | featureObservers = new Array(instanceSpecification.size()) 288 | for (i <- 0 until instanceSpecification.size()) { 289 | val featureSpec: FeatureSpecification = instanceSpecification(i) 290 | featureObservers(i) = FeatureClassObserver.createFeatureClassObserver( 291 | classDistribution.length, i, featureSpec) 292 | } 293 | } 294 | } 295 | 296 | /** 297 | * Learn and update the node 298 | * 299 | * @param ht a Hoeffding tree model 300 | * @param example the input example 301 | */ 302 | override def learn(ht: HoeffdingTreeModel, example: Example): Unit = { 303 | init() 304 | blockClassDistribution(example.labelAt(0).toInt) += example.weight 305 | featureObservers.zipWithIndex.foreach { 306 | x => x._1.observeClass(example.labelAt(0).toInt, example.featureAt(x._2), example.weight) 307 | } 308 | } 309 | /** 310 | * Disable a feature at a given index 311 | * 312 | * @param fIndex the index of the feature 313 | */ 314 | def disableFeature(fIndex: Int): Unit = { 315 | //not support yet 316 | } 317 | 318 | /** 319 | * Returns whether a node is active. 320 | * 321 | */ 322 | override def isActive(): Boolean = true 323 | 324 | /** 325 | * Returns whether a node is pure, which means it only has examples belonging 326 | * to a single class. 327 | */ 328 | def isPure(): Boolean = { 329 | this.classDistribution.filter(_ > 0).length <= 1 && 330 | this.blockClassDistribution.filter(_ > 0).length <= 1 331 | } 332 | 333 | def weight(): Double = { classDistribution.sum + blockClassDistribution.sum } 334 | 335 | def blockWeight(): Double = blockClassDistribution.sum 336 | 337 | def addOnWeight(): Double = { 338 | addonWeight 339 | } 340 | 341 | /** 342 | * Merge two nodes 343 | * 344 | * @param node the node which will be merged 345 | * @param trySplit flag indicating whether the node will be split 346 | * @return new node 347 | */ 348 | override def merge(that: Node, trySplit: Boolean): Node = { 349 | if (that.isInstanceOf[ActiveLearningNode]) { 350 | val node = that.asInstanceOf[ActiveLearningNode] 351 | //merge addonWeight and class distribution 352 | if (!trySplit) { 353 | this.blockAddonWeight += that.blockClassDistribution.sum 354 | for (i <- 0 until blockClassDistribution.length) 355 | this.blockClassDistribution(i) += that.blockClassDistribution(i) 356 | } else { 357 | this.addonWeight += node.blockAddonWeight 358 | for (i <- 0 until classDistribution.length) 359 | this.classDistribution(i) += that.blockClassDistribution(i) 360 | } 361 | //merge feature class observers 362 | for (i <- 0 until featureObservers.length) 363 | featureObservers(i) = featureObservers(i).merge(node.featureObservers(i), trySplit) 364 | } 365 | this 366 | } 367 | /** 368 | * Returns Split suggestions for all features. 369 | * 370 | * @param splitCriterion the SplitCriterion used 371 | * @param ht a Hoeffding tree model 372 | * @return an array of FeatureSplit 373 | */ 374 | def getBestSplitSuggestions(splitCriterion: SplitCriterion, ht: HoeffdingTreeModel): Array[FeatureSplit] = { 375 | val bestSplits = new ArrayBuffer[FeatureSplit]() 376 | featureObservers.zipWithIndex.foreach(x => 377 | bestSplits.append(x._1.bestSplit(splitCriterion, classDistribution, x._2, ht.binaryOnly))) 378 | if (!ht.noPrePrune) { 379 | bestSplits.append(new FeatureSplit(null, splitCriterion.merit(classDistribution, 380 | Array.fill(1)(classDistribution)), new Array[Array[Double]](0))) 381 | } 382 | bestSplits.toArray 383 | } 384 | 385 | override def toString(): String = "level[" + dep + "]ActiveLearningNode:" + weight 386 | } 387 | /** 388 | * Inactive learning node for Hoeffding trees 389 | */ 390 | class InactiveLearningNode(classDistribution: Array[Double]) 391 | extends LearningNode(classDistribution) with Serializable { 392 | 393 | def this(that: InactiveLearningNode) { 394 | this(Utils.addArrays(that.classDistribution, that.blockClassDistribution)) 395 | } 396 | 397 | /** 398 | * Learn and update the node. No action is taken for InactiveLearningNode 399 | * 400 | * @param ht HoeffdingTreeModel 401 | * @param example an Example will be processed 402 | */ 403 | override def learn(ht: HoeffdingTreeModel, example: Example): Unit = {} 404 | 405 | /** 406 | * Return whether a learning node is active 407 | */ 408 | override def isActive(): Boolean = false 409 | 410 | /** 411 | * Merge two nodes 412 | * 413 | * @param node the node which will be merged 414 | * @param trySplit flag indicating whether the node will be split 415 | * @return new node 416 | */ 417 | override def merge(that: Node, trySplit: Boolean): Node = this 418 | 419 | override def toString(): String = "level[" + dep + "] InactiveLearningNode" 420 | } 421 | /** 422 | * Naive Bayes based learning node. 423 | */ 424 | class LearningNodeNB(classDistribution: Array[Double], instanceSpecification: InstanceSpecification) 425 | extends ActiveLearningNode(classDistribution, instanceSpecification) with Serializable { 426 | 427 | def this(that: LearningNodeNB) { 428 | this(Utils.addArrays(that.classDistribution, that.blockClassDistribution), 429 | that.instanceSpecification) 430 | //init() 431 | } 432 | 433 | /** 434 | * Returns the predicted class distribution 435 | * 436 | * @param ht a Hoeffding tree model 437 | * @param example the Example to be evaluated 438 | * @return the predicted class distribution 439 | */ 440 | override def classVotes(ht: HoeffdingTreeModel, example: Example): Array[Double] = { 441 | if (weight() > ht.nbThreshold) 442 | NaiveBayes.predict(example, classDistribution, featureObservers) 443 | else super.classVotes(ht, example) 444 | } 445 | 446 | /** 447 | * Disable a feature having an index 448 | * 449 | * @param fIndex the index of the feature 450 | */ 451 | override def disableFeature(fIndex: Int): Unit = { 452 | featureObservers(fIndex) = new NullFeatureClassObserver() 453 | } 454 | } 455 | 456 | /** 457 | * Adaptive Naive Bayes learning node. 458 | */ 459 | 460 | class LearningNodeNBAdaptive(classDistribution: Array[Double], 461 | instanceSpecification: InstanceSpecification) 462 | extends ActiveLearningNode(classDistribution, instanceSpecification) with Serializable { 463 | 464 | var mcCorrectWeight: Double = 0 465 | var nbCorrectWeight: Double = 0 466 | 467 | var mcBlockCorrectWeight: Double = 0 468 | var nbBlockCorrectWeight: Double = 0 469 | 470 | def this(that: LearningNodeNBAdaptive) { 471 | this(Utils.addArrays(that.classDistribution, that.blockClassDistribution), 472 | that.instanceSpecification) 473 | addonWeight = that.addonWeight 474 | mcCorrectWeight = that.mcCorrectWeight 475 | nbCorrectWeight = that.nbCorrectWeight 476 | init() 477 | } 478 | 479 | /** 480 | * Learn and update the node. 481 | * 482 | * @param ht a Hoeffding tree model 483 | * @param example an input example 484 | */ 485 | override def learn(ht: HoeffdingTreeModel, example: Example): Unit = { 486 | super.learn(ht, example) 487 | if (argmax(classDistribution) == example.labelAt(0)) 488 | mcBlockCorrectWeight += example.weight 489 | if (argmax(NaiveBayes.predict(example, classDistribution, featureObservers)) == 490 | example.labelAt(0)) 491 | nbBlockCorrectWeight += example.weight 492 | } 493 | 494 | /** 495 | * Merge two nodes 496 | * 497 | * @param node the node which will be merged 498 | * @param trySplit flag indicating whether the node will be split 499 | * @return new node 500 | */ 501 | override def merge(that: Node, trySplit: Boolean): Node = { 502 | if (that.isInstanceOf[LearningNodeNBAdaptive]) { 503 | val nbaNode = that.asInstanceOf[LearningNodeNBAdaptive] 504 | //merge weights and class distribution 505 | if (!trySplit) { 506 | this.blockAddonWeight += nbaNode.blockClassDistribution.sum 507 | mcBlockCorrectWeight += nbaNode.mcBlockCorrectWeight 508 | nbBlockCorrectWeight += nbaNode.nbBlockCorrectWeight 509 | for (i <- 0 until blockClassDistribution.length) 510 | this.blockClassDistribution(i) += that.blockClassDistribution(i) 511 | } else { 512 | this.addonWeight += nbaNode.blockAddonWeight 513 | mcCorrectWeight += nbaNode.mcBlockCorrectWeight 514 | nbCorrectWeight += nbaNode.nbBlockCorrectWeight 515 | for (i <- 0 until classDistribution.length) 516 | this.classDistribution(i) += that.blockClassDistribution(i) 517 | } 518 | //merge feature class observers 519 | for (i <- 0 until featureObservers.length) 520 | featureObservers(i) = featureObservers(i).merge(nbaNode.featureObservers(i), trySplit) 521 | 522 | } 523 | this 524 | } 525 | 526 | /** 527 | * Returns the predicted class distribution 528 | * 529 | * @param ht a Hoeffding tree model 530 | * @param example the input example 531 | * @return the predicted class distribution 532 | */ 533 | override def classVotes(ht: HoeffdingTreeModel, example: Example): Array[Double] = { 534 | if (mcCorrectWeight > nbCorrectWeight) super.classVotes(ht, example) 535 | else NaiveBayes.predict(example, classDistribution, featureObservers) 536 | } 537 | } 538 | --------------------------------------------------------------------------------