├── .gitignore ├── README.md ├── build.sbt ├── data ├── import_eventserver.py └── send_query.py ├── engine.json ├── project ├── assembly.sbt └── pio-build.sbt ├── src └── main │ └── scala │ ├── ALSAlgorithm.scala │ ├── DataSource.scala │ ├── Engine.scala │ ├── Preparator.scala │ └── Serving.scala └── template.json /.gitignore: -------------------------------------------------------------------------------- 1 | manifest.json 2 | target/ 3 | pio.log 4 | /pio.sbt 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Product Ranking Engine Template 2 | 3 | ## Documentation 4 | 5 | Please refer to http://predictionio.incubator.apache.org/templates/productranking/quickstart/ 6 | 7 | ## Version 8 | 9 | ### v0.3.0 10 | 11 | - update for PredictionIO 0.9.2, including: 12 | 13 | - use new PEventStore API 14 | - use appName in DataSource parameter 15 | 16 | ### v0.2.0 17 | 18 | - update build.sbt and template.json for PredictionIO 0.9.2 19 | 20 | ### v0.1.1 21 | 22 | - update for PredictionIO 0.9.0 23 | 24 | ### v0.1.0 25 | 26 | - initial version 27 | 28 | 29 | ## Development Notes 30 | 31 | ### import sample data 32 | 33 | ``` 34 | $ python data/import_eventserver.py --access_key 35 | ``` 36 | 37 | ### query 38 | 39 | normal: 40 | 41 | ``` 42 | curl -H "Content-Type: application/json" \ 43 | -d '{ "user": "u2", "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"]}' \ 44 | http://localhost:8000/queries.json \ 45 | -w %{time_connect}:%{time_starttransfer}:%{time_total} 46 | ``` 47 | 48 | unknown user: 49 | 50 | ``` 51 | curl -H "Content-Type: application/json" \ 52 | -d '{ "user": "unknown_user", "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"]}' \ 53 | http://localhost:8000/queries.json \ 54 | -w %{time_connect}:%{time_starttransfer}:%{time_total} 55 | ``` 56 | 57 | unknown item: 58 | 59 | ``` 60 | curl -H "Content-Type: application/json" \ 61 | -d '{ "user": "u3", "items": ["unk1", "i3", "i10", "i2", "i9"]}' \ 62 | http://localhost:8000/queries.json \ 63 | -w %{time_connect}:%{time_starttransfer}:%{time_total} 64 | ``` 65 | 66 | all unknown items: 67 | 68 | ``` 69 | curl -H "Content-Type: application/json" \ 70 | -d '{ "user": "u4", "items": ["unk1", "unk2", "unk3", "unk4"]}' \ 71 | http://localhost:8000/queries.json \ 72 | -w %{time_connect}:%{time_starttransfer}:%{time_total} 73 | ``` 74 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import AssemblyKeys._ 2 | 3 | assemblySettings 4 | 5 | name := "template-scala-parallel-productranking" 6 | 7 | organization := "io.prediction" 8 | 9 | libraryDependencies ++= Seq( 10 | "io.prediction" %% "core" % pioVersion.value % "provided", 11 | "org.apache.spark" %% "spark-core" % "1.3.0" % "provided", 12 | "org.apache.spark" %% "spark-mllib" % "1.3.0" % "provided") 13 | -------------------------------------------------------------------------------- /data/import_eventserver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import sample data for similar product engine 3 | """ 4 | 5 | import predictionio 6 | import argparse 7 | import random 8 | 9 | SEED = 7 10 | 11 | def import_events(client): 12 | random.seed(SEED) 13 | count = 0 14 | print client.get_status() 15 | print "Importing data..." 16 | 17 | # generate 10 users, with user ids u1,u2,....,u10 18 | user_ids = ["u%s" % i for i in range(1, 11)] 19 | for user_id in user_ids: 20 | print "Set user", user_id 21 | client.create_event( 22 | event="$set", 23 | entity_type="user", 24 | entity_id=user_id 25 | ) 26 | count += 1 27 | 28 | # generate 50 items, with item ids i1,i2,....,i50 29 | item_ids = ["i%s" % i for i in range(1, 51)] 30 | for item_id in item_ids: 31 | print "Set item", item_id 32 | client.create_event( 33 | event="$set", 34 | entity_type="item", 35 | entity_id=item_id 36 | ) 37 | count += 1 38 | 39 | # each user randomly viewed 10 items 40 | for user_id in user_ids: 41 | for viewed_item in random.sample(item_ids, 10): 42 | print "User", user_id ,"views item", viewed_item 43 | client.create_event( 44 | event="view", 45 | entity_type="user", 46 | entity_id=user_id, 47 | target_entity_type="item", 48 | target_entity_id=viewed_item 49 | ) 50 | count += 1 51 | 52 | print "%s events are imported." % count 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser( 56 | description="Import sample data for similar product engine") 57 | parser.add_argument('--access_key', default='invald_access_key') 58 | parser.add_argument('--url', default="http://localhost:7070") 59 | 60 | args = parser.parse_args() 61 | print args 62 | 63 | client = predictionio.EventClient( 64 | access_key=args.access_key, 65 | url=args.url, 66 | threads=5, 67 | qsize=500) 68 | import_events(client) 69 | -------------------------------------------------------------------------------- /data/send_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | Send sample query to prediction engine 3 | """ 4 | 5 | import predictionio 6 | engine_client = predictionio.EngineClient(url="http://localhost:8000") 7 | print engine_client.send_query({ 8 | "user": "u2", 9 | "items": ["i1", "i3", "i10", "i2", "i5", "i31", "i9"] 10 | }) 11 | -------------------------------------------------------------------------------- /engine.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "default", 3 | "description": "Default settings", 4 | "engineFactory": "org.template.productranking.ProductRankingEngine", 5 | "datasource": { 6 | "params" : { 7 | "appName": "INVALID_APP_NAME" 8 | } 9 | }, 10 | "algorithms": [ 11 | { 12 | "name": "als", 13 | "params": { 14 | "rank": 10, 15 | "numIterations" : 20, 16 | "lambda": 0.01, 17 | "seed": 3 18 | } 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") 2 | -------------------------------------------------------------------------------- /project/pio-build.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("io.prediction" % "pio-build" % "0.9.0") 2 | -------------------------------------------------------------------------------- /src/main/scala/ALSAlgorithm.scala: -------------------------------------------------------------------------------- 1 | package org.template.productranking 2 | 3 | import io.prediction.controller.P2LAlgorithm 4 | import io.prediction.controller.Params 5 | import io.prediction.data.storage.BiMap 6 | 7 | import org.apache.spark.SparkContext 8 | import org.apache.spark.SparkContext._ 9 | import org.apache.spark.mllib.recommendation.ALS 10 | import org.apache.spark.mllib.recommendation.{Rating => MLlibRating} 11 | 12 | import grizzled.slf4j.Logger 13 | 14 | import scala.collection.parallel.immutable.ParVector 15 | 16 | case class ALSAlgorithmParams( 17 | rank: Int, 18 | numIterations: Int, 19 | lambda: Double, 20 | seed: Option[Long]) extends Params 21 | 22 | class ALSModel( 23 | val rank: Int, 24 | val userFeatures: Map[Int, Array[Double]], 25 | val productFeatures: Map[Int, Array[Double]], 26 | val userStringIntMap: BiMap[String, Int], 27 | val itemStringIntMap: BiMap[String, Int] 28 | ) extends Serializable { 29 | 30 | @transient lazy val itemIntStringMap = itemStringIntMap.inverse 31 | 32 | override def toString = { 33 | s" rank: ${rank}" + 34 | s" userFeatures: [${userFeatures.size}]" + 35 | s"(${userFeatures.take(2).toList}...)" + 36 | s" productFeatures: [${productFeatures.size}]" + 37 | s"(${productFeatures.take(2).toList}...)" + 38 | s" userStringIntMap: [${userStringIntMap.size}]" + 39 | s"(${userStringIntMap.take(2).toString}...)]" + 40 | s" itemStringIntMap: [${itemStringIntMap.size}]" + 41 | s"(${itemStringIntMap.take(2).toString}...)]" 42 | } 43 | } 44 | 45 | class ALSAlgorithm(val ap: ALSAlgorithmParams) 46 | extends P2LAlgorithm[PreparedData, ALSModel, Query, PredictedResult] { 47 | 48 | @transient lazy val logger = Logger[this.type] 49 | 50 | def train(sc: SparkContext, data: PreparedData): ALSModel = { 51 | require(!data.viewEvents.take(1).isEmpty, 52 | s"viewEvents in PreparedData cannot be empty." + 53 | " Please check if DataSource generates TrainingData" + 54 | " and Preprator generates PreparedData correctly.") 55 | require(!data.users.take(1).isEmpty, 56 | s"users in PreparedData cannot be empty." + 57 | " Please check if DataSource generates TrainingData" + 58 | " and Preprator generates PreparedData correctly.") 59 | require(!data.items.take(1).isEmpty, 60 | s"items in PreparedData cannot be empty." + 61 | " Please check if DataSource generates TrainingData" + 62 | " and Preprator generates PreparedData correctly.") 63 | // create User and item's String ID to integer index BiMap 64 | val userStringIntMap = BiMap.stringInt(data.users.keys) 65 | val itemStringIntMap = BiMap.stringInt(data.items.keys) 66 | 67 | val mllibRatings = data.viewEvents 68 | .map { r => 69 | // Convert user and item String IDs to Int index for MLlib 70 | val uindex = userStringIntMap.getOrElse(r.user, -1) 71 | val iindex = itemStringIntMap.getOrElse(r.item, -1) 72 | 73 | if (uindex == -1) 74 | logger.info(s"Couldn't convert nonexistent user ID ${r.user}" 75 | + " to Int index.") 76 | 77 | if (iindex == -1) 78 | logger.info(s"Couldn't convert nonexistent item ID ${r.item}" 79 | + " to Int index.") 80 | 81 | ((uindex, iindex), 1) 82 | }.filter { case ((u, i), v) => 83 | // keep events with valid user and item index 84 | (u != -1) && (i != -1) 85 | }.reduceByKey(_ + _) // aggregate all view events of same user-item pair 86 | .map { case ((u, i), v) => 87 | // MLlibRating requires integer index for user and item 88 | MLlibRating(u, i, v) 89 | } 90 | 91 | // MLLib ALS cannot handle empty training data. 92 | require(!mllibRatings.take(1).isEmpty, 93 | s"mllibRatings cannot be empty." + 94 | " Please check if your events contain valid user and item ID.") 95 | 96 | // seed for MLlib ALS 97 | val seed = ap.seed.getOrElse(System.nanoTime) 98 | 99 | val m = ALS.trainImplicit( 100 | ratings = mllibRatings, 101 | rank = ap.rank, 102 | iterations = ap.numIterations, 103 | lambda = ap.lambda, 104 | blocks = -1, 105 | alpha = 1.0, 106 | seed = seed) 107 | 108 | new ALSModel( 109 | rank = m.rank, 110 | userFeatures = m.userFeatures.collectAsMap.toMap, 111 | productFeatures = m.productFeatures.collectAsMap.toMap, 112 | userStringIntMap = userStringIntMap, 113 | itemStringIntMap = itemStringIntMap 114 | ) 115 | } 116 | 117 | def predict(model: ALSModel, query: Query): PredictedResult = { 118 | 119 | val itemStringIntMap = model.itemStringIntMap 120 | val productFeatures = model.productFeatures 121 | 122 | // default itemScores array if items are not ranked at all 123 | lazy val notRankedItemScores = 124 | query.items.map(i => ItemScore(i, 0)).toArray 125 | 126 | model.userStringIntMap.get(query.user).map { userIndex => 127 | // lookup userFeature for the user 128 | model.userFeatures.get(userIndex) 129 | }.flatten // flatten Option[Option[Array[Double]]] to Option[Array[Double]] 130 | .map { userFeature => 131 | val scores: Vector[Option[Double]] = query.items.toVector 132 | .par // convert to parallel collection for parallel lookup 133 | .map { iid => 134 | // convert query item id to index 135 | val featureOpt: Option[Array[Double]] = itemStringIntMap.get(iid) 136 | // productFeatures may not contain the item 137 | .map (index => productFeatures.get(index)) 138 | // flatten Option[Option[Array[Double]]] to Option[Array[Double]] 139 | .flatten 140 | 141 | featureOpt.map(f => dotProduct(f, userFeature)) 142 | }.seq // convert back to sequential collection 143 | 144 | // check if all scores is None (get rid of all None and see if empty) 145 | val isAllNone = scores.flatten.isEmpty 146 | 147 | if (isAllNone) { 148 | logger.info(s"No productFeature for all items ${query.items}.") 149 | PredictedResult( 150 | itemScores = notRankedItemScores, 151 | isOriginal = true 152 | ) 153 | } else { 154 | // sort the score 155 | val ord = Ordering.by[ItemScore, Double](_.score).reverse 156 | val sorted = query.items.zip(scores).map{ case (iid, scoreOpt) => 157 | ItemScore( 158 | item = iid, 159 | score = scoreOpt.getOrElse[Double](0) 160 | ) 161 | }.sorted(ord).toArray 162 | 163 | PredictedResult( 164 | itemScores = sorted, 165 | isOriginal = false 166 | ) 167 | } 168 | }.getOrElse { 169 | logger.info(s"No userFeature found for user ${query.user}.") 170 | PredictedResult( 171 | itemScores = notRankedItemScores, 172 | isOriginal = true 173 | ) 174 | } 175 | 176 | } 177 | 178 | private 179 | def dotProduct(v1: Array[Double], v2: Array[Double]): Double = { 180 | val size = v1.size 181 | var i = 0 182 | var d: Double = 0 183 | while (i < size) { 184 | d += v1(i) * v2(i) 185 | i += 1 186 | } 187 | d 188 | } 189 | 190 | } 191 | -------------------------------------------------------------------------------- /src/main/scala/DataSource.scala: -------------------------------------------------------------------------------- 1 | package org.template.productranking 2 | 3 | import io.prediction.controller.PDataSource 4 | import io.prediction.controller.EmptyEvaluationInfo 5 | import io.prediction.controller.EmptyActualResult 6 | import io.prediction.controller.Params 7 | import io.prediction.data.storage.Event 8 | import io.prediction.data.store.PEventStore 9 | 10 | import org.apache.spark.SparkContext 11 | import org.apache.spark.SparkContext._ 12 | import org.apache.spark.rdd.RDD 13 | 14 | import grizzled.slf4j.Logger 15 | 16 | case class DataSourceParams(appName: String) extends Params 17 | 18 | class DataSource(val dsp: DataSourceParams) 19 | extends PDataSource[TrainingData, 20 | EmptyEvaluationInfo, Query, EmptyActualResult] { 21 | 22 | @transient lazy val logger = Logger[this.type] 23 | 24 | override 25 | def readTraining(sc: SparkContext): TrainingData = { 26 | 27 | // create a RDD of (entityID, User) 28 | val usersRDD: RDD[(String, User)] = PEventStore.aggregateProperties( 29 | appName = dsp.appName, 30 | entityType = "user" 31 | )(sc).map { case (entityId, properties) => 32 | val user = try { 33 | // placeholder for expanding user properties 34 | User() 35 | } catch { 36 | case e: Exception => { 37 | logger.error(s"Failed to get properties ${properties} of" + 38 | s" user ${entityId}. Exception: ${e}.") 39 | throw e 40 | } 41 | } 42 | (entityId, user) 43 | }.cache() 44 | 45 | // create a RDD of (entityID, Item) 46 | val itemsRDD: RDD[(String, Item)] = PEventStore.aggregateProperties( 47 | appName = dsp.appName, 48 | entityType = "item" 49 | )(sc).map { case (entityId, properties) => 50 | val item = try { 51 | // placeholder for expanding item properties 52 | Item() 53 | } catch { 54 | case e: Exception => { 55 | logger.error(s"Failed to get properties ${properties} of" + 56 | s" item ${entityId}. Exception: ${e}.") 57 | throw e 58 | } 59 | } 60 | (entityId, item) 61 | }.cache() 62 | 63 | // get all "user" "view" "item" events 64 | val viewEventsRDD: RDD[ViewEvent] = PEventStore.find( 65 | appName = dsp.appName, 66 | entityType = Some("user"), 67 | eventNames = Some(List("view")), 68 | // targetEntityType is optional field of an event. 69 | targetEntityType = Some(Some("item")))(sc) 70 | // eventsDb.find() returns RDD[Event] 71 | .map { event => 72 | val viewEvent = try { 73 | event.event match { 74 | case "view" => ViewEvent( 75 | user = event.entityId, 76 | item = event.targetEntityId.get, 77 | t = event.eventTime.getMillis) 78 | case _ => throw new Exception(s"Unexpected event ${event} is read.") 79 | } 80 | } catch { 81 | case e: Exception => { 82 | logger.error(s"Cannot convert ${event} to ViewEvent." + 83 | s" Exception: ${e}.") 84 | throw e 85 | } 86 | } 87 | viewEvent 88 | }.cache() 89 | 90 | new TrainingData( 91 | users = usersRDD, 92 | items = itemsRDD, 93 | viewEvents = viewEventsRDD 94 | ) 95 | } 96 | } 97 | 98 | case class User() 99 | 100 | case class Item() 101 | 102 | case class ViewEvent(user: String, item: String, t: Long) 103 | 104 | class TrainingData( 105 | val users: RDD[(String, User)], 106 | val items: RDD[(String, Item)], 107 | val viewEvents: RDD[ViewEvent] 108 | ) extends Serializable { 109 | override def toString = { 110 | s"users: [${users.count()} (${users.take(2).toList}...)]" + 111 | s"items: [${items.count()} (${items.take(2).toList}...)]" + 112 | s"viewEvents: [${viewEvents.count()}] (${viewEvents.take(2).toList}...)" 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/main/scala/Engine.scala: -------------------------------------------------------------------------------- 1 | package org.template.productranking 2 | 3 | import io.prediction.controller.IEngineFactory 4 | import io.prediction.controller.Engine 5 | 6 | case class Query( 7 | user: String, 8 | items: List[String] 9 | ) extends Serializable 10 | 11 | case class PredictedResult( 12 | itemScores: Array[ItemScore], 13 | isOriginal: Boolean // set to true if the items are not ranked at all. 14 | ) extends Serializable 15 | 16 | case class ItemScore( 17 | item: String, 18 | score: Double 19 | ) extends Serializable 20 | 21 | object ProductRankingEngine extends IEngineFactory { 22 | def apply() = { 23 | new Engine( 24 | classOf[DataSource], 25 | classOf[Preparator], 26 | Map("als" -> classOf[ALSAlgorithm]), 27 | classOf[Serving]) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/Preparator.scala: -------------------------------------------------------------------------------- 1 | package org.template.productranking 2 | 3 | import io.prediction.controller.PPreparator 4 | 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.SparkContext._ 7 | import org.apache.spark.rdd.RDD 8 | 9 | class Preparator 10 | extends PPreparator[TrainingData, PreparedData] { 11 | 12 | def prepare(sc: SparkContext, trainingData: TrainingData): PreparedData = { 13 | new PreparedData( 14 | users = trainingData.users, 15 | items = trainingData.items, 16 | viewEvents = trainingData.viewEvents) 17 | } 18 | } 19 | 20 | class PreparedData( 21 | val users: RDD[(String, User)], 22 | val items: RDD[(String, Item)], 23 | val viewEvents: RDD[ViewEvent] 24 | ) extends Serializable 25 | -------------------------------------------------------------------------------- /src/main/scala/Serving.scala: -------------------------------------------------------------------------------- 1 | package org.template.productranking 2 | 3 | import io.prediction.controller.LServing 4 | 5 | class Serving 6 | extends LServing[Query, PredictedResult] { 7 | 8 | override 9 | def serve(query: Query, 10 | predictedResults: Seq[PredictedResult]): PredictedResult = { 11 | predictedResults.head 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /template.json: -------------------------------------------------------------------------------- 1 | {"pio": {"version": { "min": "0.9.2" }}} 2 | --------------------------------------------------------------------------------