├── project ├── build.properties └── plugins.sbt ├── forest-example ├── run_spark_demo.sh ├── build.sbt ├── classify_user.py ├── src │ └── main │ │ └── scala │ │ └── ForestTest.scala └── add_user_features.py ├── .gitignore ├── scripts ├── jedis-ml-test.scala ├── kmeans-example.scala ├── gen_data.py ├── new-ml-forest-example.scala └── ml-forest-example.scala ├── LICENSE ├── README.md ├── Dockerfile ├── src └── main │ └── scala │ └── com │ └── redislabs │ └── provider │ └── redis │ └── ml │ └── package.scala └── scalastyle-config.xml /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.6 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3") 3 | 4 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 5 | -------------------------------------------------------------------------------- /forest-example/run_spark_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ~/dev/spark/bin/spark-submit --master local[*] --jars lib/jedis-ml-1.0-SNAPSHOT.jar,lib/jedis-3.0.0-SNAPSHOT.jar,lib/spark-redis-ml-assembly-0.1.0.jar,lib/spark-mllib_2.11-2.2.0-SNAPSHOT.jar ./target/scala-2.11/forestexample_2.11-0.1.0.jar data/$1 $2 4 | 5 | #for f in {1..10}; do ./run_spark_demo.sh $f 30; done -------------------------------------------------------------------------------- /forest-example/build.sbt: -------------------------------------------------------------------------------- 1 | // Your sbt build file. Guides on how to write one can be found at 2 | // http://www.scala-sbt.org/0.13/docs/index.html 3 | name := "ForestExample" 4 | 5 | scalaVersion := "2.11.6" 6 | 7 | version := "0.1.0" 8 | 9 | libraryDependencies ++= Seq( 10 | "org.apache.spark" %% "spark-core" % "2.1.1", 11 | "org.apache.spark" %% "spark-sql" % "2.1.1", 12 | "org.apache.spark" %% "spark-mllib" % "2.1.1", 13 | "redis.clients" % "jedis" % "2.7.2") 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | hs_err_pid*.log 2 | nohup.out 3 | .idea 4 | *.iml 5 | **/.idea 6 | */.classpath 7 | */.project 8 | */.settings 9 | */.cache 10 | */test-output/ 11 | *.log 12 | */*.versionsBackup 13 | target/ 14 | *GitIgnored* 15 | *.asc 16 | *.gpg 17 | /bin/ 18 | 19 | *.class 20 | *.log 21 | *.pyc 22 | sbt/*.jar 23 | 24 | # sbt specific 25 | .cache/ 26 | .history/ 27 | .lib/ 28 | dist/* 29 | target/ 30 | lib_managed/ 31 | src_managed/ 32 | project/boot/ 33 | project/plugins/project/ 34 | **.jar 35 | -------------------------------------------------------------------------------- /forest-example/classify_user.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import operator 4 | import redis 5 | config = {"host":"localhost", "port":6379} 6 | r = redis.StrictRedis(**config) 7 | 8 | user_profile = r.get("user-1-profile") 9 | 10 | results = {} 11 | 12 | for i in range(1, 11): 13 | results[i] = r.execute_command("ML.FOREST.RUN", "movie-{}".format(i), user_profile) 14 | 15 | print "Movies sorted by scores:" 16 | sorted_results = sorted(results.items(), key=operator.itemgetter(1), reverse=True) 17 | for k,v in sorted_results: 18 | print "movie-{}:{}".format(k,v) 19 | 20 | print "" 21 | print "Recommended movie: movie-{}".format(sorted_results[0][0]) 22 | 23 | -------------------------------------------------------------------------------- /scripts/jedis-ml-test.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.ml.regression.LinearRegression 2 | import com.redislabs.client.redisml.MLClient 3 | import redis.clients.jedis.{Jedis, _} 4 | 5 | // Load training data and train 6 | val training = spark.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") 7 | val lr = new LinearRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8) 8 | val lrModel = lr.fit(training) 9 | println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") 10 | 11 | // Connect to Redis 12 | val jedis = new Jedis("localhost") 13 | 14 | // Load model to Redis 15 | val cmd = "my_lr_model" +: lrModel.intercept.toString +: lrModel.coefficients.toArray.mkString(",").split(",") 16 | jedis.getClient.sendCommand(MLClient.ModuleCommand.LINREG_SET, cmd: _*) 17 | jedis.getClient().getStatusCodeReply 18 | 19 | // Perform prediction with Redis 20 | val cmd = Array("my_lr_model", "1", "2", "5") 21 | jedis.getClient.sendCommand(MLClient.ModuleCommand.LINREG_PREDICT, cmd: _*) 22 | jedis.getClient().getStatusCodeReply 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Spark-Redis is provided under the 3-Clause BSD License: http://opensource.org/licenses/BSD-3-Clause 2 | 3 | Copyright (c) 2015, Redis Labs, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | #### Notice RedisML is planned to be replaced by [RedisAI](http://redisai.io), adding support for deep learning. 3 | *** 4 | 5 | # Spark-Redis-ML 6 | 7 | ### A spark package for loading Spark ML models to [Redis-ML](https://github.com/RedisLabsModules/redis-ml "Redis-ML") 8 | 9 | ## Requirments: 10 | 11 | Apache [Spark](https://github.com/apache/spark) 2.0 or later 12 | 13 | [Redis](https://github.com/antirez/redis) build from unstable branch 14 | 15 | [Jedis](https://github.com/xetorthio/jedis) 16 | 17 | [Jedis-ml](https://github.com/RedisLabs/jedis-ml) 18 | 19 | ## Installation: 20 | 21 | ```sh 22 | #get and build redis-ml 23 | git clone https://github.com/RedisLabsModules/redis-ml.git 24 | cd redis-ml/src 25 | make 26 | 27 | #get and build jedis 28 | git clone https://github.com/xetorthio/jedis.git 29 | cd jedis 30 | mvn package -Dmaven.test.skip=true 31 | 32 | #get and build jedis-ml 33 | cd.. 34 | git clone https://github.com/RedisLabs/jedis-ml.git 35 | cd jedis-ml 36 | mkdir lib 37 | cp ../jedis/target/jedis-3.0.0-SNAPSHOT.jar lib/ 38 | mvn install 39 | 40 | #get and build spark-jedis-ml 41 | cd.. 42 | git clone https://github.com/RedisLabs/spark-redis-ml.git 43 | cd spark-redis-ml 44 | cp ../jedis/target/jedis-3.0.0-SNAPSHOT.jar lib/ 45 | cp ../jedis-ml/target/jedis-ml-1.0-SNAPSHOT.jar lib/ 46 | sbt assembly 47 | ``` 48 | 49 | 50 | 51 | ### Usage: 52 | 53 | Run Redis server with redis-ml module: 54 | 55 | ```sh 56 | /path/to/redis-server --loadmodule ./redis-ml.so 57 | ``` 58 | 59 | 60 | 61 | From Spark root directory, Run Spark shell with the required jars: 62 | 63 | ```sh 64 | ./bin/spark-shell --jars ../spark-redis-ml/target/scala-2.11/spark-redis-ml-assembly-0.1.0.jar,../spark-redis-ml/lib/jedis-3.0.0-SNAPSHOT.jar,../spark-redis-ml/lib/jedis-ml-1.0-SNAPSHOT.jar 65 | ``` 66 | 67 | 68 | 69 | On Spark shell: 70 | 71 | ```sh 72 | scala> :load "../spark-redis-ml/scripts/forest-example.scala" 73 | scala> benchmark(10) 74 | ``` 75 | 76 | 77 | 78 | ### 79 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu 2 | 3 | RUN apt-get -y update && apt-get install -y build-essential git wget unzip python vim 4 | RUN git clone https://github.com/xetorthio/jedis.git 5 | RUN git clone https://github.com/RedisLabs/jedis-ml.git 6 | RUN git clone https://github.com/RedisLabs/spark-redis-ml.git 7 | RUN git clone https://github.com/shaynativ/spark.git 8 | 9 | RUN apt-get install -y maven default-jdk 10 | 11 | RUN cd jedis && mvn package -Dmaven.test.skip=true 12 | 13 | RUN cd jedis-ml && mkdir lib && cp ../jedis/target/jedis-3.0.0-SNAPSHOT.jar lib/ && mvn install 14 | 15 | RUN echo "deb http://dl.bintray.com/sbt/debian /" | tee -a /etc/apt/sources.list.d/sbt.list 16 | RUN apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2EE0EA64E40A89B84B2DF73499E82A75642AC823 17 | RUN apt-get -y update 18 | RUN apt-get install -y sbt 19 | 20 | RUN cd spark && mvn clean package -DskipTests=true 21 | 22 | WORKDIR /spark-redis-ml 23 | 24 | RUN mkdir lib &&\ 25 | cp /spark/mllib/target/spark-mllib_2.11-2.2.0-SNAPSHOT.jar lib/ &&\ 26 | cp ../jedis/target/jedis-3.0.0-SNAPSHOT.jar lib/ &&\ 27 | cp ../jedis-ml/target/jedis-ml-1.0-SNAPSHOT.jar lib/ 28 | 29 | RUN sbt assembly 30 | 31 | WORKDIR /spark-redis-ml/forest-example 32 | RUN mkdir lib && cp ../lib/* lib/ 33 | RUN cp ../target/scala-2.11/spark-redis-ml-assembly-0.1.0.jar lib/ 34 | RUN git pull 35 | RUN sbt package 36 | 37 | WORKDIR / 38 | 39 | RUN wget http://files.grouplens.org/datasets/movielens/ml-100k.zip &&\ 40 | unzip ml-100k.zip &&\ 41 | cp spark-redis-ml/scripts/gen_data.py ml-100k/ &&\ 42 | mkdir ml-100k/out &&\ 43 | cd ml-100k && ./gen_data.py &&\ 44 | /bin/sh -c 'for i in `seq 1 20`; do cp /ml-100k/out/$i /spark/data/mllib/; done' &&\ 45 | rm /ml-100k.zip && rm -rf /ml-100k 46 | 47 | WORKDIR /spark-redis-ml/forest-example 48 | CMD ["/spark/bin/spark-submit", "--master", "local[*]", "--jars", "lib/jedis-ml-1.0-SNAPSHOT.jar,lib/jedis-3.0.0-SNAPSHOT.jar,lib/spark-redis-ml-assembly-0.1.0.jar,lib/spark-mllib_2.11-2.2.0-SNAPSHOT.jar", "./target/scala-2.11/forestexample_2.11-0.1.0.jar", "/spark/data/mllib/10", "20"] 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/redis/ml/package.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.redis.ml 2 | 3 | import org.apache.spark.ml.tree 4 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel 5 | import redis.clients.jedis.Protocol.Command 6 | import redis.clients.jedis.{Jedis, _} 7 | import com.redislabs.client.redisml.MLClient 8 | import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, InternalNode} 9 | 10 | class Forest(trees: Array[DecisionTreeClassificationModel]) { 11 | 12 | private def subtreeToRedisString(n: org.apache.spark.ml.tree.Node, path: String = "."): String = { 13 | val prefix: String = s",${path}," 14 | n.getClass.getSimpleName match { 15 | case "InternalNode" => { 16 | val in = n.asInstanceOf[InternalNode] 17 | val splitStr = in.split match { 18 | case contSplit: ContinuousSplit => s"numeric,${in.split.featureIndex},${contSplit.threshold}" 19 | case catSplit: CategoricalSplit => s"categoric,${in.split.featureIndex}," + 20 | catSplit.leftCategories.mkString(":") 21 | } 22 | prefix + splitStr + subtreeToRedisString(in.leftChild, path + "l") + 23 | subtreeToRedisString(in.rightChild, path + "r") 24 | } 25 | case "LeafNode" => { 26 | prefix + s"leaf,${n.prediction}" + 27 | s",stats,${n.getImpurityStats.mkString(":")}" 28 | } 29 | } 30 | } 31 | 32 | private def toRedisString: String = { 33 | trees.zipWithIndex.map { case (tree, treeIndex) => 34 | s"${treeIndex}" + subtreeToRedisString(tree.rootNode, ".") 35 | }.fold("") { (a, b) => a + "\n" + b } 36 | } 37 | 38 | def toDebugArray: Array[String] = { 39 | toRedisString.split("\n").drop(1) 40 | } 41 | 42 | def loadToRedis(forestId: String = "test_forest", host: String = "localhost") { 43 | val jedis = new Jedis(host) 44 | val commands = toRedisString.split("\n").drop(1) 45 | jedis.getClient.sendCommand(Command.MULTI) 46 | jedis.getClient().getStatusCodeReply 47 | for (cmd <- commands) { 48 | val cmdArray = forestId +: cmd.split(",") 49 | jedis.getClient.sendCommand(MLClient.ModuleCommand.FOREST_ADD, cmdArray: _*) 50 | jedis.getClient().getStatusCodeReply 51 | } 52 | jedis.getClient.sendCommand(Command.EXEC) 53 | jedis.getClient.getMultiBulkReply 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /scripts/kmeans-example.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.mllib.clustering.{KMeans, KMeansModel} 2 | import org.apache.spark.mllib.linalg.Vectors 3 | import com.redislabs.client.redisml.MLClient 4 | import redis.clients.jedis.{Jedis, _} 5 | 6 | // Load and parse the data 7 | val data = sc.textFile("data/mllib/kmeans_data.txt") 8 | val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() 9 | 10 | // Cluster the data into two classes using KMeans 11 | val numClusters = 4 12 | val numIterations = 20 13 | val clusters = KMeans.train(parsedData, numClusters, numIterations) 14 | 15 | val cmd = "my_km_model" +: numClusters.toString +: "3" +: clusters.clusterCenters.map(x => x.toArray).flatten.mkString(",").split(",") 16 | val jedis = new Jedis("localhost") 17 | jedis.getClient.sendCommand(MLClient.ModuleCommand.KMEANS_SET, cmd: _*) 18 | jedis.getClient().getStatusCodeReply 19 | 20 | 21 | var redisRes = 0L 22 | var sparkRes = 0.0 23 | var rtotal = 0.0 24 | var stotal = 0.0 25 | var diffs = 0.0 26 | 27 | def benchmark(b: Int) { 28 | rtotal = 0.0 29 | stotal = 0.0 30 | diffs = 0.0 31 | val jedis = new Jedis("localhost") 32 | for (i <- 0 to b) { 33 | val rt0 = System.nanoTime() 34 | val cmd = Array("my_km_model", i.toString, i.toString, i.toString) 35 | jedis.getClient.sendCommand(MLClient.ModuleCommand.KMEANS_PREDICT, cmd: _*) 36 | redisRes = jedis.getClient().getIntegerReply 37 | val rt1 = System.nanoTime() 38 | println(cmd.mkString(", ")) 39 | println("Redis time: " + (rt1 - rt0) / 1000000.0 + "ms, res=" + redisRes) 40 | val v = Vectors.dense(i, i, i) 41 | val p = Seq(v) 42 | val rdd = sc.parallelize(p,1) 43 | val st0 = System.nanoTime() 44 | val rawSparkRes = clusters.predict(rdd).collect 45 | val st1 = System.nanoTime() 46 | sparkRes = rawSparkRes(0) 47 | println("Spark time: " + (st1 - st0) / 1000000.0 + "ms, res=" + sparkRes) 48 | println("---------------------------------------"); 49 | if (sparkRes - redisRes.toFloat != 0) { 50 | diffs += 1 51 | } 52 | rtotal += (rt1 - rt0) / 1000000.0 53 | stotal += (st1 - st0) / 1000000.0 54 | } 55 | println("Classification averages:") 56 | println(s"redis: ${rtotal / b.toFloat} ms") 57 | println(s"spark: ${stotal / b.toFloat} ms") 58 | println(s"ratio: ${stotal / rtotal}") 59 | println(s"diffs: $diffs") 60 | } 61 | 62 | benchmark(20) -------------------------------------------------------------------------------- /scripts/gen_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from collections import defaultdict 3 | 4 | OCCUPATION_FILE = "./u.occupation" 5 | GENRE_FILE = "./u.genre" 6 | ITEM_FILE = "./u.item" 7 | USER_FILE = "./u.user" 8 | RATINGS_FILE = "./u.data" 9 | 10 | occupations_fo = open(OCCUPATION_FILE, "r") 11 | genres_fo = open(GENRE_FILE, "r") 12 | items_fo = open(ITEM_FILE, "r") 13 | users_fo = open(USER_FILE, "r") 14 | ratings_fo = open(RATINGS_FILE, "r") 15 | 16 | gender_map = {'M':'0', 'F':'1'} 17 | user_ratings_map = defaultdict(lambda: {}) 18 | item_raters = defaultdict(lambda: []) 19 | user_genres_avg = defaultdict(lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 20 | user_rating_count = defaultdict(lambda: 0) 21 | 22 | def main(): 23 | occupations_map = dict((x[:-1],str(i+1)) for i,x in enumerate(occupations_fo.readlines())) 24 | occupations_fo.close() 25 | 26 | item_genres_map = dict((x.split('|')[0], x[:-1].split('|')[5:]) for x in items_fo.readlines()) 27 | items_fo.close() 28 | 29 | # add user info to user_ratings_map. format: uid|age|gender|occupation|zip(excluded) 30 | for x in users_fo.readlines(): 31 | x = x[:-1].split('|') 32 | uid = int(x[0]) 33 | # 1700 = age, 1701 = gender, 1702 = occupation 34 | user_ratings_map[uid][1700] = x[1] 35 | user_ratings_map[uid][1701] = gender_map[x[2]] 36 | user_ratings_map[uid][1702] = occupations_map[x[3]] 37 | 38 | # add item ratings, count user events, and accumulate genres info. format: user,item,rating 39 | # item range is 1 to 1650 40 | for x in ratings_fo.readlines(): 41 | x = x[:-1].split('\t') 42 | uid = int(x[0]) 43 | item = int(x[1]) 44 | rating = int(x[2]) 45 | item_raters[int(x[1])].append(uid) 46 | user_rating_count[uid] = user_rating_count[uid] + 1 47 | user_ratings_map[uid][item] = rating 48 | user_genres_avg[uid] = [(int(x) * rating + user_genres_avg[uid][i]) for i,x in enumerate(item_genres_map[x[1]])] 49 | 50 | # calculate genre avg per user. 1800 = count, 1801 - 1819 = genres 51 | for uid, count in user_rating_count.iteritems(): 52 | user_ratings_map[uid][1800] = count 53 | factor = count if count > 1 else 1 54 | for i,x in enumerate(user_genres_avg[uid]): 55 | user_ratings_map[uid][1801+i] = "{:.2f}".format(1.0*x/factor) 56 | # print(user_ratings_map[uid]) 57 | 58 | # generate rating file for each item 59 | for item, users in item_raters.iteritems(): 60 | f = open("./out/{}".format(item),'w') 61 | line = '' 62 | for uid in users: 63 | map = user_ratings_map[uid] 64 | zero_val = map.pop(item) #not using NULL, all keys must be present. 65 | line = line + str(zero_val) 66 | for k,v in sorted(map.iteritems()): #libsvm requires sorting 67 | line = line + " {}:{}".format(k,v) 68 | line = line + "\n" 69 | f.write(line) 70 | f.close() 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /scripts/new-ml-forest-example.scala: -------------------------------------------------------------------------------- 1 | import scala.collection.mutable 2 | import scala.language.reflectiveCalls 3 | import org.apache.spark.ml.{Pipeline, PipelineStage} 4 | import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} 5 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} 8 | import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} 9 | import org.apache.spark.mllib.util.MLUtils 10 | import org.apache.spark.sql.{SparkSession, _} 11 | import redis.clients.jedis.Protocol.Command 12 | import redis.clients.jedis.{Jedis, _} 13 | import com.redislabs.client.redisml.MLClient 14 | import com.redislabs.provider.redis.ml.Forest 15 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 16 | 17 | : power 18 | vals.isettings 19 | vals.isettings.maxPrintString = 255 20 | 21 | 22 | 23 | // Load and parse the data file, converting it to a DataFrame. 24 | //val data = spark.read.format("libsvm").load("data/mllib/small_test_10L_2F_np") 25 | val data = spark.read.format("libsvm").load("data/mllib/10") 26 | 27 | // Index labels, adding metadata to the label column. 28 | // Fit on whole dataset to include all labels in index. 29 | val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) 30 | // Automatically identify categorical features, and index them. 31 | // Set maxCategories so features with > 4 distinct values are treated as continuous. 32 | val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(20).fit(data) 33 | 34 | // Split the data into training and test sets (30% held out for testing). 35 | val Array(trainingData, test) = data.randomSplit(Array(0.8, 0.2)) 36 | 37 | // Train a RandomForest model. 38 | val rf = new RandomForestClassifier().setFeatureSubsetStrategy("all").setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(20) 39 | 40 | // Convert indexed labels back to original labels. 41 | val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) 42 | 43 | // Chain indexers and forest in a Pipeline. 44 | val pipeline = new org.apache.spark.ml.Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) 45 | 46 | // Train model. This also runs the indexers. 47 | val model = pipeline.fit(trainingData) 48 | 49 | // Make predictions. 50 | val predictions = model.transform(test) 51 | 52 | // Select example rows to display. 53 | predictions.select("predictedLabel", "label", "features").show(5) 54 | 55 | // Select (prediction, true label) and compute test error. 56 | val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") 57 | val accuracy = evaluator.evaluate(predictions) 58 | println("Test Error = " + (1.0 - accuracy)) 59 | 60 | val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] 61 | println("Learned classification forest model:\n" + rfModel.toDebugString) 62 | 63 | val f = new Forest(rfModel.trees) 64 | f.loadToRedis("forest-test", "localhost") 65 | 66 | val localData = featureIndexer.transform(test).collect 67 | 68 | def makeInputString(i: Int): String = { 69 | val sparseRecord = localData(i)(2).asInstanceOf[org.apache.spark.ml.linalg.SparseVector] 70 | val indices = sparseRecord.indices 71 | val values = sparseRecord.values 72 | var sep = "" 73 | var inputStr = "" 74 | for (i <- 0 to ((indices.length - 1))) { 75 | inputStr = inputStr + sep + indices(i).toString + ":" + values(i).toString 76 | sep = "," 77 | } 78 | inputStr 79 | } 80 | 81 | def makeDF(i: Int): org.apache.spark.sql.DataFrame = { 82 | test.sqlContext.createDataFrame(sc.parallelize(test.take(i + 1).slice(i, i + 1)), test.schema) 83 | } 84 | 85 | var redisRes = "" 86 | var sparkRes = 0.0 87 | var rtotal = 0.0 88 | var stotal = 0.0 89 | var diffs = 0.0 90 | def benchmark(b: Int) { 91 | rtotal = 0.0 92 | stotal = 0.0 93 | diffs = 0.0 94 | val jedis = new Jedis("localhost") 95 | for (i <- 0 to b) { 96 | val rt0 = System.nanoTime() 97 | jedis.getClient.sendCommand(MLClient.ModuleCommand.FOREST_RUN, "forest-test", makeInputString(i)) 98 | redisRes = jedis.getClient().getStatusCodeReply 99 | val rt1 = System.nanoTime() 100 | println("Redis time: " + (rt1 - rt0) / 1000000.0 + "ms, res=" + redisRes) 101 | val df = makeDF(i) 102 | val st0 = System.nanoTime() 103 | val rawSparkRes = model.transform(df) 104 | val st1 = System.nanoTime() 105 | sparkRes = rawSparkRes.select("prediction").asInstanceOf[org.apache.spark.sql.DataFrame].take(1)(0)(0).asInstanceOf[Double] 106 | println("Spark time: " + (st1 - st0) / 1000000.0 + "ms, res=" + sparkRes) 107 | println("---------------------------------------"); 108 | if (sparkRes - redisRes.toFloat != 0) { 109 | diffs += 1 110 | } 111 | rtotal += (rt1 - rt0) / 1000000.0 112 | stotal += (st1 - st0) / 1000000.0 113 | } 114 | println("Classification averages:") 115 | println(s"redis: ${rtotal / b.toFloat} ms") 116 | println(s"spark: ${stotal / b.toFloat} ms") 117 | println(s"ratio: ${stotal / rtotal}") 118 | println(s"diffs: $diffs") 119 | } 120 | 121 | vals.isettings.maxPrintString = Int.MaxValue 122 | 123 | def dbt(i: Int) = { 124 | rfModel.trees(i).toDebugString 125 | } 126 | 127 | def makeFV(i: Int):org.apache.spark.ml.linalg.SparseVector = {makeDF(i).take(1)(0)(1).asInstanceOf[org.apache.spark.ml.linalg.SparseVector]} 128 | 129 | 130 | -------------------------------------------------------------------------------- /forest-example/src/main/scala/ForestTest.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.SparkContext 2 | import org.apache.spark.SparkContext._ 3 | import org.apache.spark.SparkConf 4 | import scala.collection.mutable 5 | import scala.language.reflectiveCalls 6 | import org.apache.spark.ml.{Pipeline, PipelineStage} 7 | import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} 8 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} 9 | import org.apache.spark.ml.linalg.Vector 10 | import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} 11 | import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} 12 | import org.apache.spark.mllib.util.MLUtils 13 | import org.apache.spark.sql.{SparkSession, _} 14 | import redis.clients.jedis.Protocol.Command 15 | import redis.clients.jedis.{Jedis, _} 16 | import com.redislabs.client.redisml.MLClient 17 | import com.redislabs.provider.redis.ml.Forest 18 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 19 | 20 | object ForestTest { 21 | def main(args: Array[String]) { 22 | val conf = new SparkConf().setAppName("Forest Example") 23 | val sc = new SparkContext(conf) 24 | sc.setLogLevel("WARN") 25 | 26 | val movieId = args(0).split("/").last 27 | val nTrees = args(1).toInt 28 | val spark = SparkSession 29 | .builder 30 | .getOrCreate() 31 | // Load and parse the data file, converting it to a DataFrame. 32 | //val data = spark.read.format("libsvm").load("data/mllib/small_test_10L_2F_np") 33 | val data = spark.read.format("libsvm").load(args(0)) 34 | 35 | // Index labels, adding metadata to the label column. 36 | // Fit on whole dataset to include all labels in index. 37 | val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) 38 | // Automatically identify categorical features, and index them. 39 | // Set maxCategories so features with > 4 distinct values are treated as continuous. 40 | val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(20).fit(data) 41 | 42 | // Split the data into training and test sets (30% held out for testing). 43 | val Array(trainingData, test) = data.randomSplit(Array(0.8, 0.2)) 44 | 45 | // Train a RandomForest model. 46 | val rf = new RandomForestClassifier().setFeatureSubsetStrategy("all").setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(nTrees) 47 | 48 | // Convert indexed labels back to original labels. 49 | val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) 50 | 51 | // Chain indexers and forest in a Pipeline. 52 | val pipeline = new org.apache.spark.ml.Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) 53 | 54 | // Train model. This also runs the indexers. 55 | val model = pipeline.fit(trainingData) 56 | 57 | // Make predictions. 58 | val predictions = model.transform(test) 59 | 60 | // Select example rows to display. 61 | predictions.select("predictedLabel", "label", "features").show(5) 62 | 63 | // Select (prediction, true label) and compute test error. 64 | val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") 65 | val accuracy = evaluator.evaluate(predictions) 66 | println("Test Error = " + (1.0 - accuracy)) 67 | 68 | val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] 69 | println("Learned classification forest model:\n" + rfModel.toDebugString) 70 | 71 | val f = new Forest(rfModel.trees) 72 | f.loadToRedis(s"movie-${movieId}", "localhost") 73 | 74 | val localData = featureIndexer.transform(test).collect 75 | 76 | def makeInputString(i: Int): String = { 77 | val sparseRecord = localData(i)(2).asInstanceOf[org.apache.spark.ml.linalg.SparseVector] 78 | val indices = sparseRecord.indices 79 | val values = sparseRecord.values 80 | var sep = "" 81 | var inputStr = "" 82 | for (i <- 0 to ((indices.length - 1))) { 83 | inputStr = inputStr + sep + indices(i).toString + ":" + values(i).toString 84 | sep = "," 85 | } 86 | inputStr 87 | } 88 | 89 | def makeDF(i: Int): org.apache.spark.sql.DataFrame = { 90 | test.sqlContext.createDataFrame(sc.parallelize(test.take(i + 1).slice(i, i + 1)), test.schema) 91 | } 92 | 93 | var redisRes = "" 94 | var sparkRes = 0.0 95 | var rtotal = 0.0 96 | var stotal = 0.0 97 | var diffs = 0.0 98 | def benchmark(b: Int) { 99 | rtotal = 0.0 100 | stotal = 0.0 101 | diffs = 0.0 102 | val jedis = new Jedis("localhost") 103 | for (i <- 0 to b) { 104 | val rt0 = System.nanoTime() 105 | jedis.getClient.sendCommand(MLClient.ModuleCommand.FOREST_RUN, s"movie-${movieId}", makeInputString(i)) 106 | redisRes = jedis.getClient().getStatusCodeReply 107 | val rt1 = System.nanoTime() 108 | println("Redis time: " + (rt1 - rt0) / 1000000.0 + "ms, res=" + redisRes) 109 | val df = makeDF(i) 110 | val st0 = System.nanoTime() 111 | val rawSparkRes = model.transform(df) 112 | val st1 = System.nanoTime() 113 | sparkRes = rawSparkRes.select("prediction").asInstanceOf[org.apache.spark.sql.DataFrame].take(1)(0)(0).asInstanceOf[Double] 114 | println("Spark time: " + (st1 - st0) / 1000000.0 + "ms, res=" + sparkRes) 115 | println("---------------------------------------"); 116 | if (sparkRes - redisRes.toFloat != 0) { 117 | diffs += 1 118 | } 119 | rtotal += (rt1 - rt0) / 1000000.0 120 | stotal += (st1 - st0) / 1000000.0 121 | } 122 | println("Classification averages:") 123 | println(s"redis: ${rtotal / b.toFloat} ms") 124 | println(s"spark: ${stotal / b.toFloat} ms") 125 | println(s"ratio: ${stotal / rtotal}") 126 | println(s"diffs: $diffs") 127 | } 128 | 129 | benchmark(30) 130 | sc.stop() 131 | } 132 | } -------------------------------------------------------------------------------- /forest-example/add_user_features.py: -------------------------------------------------------------------------------- 1 | import redis 2 | config = {"host":"localhost", "port":6379} 3 | r = redis.StrictRedis(**config) 4 | r.set("user-1-profile","12:1.0,13:1.0,14:3.0,15:1.0,17:1.0,18:1.0,19:1.0,20:1.0,23:1.0,24:5.0,92:1.0,99:3.0,102:1.0,103:1.0,104:1.0,105:2.0,106:1.0,107:1.0,108:1.0,110:3.0,111:1.0,115:1.0,116:2.0,117:2.0,119:1.0,120:4.0,121:2.0,122:2.0,123:1.0,124:2.0,125:1.0,128:2.0,129:1.0,136:1.0,145:1.0,146:1.0,147:2.0,148:1.0,149:1.0,150:2.0,219:3.0,220:1.0,221:4.0,223:1.0,224:3.0,234:1.0,235:1.0,236:5.0,239:1.0,241:1.0,242:1.0,244:2.0,250:1.0,255:1.0,257:3.0,258:1.0,259:1.0,260:1.0,261:2.0,262:1.0,263:1.0,265:1.0,267:1.0,268:1.0,269:3.0,272:1.0,273:4.0,274:2.0,275:2.0,276:1.0,277:2.0,278:1.0,279:4.0,280:2.0,281:4.0,282:2.0,283:2.0,284:2.0,285:1.0,286:1.0,287:4.0,288:4.0,289:2.0,290:2.0,291:1.0,293:2.0,298:1.0,299:3.0,300:1.0,301:1.0,302:1.0,303:1.0,304:1.0,305:1.0,306:1.0,307:1.0,318:3.0,320:2.0,321:1.0,322:2.0,323:1.0,324:2.0,325:1.0,326:2.0,327:3.0,328:1.0,329:1.0,330:1.0,331:2.0,332:3.0,333:1.0,334:1.0,335:2.0,336:1.0,357:2.0,358:1.0,359:1.0,362:1.0,367:1.0,368:3.0,369:2.0,404:4.0,405:1.0,406:2.0,407:1.0,408:1.0,409:1.0,410:3.0,411:2.0,412:2.0,423:1.0,454:1.0,455:1.0,456:1.0,457:3.0,458:1.0,459:1.0,470:1.0,471:1.0,472:1.0,474:2.0,475:4.0,476:1.0,507:2.0,543:1.0,545:2.0,546:1.0,590:3.0,592:1.0,594:1.0,595:3.0,596:2.0,597:1.0,618:3.0,619:1.0,627:2.0,675:3.0,677:2.0,679:1.0,680:1.0,681:4.0,682:1.0,684:1.0,686:1.0,687:1.0,689:3.0,695:1.0,712:1.0,716:1.0,717:1.0,739:1.0,740:1.0,741:3.0,742:1.0,743:2.0,747:1.0,748:1.0,755:1.0,757:1.0,759:1.0,761:1.0,762:1.0,763:1.0,765:1.0,766:1.0,812:1.0,814:2.0,817:1.0,818:2.0,819:1.0,822:2.0,823:1.0,824:1.0,825:1.0,826:2.0,827:1.0,828:1.0,830:1.0,831:1.0,832:1.0,833:1.0,839:1.0,840:1.0,843:1.0,844:3.0,845:2.0,846:1.0,863:1.0,865:1.0,869:1.0,870:1.0,871:1.0,872:1.0,873:1.0,874:3.0,875:1.0,876:1.0,877:1.0,878:1.0,879:1.0,880:1.0,881:1.0,882:1.0,883:1.0,884:1.0,885:1.0,886:1.0,918:1.0,919:1.0,921:1.0,923:2.0,924:1.0,925:1.0,926:1.0,927:3.0,928:1.0,929:1.0,930:1.0,931:1.0,932:1.0,933:2.0,936:2.0,937:1.0,947:1.0,949:1.0,951:1.0,973:3.0,974:1.0,975:1.0,976:1.0,977:1.0,978:1.0,979:1.0,980:1.0,981:1.0,982:1.0,983:1.0,984:1.0,985:1.0,987:2.0,988:1.0,989:1.0,990:1.0,994:1.0,1000:1.0,1001:1.0,1007:1.0,1008:1.0,1009:1.0,1010:1.0,1014:1.0,1016:1.0,1021:1.0,1024:1.0,1025:1.0,1027:2.0,1032:1.0,1033:1.0,1037:1.0,1039:1.0,1046:2.0,1047:2.0,1048:1.0,1050:1.0,1051:1.0,1053:2.0,1056:1.0,1058:1.0,1059:1.0,1060:1.0,1066:1.0,1067:1.0,1078:1.0,1080:1.0,1083:1.0,1084:1.0,1085:1.0,1086:1.0,1092:1.0,1093:1.0,1094:1.0,1096:1.0,1101:1.0,1113:1.0,1114:1.0,1116:2.0,1119:1.0,1127:1.0,1128:1.0,1131:1.0,1133:2.0,1136:1.0,1149:1.0,1150:1.0,1151:2.0,1160:1.0,1161:1.0,1162:2.0,1163:2.0,1164:1.0,1170:1.0,1172:1.0,1173:1.0,1186:1.0,1196:1.0,1197:1.0,1198:1.0,1201:1.0,1214:1.0,1241:1.0,1244:1.0,1251:1.0,1254:1.0,1258:1.0,1264:1.0,1271:1.0,1275:1.0,1276:1.0,1279:1.0,1280:1.0,1281:1.0,1283:1.0,1286:1.0,1287:1.0,1288:1.0,1290:1.0,1294:1.0,1295:1.0,1301:1.0,1311:1.0,1316:1.0,1317:1.0,1318:1.0,1319:1.0,1320:1.0,1321:1.0,1322:1.0,1323:1.0,1324:1.0,1325:1.0,1326:1.0,1327:1.0,1328:1.0,1329:1.0,1330:1.0,1331:1.0,1332:1.0,1333:1.0,1334:1.0,1335:1.0,1336:1.0,1337:1.0,1338:1.0,1339:1.0,1340:1.0,1341:1.0,1342:1.0,1343:1.0,1344:1.0,1345:1.0,1346:1.0,1347:1.0,1348:1.0,1349:1.0,1350:1.0,1351:1.0,1352:1.0,1353:1.0,1354:1.0,1355:1.0,1356:1.0,1357:1.0,1358:1.0,1359:1.0,1360:1.0,1361:1.0,1362:1.0,1363:1.0,1364:1.0,1365:1.0,1366:1.0,1367:1.0,1368:1.0,1369:1.0,1370:1.0,1371:1.0,1372:1.0,1373:1.0,1374:1.0,1375:1.0,1376:1.0,1377:1.0,1378:1.0,1379:1.0,1380:1.0,1381:1.0,1382:1.0,1383:1.0,1384:1.0,1385:1.0,1386:1.0,1387:1.0,1388:1.0,1389:1.0,1390:1.0,1391:1.0,1392:1.0,1393:1.0,1394:1.0,1699:26.0,1701:6.0,1799:435.0,1801:0.2,1802:0.11,1803:0.04,1804:0.09,1805:0.43,1806:0.09,1807:3.0,1808:0.67,1809:2.0,1810:0.01,1811:0.06,1812:0.04,1813:0.07,1814:0.24,1815:0.09,1816:0.32,1817:0.06") 5 | r.set("user-2-profile","4:5,5:4,7:4,8:5,9:4,11:5,12:5,15:3,17:4,19:3,21:3,22:4,23:5,24:4,25:4,28:3,30:4,31:3,32:5,42:4,44:4,45:4,47:4,48:4,49:3,50:5,54:2,55:3,56:5,58:3,59:4,60:3,61:3,64:4,65:3,66:4,68:4,69:2,70:4,71:4,72:4,73:3,74:4,77:3,79:4,81:5,82:4,85:4,87:4,88:4,89:5,91:4,92:4,95:4,96:4,97:1,98:3,99:4,100:5,107:4,109:3,116:4,117:3,118:3,121:3,122:4,123:3,124:4,127:4,129:5,131:4,132:3,133:3,134:5,135:5,139:3,141:3,143:4,144:3,147:3,148:3,151:4,152:5,153:5,154:4,156:4,157:5,160:4,161:3,162:4,163:4,164:4,165:3,166:3,168:4,169:5,170:3,171:4,172:4,174:4,175:5,176:4,177:5,178:4,179:4,180:5,181:4,182:5,183:4,184:4,185:4,186:4,187:5,191:4,192:5,193:3,194:5,195:5,196:3,197:3,198:3,199:4,200:5,201:5,202:4,203:5,204:4,205:3,208:4,209:4,210:4,211:4,213:4,214:2,215:3,216:3,218:5,219:3,223:4,226:3,230:4,231:3,233:3,234:3,235:3,237:3,238:5,239:3,241:4,248:4,254:2,255:4,257:4,259:3,264:2,265:3,273:2,274:3,275:4,276:4,283:3,284:4,285:5,288:4,291:3,293:4,294:3,295:3,298:5,309:1,313:3,318:4,319:4,321:3,322:2,356:3,357:4,365:3,367:4,371:3,378:3,382:4,385:4,392:4,393:4,396:4,402:4,403:4,404:3,408:5,410:4,411:4,417:3,419:4,420:4,423:5,425:4,427:4,428:5,429:4,430:4,432:4,433:5,434:4,435:4,436:4,443:3,447:4,448:3,449:3,452:2,455:4,461:4,463:4,466:5,467:4,469:5,471:3,472:2,473:3,475:4,477:4,479:5,480:4,481:4,482:5,483:3,484:3,485:3,486:4,487:4,488:4,490:4,492:3,493:3,494:5,495:3,496:3,498:5,499:3,501:4,502:5,504:4,505:3,506:4,507:3,509:4,510:3,511:5,512:5,513:3,514:4,515:3,516:4,517:4,519:4,520:4,521:3,522:3,523:4,525:5,526:3,528:3,530:4,531:4,537:4,546:3,550:4,558:4,559:4,566:4,567:4,568:5,569:3,578:2,579:3,581:4,582:3,583:4,584:4,588:5,589:4,591:3,597:3,602:4,603:5,605:4,607:3,609:4,610:4,611:4,613:4,614:3,615:3,616:2,618:4,628:3,629:4,632:3,633:4,634:4,637:3,640:4,641:4,642:5,646:5,648:4,649:4,653:5,654:5,655:4,656:3,657:4,659:3,660:3,661:4,663:5,664:5,665:4,671:4,673:4,675:4,678:3,679:4,684:3,686:4,692:3,693:3,699:4,705:5,708:4,709:3,712:4,715:5,729:3,732:4,736:3,739:4,741:4,742:4,746:4,747:3,755:3,770:4,778:3,792:3,802:3,805:4,806:4,811:4,822:4,824:3,825:4,826:3,842:3,843:3,848:4,853:5,856:4,863:3,921:4,928:4,942:3,945:4,959:3,962:4,965:4,966:3,968:4,1006:4,1019:4,1021:4,1028:2,1045:4,1046:4,1047:3,1065:5,1073:3,1074:3,1118:4,1121:3,1126:3,1135:4,1140:4,1147:4,1154:2,1169:5,1197:4,1211:3,1252:3,1286:3,1404:4,1411:4,1421:4,1456:4,1515:4,1700:60,1701:0,1702:16,1800:397,1801:0.00,1802:0.69,1803:0.39,1804:0.12,1805:0.23,1806:1.23,1807:0.37,1808:0.06,1809:1.52,1810:0.06,1811:0.14,1812:0.25,1813:0.27,1814:0.18,1815:0.61,1816:0.39,1817:0.77,1818:0.28,1819:0.09") 6 | -------------------------------------------------------------------------------- /scripts/ml-forest-example.scala: -------------------------------------------------------------------------------- 1 | import scala.collection.mutable 2 | import scala.language.reflectiveCalls 3 | import org.apache.spark.ml.{Pipeline, PipelineStage} 4 | import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} 5 | import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} 8 | import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} 9 | import org.apache.spark.mllib.util.MLUtils 10 | import org.apache.spark.sql.{SparkSession, _} 11 | import redis.clients.jedis.Protocol.Command 12 | import redis.clients.jedis.{Jedis, _} 13 | import com.redislabs.client.redisml.MLClient 14 | import com.redislabs.provider.redis.ml.Forest 15 | 16 | /** Load a dataset from the given path, using the given format */ 17 | def loadData( 18 | spark: SparkSession, 19 | path: String, 20 | format: String, 21 | expectedNumFeatures: Option[Int] = None): DataFrame = { 22 | import spark.implicits._ 23 | 24 | format match { 25 | case "dense" => MLUtils.loadLabeledPoints(spark.sparkContext, path).toDF() 26 | case "libsvm" => expectedNumFeatures match { 27 | case Some(numFeatures) => spark.read.option("numFeatures", numFeatures.toString) 28 | .format("libsvm").load(path) 29 | case None => spark.read.format("libsvm").load(path) 30 | } 31 | case _ => throw new IllegalArgumentException(s"Bad data format: $format") 32 | } 33 | } 34 | 35 | def loadDatasets( 36 | input: String, 37 | dataFormat: String, 38 | testInput: String, 39 | algo: String, 40 | fracTest: Double): (DataFrame, DataFrame) = { 41 | val spark = SparkSession 42 | .builder 43 | .getOrCreate() 44 | 45 | // Load training data 46 | val origExamples: DataFrame = loadData(spark, input, dataFormat) 47 | 48 | // Load or create test set 49 | val dataframes: Array[DataFrame] = if (testInput != "") { 50 | // Load testInput. 51 | val numFeatures = origExamples.first().getAs[Vector](1).size 52 | val origTestExamples: DataFrame = 53 | loadData(spark, testInput, dataFormat, Some(numFeatures)) 54 | Array(origExamples, origTestExamples) 55 | } else { 56 | // Split input into training, test. 57 | origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) 58 | } 59 | 60 | val training = dataframes(0).cache() 61 | val test = dataframes(1).cache() 62 | 63 | val numTraining = training.count() 64 | val numTest = test.count() 65 | val numFeatures = training.select("features").first().getAs[Vector](0).size 66 | println("Loaded data:") 67 | println(s" numTraining = $numTraining, numTest = $numTest") 68 | println(s" numFeatures = $numFeatures") 69 | 70 | (training, test) 71 | } 72 | 73 | case class Params( 74 | // input: String = "file:///root/spark/data/mllib/sample_libsvm_data.txt", 75 | input: String = "data/mllib/10", 76 | testInput: String = "", 77 | dataFormat: String = "libsvm", 78 | algo: String = "classification", 79 | // algo: String = "regression", 80 | maxDepth: Int = 5, 81 | maxBins: Int = 32, 82 | minInstancesPerNode: Int = 1, 83 | minInfoGain: Double = 0.0, 84 | //numTrees: Int = 2000, 85 | numTrees: Int = 5, 86 | featureSubsetStrategy: String = "auto", 87 | fracTest: Double = 0.2, 88 | cacheNodeIds: Boolean = false, 89 | checkpointDir: Option[String] = None, 90 | checkpointInterval: Int = 10 91 | ) 92 | 93 | val params = Params() 94 | 95 | sc.setLogLevel("WARN") 96 | params.checkpointDir.foreach(sc.setCheckpointDir) 97 | val algo = params.algo.toLowerCase 98 | 99 | println(s"RandomForestExample with parameters:\n$params") 100 | 101 | // Load training and test data and cache it. 102 | val (training: DataFrame, test: DataFrame) = loadDatasets(params.input, 103 | params.dataFormat, params.testInput, algo, params.fracTest) 104 | 105 | // Set up Pipeline. 106 | val stages = new mutable.ArrayBuffer[PipelineStage]() 107 | // (1) For classification, re-index classes. 108 | val labelColName = if (algo == "classification") "indexedLabel" else "label" 109 | if (algo == "classification") { 110 | val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol(labelColName) 111 | stages += labelIndexer 112 | } 113 | // (2) Identify categorical features using VectorIndexer. 114 | // Features with more than maxCategories values will be treated as continuous. 115 | val featuresIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(20) 116 | stages += featuresIndexer 117 | // (3) Learn Random Forest. 118 | val dt = new RandomForestClassifier(). 119 | setFeaturesCol("indexedFeatures"). 120 | setLabelCol(labelColName). 121 | setMaxDepth(params.maxDepth). 122 | setMaxBins(params.maxBins). 123 | setMinInstancesPerNode(params.minInstancesPerNode). 124 | setMinInfoGain(params.minInfoGain). 125 | setCacheNodeIds(params.cacheNodeIds). 126 | setCheckpointInterval(params.checkpointInterval). 127 | setFeatureSubsetStrategy(params.featureSubsetStrategy). 128 | setNumTrees(params.numTrees) 129 | 130 | stages += dt 131 | val pipeline = new org.apache.spark.ml.Pipeline().setStages(stages.toArray) 132 | 133 | // Fit the Pipeline. 134 | val startTime = System.nanoTime() 135 | val pipelineModel = pipeline.fit(training) 136 | val elapsedTime = (System.nanoTime() - startTime) / 1e9 137 | println(s"Training time: $elapsedTime seconds") 138 | 139 | val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel] 140 | if (rfModel.totalNumNodes < 30) { 141 | println(rfModel.toDebugString) // Print full model. 142 | } else { 143 | println(rfModel) // Print model summary. 144 | } 145 | 146 | val f = new Forest(rfModel.trees) 147 | f.loadToRedis("forest-test", "localhost") 148 | 149 | val localData = test.collect 150 | 151 | def makeInputString(i: Int): String = { 152 | val sparseRecord = localData(i)(1).asInstanceOf[org.apache.spark.ml.linalg.SparseVector] 153 | val indices = sparseRecord.indices 154 | val values = sparseRecord.values 155 | var sep = "" 156 | var inputStr = "" 157 | for (i <- 0 to ((indices.length - 1))) { 158 | inputStr = inputStr + sep + indices(i).toString + ":" + values(i).toString 159 | sep = "," 160 | } 161 | inputStr 162 | } 163 | 164 | def makeDF(i: Int): org.apache.spark.sql.DataFrame = { 165 | test.sqlContext.createDataFrame (sc.parallelize (test.take (i).slice (i - 1, i) ), test.schema) 166 | } 167 | 168 | val trans = pipelineModel.transform(test).collect() 169 | 170 | var redisRes = "" 171 | var sparkRes = 0.0 172 | var rtotal = 0.0 173 | var stotal = 0.0 174 | 175 | def benchmark(b: Int) { 176 | rtotal = 0.0 177 | stotal = 0.0 178 | val jedis = new Jedis("localhost") 179 | for (i <- 0 to b) { 180 | val rt0 = System.nanoTime() 181 | jedis.getClient.sendCommand(MLClient.ModuleCommand.FOREST_RUN, "forest-test", makeInputString(i)) 182 | // print("forest-test", makeInputString(i)) 183 | redisRes = jedis.getClient().getStatusCodeReply 184 | val rt1 = System.nanoTime() 185 | // println("Redis time: " + (rt1 - rt0) / 1000000.0 + "ms, res=" + redisRes) 186 | println("res = " + redisRes + "::" + trans(i)(6)) 187 | // val st0 = System.nanoTime() 188 | // sparkRes = rfModel.predict(localData(i).features) 189 | // val st1 = System.nanoTime() 190 | // println("Spark time: " + (st1 - st0) / 1000000.0 + "ms, res=" + sparkRes.toInt) 191 | println("---------------------------------------"); 192 | rtotal += (rt1 - rt0) / 1000000.0 193 | // stotal += (st1 - st0) / 1000000.0 194 | } 195 | println("Classification averages:") 196 | println("redis:" + rtotal / b.toFloat + "ms") 197 | // println("spark:" + stotal/b.toFloat+ "ms") 198 | } 199 | 200 | :power 201 | vals.isettings 202 | vals.isettings.maxPrintString = Int.MaxValue 203 | 204 | def dbt(i: Int) = {rfModel.trees(i).toDebugString} 205 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 16 | 17 | Scalastyle standard configuration 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | --------------------------------------------------------------------------------