├── .gitignore ├── README.md ├── build.sbt ├── data └── import_eventserver.py ├── engine.json ├── project ├── assembly.sbt └── build.properties ├── src └── main │ └── scala │ ├── Algorithm.scala │ ├── DataSource.scala │ ├── Engine.scala │ ├── Preparator.scala │ └── Serving.scala └── template.json /.gitignore: -------------------------------------------------------------------------------- 1 | data/*.txt 2 | manifest.json 3 | pio.log 4 | /pio.sbt 5 | target/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complementary Purchase Engine Template 2 | 3 | ## Documentation 4 | 5 | Please refer to http://docs.prediction.io/templates/complementarypurchase/quickstart/ 6 | 7 | ## Versions 8 | 9 | ### v0.3.3 10 | 11 | - change to sort by confidence instead of lift score and change default parameter values 12 | 13 | ### v0.3.2 14 | 15 | - lower the default minSupport and minConfidence parameter setting 16 | 17 | ### v0.3.1 18 | 19 | - Add `minBasketSize` algorithm parameter. 20 | 21 | ### v0.3.0 22 | 23 | - update for PredictionIO 0.9.2, including: 24 | 25 | - use new PEventStore API 26 | - use appName in DataSource parameter 27 | 28 | ### v0.2.0 29 | 30 | - update build.sbt and template.json for PredictionIO 0.9.2 31 | 32 | ### v0.1.1 33 | 34 | - Template name typo fix. Rename from "Complimentary" to "Complementary" 35 | 36 | ### v0.1.0 37 | 38 | - initial version (require PredictionIO >= 0.9.0) 39 | 40 | ## Development Notes 41 | 42 | ### import sample data 43 | 44 | ``` 45 | $ python data/import_eventserver.py --access_key 46 | ``` 47 | 48 | ### sample query 49 | 50 | ``` 51 | $ curl -H "Content-Type: application/json" \ 52 | -d '{ 53 | "items" : ["s2i1"], 54 | "num" : 3 }' \ 55 | http://localhost:8000/queries.json \ 56 | -w %{time_total} 57 | ``` 58 | 59 | 60 | ``` 61 | curl -H "Content-Type: application/json" \ 62 | -d '{ 63 | "items" : ["s2i1", "s2i3", "s1i2"], 64 | "num" : 3 }' \ 65 | http://localhost:8000/queries.json \ 66 | -w %{time_total} 67 | ``` 68 | 69 | ``` 70 | $ curl -H "Content-Type: application/json" \ 71 | -d '{ 72 | "items" : ["s1i2", "s1i1"], 73 | "num" : 4 }' \ 74 | http://localhost:8000/queries.json \ 75 | -w %{time_total} 76 | ``` 77 | 78 | 79 | ``` 80 | $ curl -H "Content-Type: application/json" \ 81 | -d '{ 82 | "items" : ["x", "s1i1"], 83 | "num" : 4 }' \ 84 | http://localhost:8000/queries.json \ 85 | -w %{time_total} 86 | ``` 87 | 88 | ``` 89 | $ curl -H "Content-Type: application/json" \ 90 | -d '{ 91 | "items" : ["i1"], 92 | "num" : 3 }' \ 93 | http://localhost:8000/queries.json \ 94 | -w %{time_total} 95 | ``` 96 | 97 | ``` 98 | curl -H "Content-Type: application/json" \ 99 | -d '{ 100 | "items" : ["p1", "p2", "p3"], 101 | "num" : 3 }' \ 102 | http://localhost:8000/queries.json \ 103 | -w %{time_total} 104 | ``` 105 | 106 | ``` 107 | $ curl -H "Content-Type: application/json" \ 108 | -d '{ 109 | "items" : ["i2", "i3"], 110 | "num" : 4 }' \ 111 | http://localhost:8000/queries.json \ 112 | -w %{time_total} 113 | ``` 114 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "template-scala-parallel-complementary-purchase" 2 | 3 | scalaVersion := "2.11.8" 4 | 5 | libraryDependencies ++= Seq( 6 | "org.apache.predictionio" %% "apache-predictionio-core" % "0.11.0-incubating" % "provided", 7 | "org.apache.spark" %% "spark-core" % "2.1.0" % "provided", 8 | "org.apache.spark" %% "spark-mllib" % "2.1.0" % "provided") 9 | -------------------------------------------------------------------------------- /data/import_eventserver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import sample data for complementary purchase engine 3 | """ 4 | 5 | import predictionio 6 | import argparse 7 | import random 8 | import uuid 9 | from datetime import datetime 10 | from datetime import timedelta 11 | import pytz 12 | 13 | SEED = 3 14 | 15 | def import_events(client): 16 | random.seed(SEED) 17 | count = 0 18 | print client.get_status() 19 | print "Importing data..." 20 | 21 | # generate 10 users, with user ids u1,u2,....,u10 22 | user_ids = ["u%s" % i for i in range(1, 10+1)] 23 | 24 | # randomly generate 4 frequent item set 25 | item_sets = {} 26 | for i in range(1, 4+1): 27 | # each set contain 2 to 4 items 28 | iids = range(1, random.randint(2, 4)+1) 29 | item_sets[i] = ["s%si%s" % (i, j) for j in iids] 30 | 31 | # plus 20 other items not in any set. 32 | other_items = ["i%s" % i for i in range(1, 20+1)] 33 | 34 | # 3 popular item every one buy 35 | pop_items = ["p%s" % i for i in range(1,3+1)] 36 | 37 | # each user have 5 basket purchases: 38 | for uid in user_ids: 39 | base_time = datetime( 40 | year = 2014, 41 | month = 10, 42 | day = random.randint(1,31), 43 | hour = 15, 44 | minute = 39, 45 | second = 45, 46 | microsecond = 618000, 47 | tzinfo = pytz.timezone('US/Pacific')) 48 | seconds = 0 49 | for basket in range(0, 5): 50 | # may or may not some random item 51 | if (random.choice([True, False])): 52 | buy_items = random.sample(other_items, random.randint(1, 3)) 53 | for iid in buy_items: 54 | event_time = base_time + timedelta(seconds=seconds, days=basket) 55 | print "User", uid, "buys item", iid, "at", event_time 56 | client.create_event( 57 | event = "buy", 58 | entity_type = "user", 59 | entity_id = uid, 60 | target_entity_type = "item", 61 | target_entity_id = iid, 62 | event_time = event_time 63 | ) 64 | seconds += 10 65 | count += 1 66 | 67 | # always buy one popular item 68 | buy_items = random.sample(pop_items, 1) 69 | for iid in buy_items: 70 | event_time = base_time + timedelta(seconds=seconds, days=basket) 71 | print "User", uid, "buys item", iid, "at", event_time 72 | client.create_event( 73 | event = "buy", 74 | entity_type = "user", 75 | entity_id = uid, 76 | target_entity_type = "item", 77 | target_entity_id = iid, 78 | event_time = event_time 79 | ) 80 | seconds += 10 81 | count += 1 82 | 83 | # always buy some something from one of the item set 84 | s = item_sets[random.choice(item_sets.keys())] 85 | buy_items = random.sample(s, random.randint(2, len(s))) 86 | for iid in buy_items: 87 | event_time = base_time + timedelta(seconds=seconds, days=basket) 88 | print "User", uid, "buys item", iid, "at", event_time 89 | client.create_event( 90 | event = "buy", 91 | entity_type = "user", 92 | entity_id = uid, 93 | target_entity_type = "item", 94 | target_entity_id = iid, 95 | event_time = event_time 96 | ) 97 | seconds += 10 98 | count += 1 99 | 100 | print "%s events are imported." % count 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser( 104 | description="Import sample data for similar product engine") 105 | parser.add_argument('--access_key', default='invald_access_key') 106 | parser.add_argument('--url', default="http://localhost:7070") 107 | 108 | args = parser.parse_args() 109 | print args 110 | 111 | client = predictionio.EventClient( 112 | access_key=args.access_key, 113 | url=args.url, 114 | threads=4, 115 | qsize=100) 116 | import_events(client) 117 | -------------------------------------------------------------------------------- /engine.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "default", 3 | "description": "Default settings", 4 | "engineFactory": "org.template.complementarypurchase.ComplementaryPurchaseEngine", 5 | "datasource": { 6 | "params" : { 7 | "appName": "INVALID_APP_NAME" 8 | } 9 | }, 10 | "algorithms": [ 11 | { 12 | "name": "algo", 13 | "params": { 14 | "basketWindow" : 120, 15 | "maxRuleLength" : 2, 16 | "minSupport": 0.001, 17 | "minConfidence": 0.1, 18 | "minLift" : 1.0, 19 | "minBasketSize" : 2, 20 | "maxNumRulesPerCond": 5 21 | } 22 | } 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.4") 2 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.15 2 | -------------------------------------------------------------------------------- /src/main/scala/Algorithm.scala: -------------------------------------------------------------------------------- 1 | package org.template.complementarypurchase 2 | 3 | import org.apache.predictionio.controller.P2LAlgorithm 4 | import org.apache.predictionio.controller.Params 5 | 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.SparkContext._ 8 | import org.apache.spark.rdd.RDD 9 | 10 | import grizzled.slf4j.Logger 11 | 12 | case class AlgorithmParams( 13 | basketWindow: Int, // in seconds 14 | maxRuleLength: Int, 15 | minSupport: Double, 16 | minConfidence: Double, 17 | minLift: Double, 18 | minBasketSize: Int, 19 | maxNumRulesPerCond: Int // max number of rules per condition 20 | ) extends Params 21 | 22 | class Algorithm(val ap: AlgorithmParams) 23 | extends P2LAlgorithm[PreparedData, Model, Query, PredictedResult] { 24 | 25 | @transient lazy val maxCondLength = ap.maxRuleLength - 1 26 | @transient lazy val logger = Logger[this.type] 27 | 28 | def train(sc: SparkContext, pd: PreparedData): Model = { 29 | val windowMillis = ap.basketWindow * 1000 30 | require(ap.maxRuleLength >= 2, 31 | s"maxRuleLength must be at least 2. Current: ${ap.maxRuleLength}.") 32 | require((ap.minSupport >= 0 && ap.minSupport < 1), 33 | s"minSupport must be >= 0 and < 1. Current: ${ap.minSupport}.") 34 | require((ap.minConfidence >= 0 && ap.minConfidence < 1), 35 | s"minSupport must be >= 0 and < 1. Current: ${ap.minSupport}.") 36 | require((ap.minBasketSize >= 2), 37 | s"minBasketSize must be >= 2. Current: ${ap.minBasketSize}.") 38 | 39 | val transactions: RDD[Set[String]] = pd.buyEvents 40 | .map (b => (b.user, new ItemAndTime(b.item, b.t))) 41 | .groupByKey 42 | // create RDD[Set[String]] // size 2 43 | .flatMap{ case (user, iter) => // user and iterable of ItemAndTime 44 | // sort by time and create List[ItemSet] based on time and window 45 | val sortedList = iter.toList.sortBy(_.t) 46 | val init = ItemSet[String](Set(sortedList.head.item), sortedList.head.t) 47 | val basketList = sortedList.tail 48 | .foldLeft(List[ItemSet[String]](init)) ( (list, itemAndTime) => 49 | // if current item time is within last item's time's window 50 | // add to same set. 51 | if ((itemAndTime.t - list.head.lastTime) <= windowMillis) 52 | (list.head + itemAndTime) :: list.tail 53 | else 54 | ItemSet(Set(itemAndTime.item), itemAndTime.t) :: list 55 | ) 56 | logger.debug(s"user ${user}: ${basketList}.") 57 | basketList.map(_.items).filter(_.size >= ap.minBasketSize) 58 | } 59 | .cache() 60 | 61 | val totalTransaction = transactions.count() 62 | val minSupportCount = ap.minSupport * totalTransaction 63 | 64 | logger.debug(s"transactions: ${transactions.collect.toList}") 65 | logger.info(s"totalTransaction: ${totalTransaction}") 66 | 67 | // generate item sets 68 | val itemSets: RDD[Set[String]] = transactions 69 | .flatMap { tran => 70 | (1 to ap.maxRuleLength).flatMap(n => tran.subsets(n)) 71 | } 72 | 73 | logger.debug(s"itemSets: ${itemSets.cache().collect.toList}") 74 | 75 | val itemSetCount: RDD[(Set[String], Int)] = itemSets.map(s => (s, 1)) 76 | .reduceByKey((a, b) => a + b) 77 | .filter(_._2 >= minSupportCount) 78 | .cache() 79 | 80 | logger.debug(s"itemSetCount: ${itemSetCount.collect.toList}") 81 | 82 | val rules: RDD[(Set[String], RuleScore)] = itemSetCount 83 | // a rule needs min set size >= 2 84 | .filter{ case (set, count) => set.size >= 2} 85 | .flatMap{ case (set, count) => 86 | set.map(i => (Set(i), (set - i, count))) 87 | } 88 | .join(itemSetCount) 89 | .map { case (conseq, ((cond, ruleCnt), conseqCnt)) => 90 | (cond, (conseq, conseqCnt, ruleCnt)) 91 | } 92 | .join(itemSetCount) 93 | .map { case (cond, ((conseq, conseqCnt, ruleCnt), condCnt)) => 94 | val support = ruleCnt.toDouble / totalTransaction 95 | val confidence = ruleCnt.toDouble / condCnt 96 | val lift = (ruleCnt.toDouble / (condCnt * conseqCnt)) * totalTransaction 97 | val ruleScore = RuleScore( 98 | conseq = conseq.head, // single item consequence 99 | support = support, 100 | confidence = confidence, 101 | lift = lift) 102 | (cond, ruleScore) 103 | } 104 | .filter{ case (cond, rs) => 105 | (rs.confidence >= ap.minConfidence) && (rs.lift >= ap.minLift) 106 | } 107 | 108 | val sortedRules = rules.groupByKey 109 | .mapValues(iter => 110 | iter.toVector 111 | .sortBy(_.confidence)(Ordering.Double.reverse) 112 | .take(ap.maxNumRulesPerCond) 113 | ) 114 | .collectAsMap.toMap 115 | 116 | new Model(sortedRules) 117 | } 118 | 119 | def predict(model: Model, query: Query): PredictedResult = { 120 | val conds = (1 to maxCondLength).flatMap(n => query.items.subsets(n)) 121 | 122 | val rules = conds.map { cond => 123 | model.rules.get(cond).map{ vec => 124 | val itemScores = vec.take(query.num).map { rs => 125 | new ItemScore( 126 | item = rs.conseq, 127 | support = rs.support, 128 | confidence = rs.confidence, 129 | lift = rs.lift 130 | ) 131 | }.toArray 132 | Rule(cond = cond, itemScores = itemScores) 133 | } 134 | }.flatten.toArray 135 | 136 | new PredictedResult(rules) 137 | } 138 | 139 | // item and time 140 | case class ItemAndTime[T](item: T, t: Long) 141 | 142 | // item set with last time of item 143 | case class ItemSet[T](items: Set[T], lastTime: Long) { 144 | def size = items.size 145 | 146 | def isEmpty = items.isEmpty 147 | 148 | def +(elem: ItemAndTime[T]): ItemSet[T] = { 149 | val newSet = items + elem.item 150 | val newLastTime = if (elem.t > lastTime) elem.t else lastTime 151 | new ItemSet(newSet, newLastTime) 152 | } 153 | } 154 | } 155 | 156 | case class RuleScore( 157 | conseq: String, support: Double, confidence: Double, lift: Double) 158 | 159 | class Model( 160 | val rules: Map[Set[String], Vector[RuleScore]] 161 | ) extends Serializable { 162 | override def toString = s"rules: ${rules.size} ${rules.take(2)}..." 163 | } 164 | -------------------------------------------------------------------------------- /src/main/scala/DataSource.scala: -------------------------------------------------------------------------------- 1 | package org.template.complementarypurchase 2 | 3 | import org.apache.predictionio.controller.PDataSource 4 | import org.apache.predictionio.controller.EmptyEvaluationInfo 5 | import org.apache.predictionio.controller.EmptyActualResult 6 | import org.apache.predictionio.controller.Params 7 | import org.apache.predictionio.data.storage.Event 8 | import org.apache.predictionio.data.store.PEventStore 9 | 10 | import org.apache.spark.SparkContext 11 | import org.apache.spark.SparkContext._ 12 | import org.apache.spark.rdd.RDD 13 | 14 | import grizzled.slf4j.Logger 15 | 16 | case class DataSourceParams(appName: String) extends Params 17 | 18 | class DataSource(val dsp: DataSourceParams) 19 | extends PDataSource[TrainingData, 20 | EmptyEvaluationInfo, Query, EmptyActualResult] { 21 | 22 | @transient lazy val logger = Logger[this.type] 23 | 24 | override 25 | def readTraining(sc: SparkContext): TrainingData = { 26 | 27 | val buyEvents: RDD[BuyEvent] = PEventStore.find( 28 | appName = dsp.appName, 29 | entityType = Some("user"), 30 | eventNames = Some(List("buy")), 31 | targetEntityType = Some(Some("item")))(sc) 32 | .map { event => 33 | try { 34 | new BuyEvent( 35 | user = event.entityId, 36 | item = event.targetEntityId.get, 37 | t = event.eventTime.getMillis 38 | ) 39 | } catch { 40 | case e: Exception => { 41 | logger.error(s"Cannot convert ${event} to BuyEvent. ${e}") 42 | throw e 43 | } 44 | } 45 | }.cache() 46 | 47 | new TrainingData(buyEvents) 48 | } 49 | } 50 | 51 | case class BuyEvent(user: String, item: String, t: Long) 52 | 53 | class TrainingData( 54 | val buyEvents: RDD[BuyEvent] 55 | ) extends Serializable { 56 | override def toString = { 57 | s"buyEvents: [${buyEvents.count()}] (${buyEvents.take(2).toList}...)" 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/Engine.scala: -------------------------------------------------------------------------------- 1 | package org.template.complementarypurchase 2 | 3 | import org.apache.predictionio.controller.EngineFactory 4 | import org.apache.predictionio.controller.Engine 5 | 6 | case class Query(items: Set[String], num: Int) 7 | extends Serializable 8 | 9 | case class PredictedResult(rules: Array[Rule]) 10 | extends Serializable 11 | 12 | //case class ItemScore(item: String, score: Double) extends Serializable 13 | case class Rule(cond: Set[String], itemScores: Array[ItemScore]) 14 | extends Serializable 15 | 16 | case class ItemScore( 17 | item: String, support: Double, confidence: Double, lift: Double 18 | ) extends Serializable 19 | 20 | object ComplementaryPurchaseEngine extends EngineFactory { 21 | def apply() = { 22 | new Engine( 23 | classOf[DataSource], 24 | classOf[Preparator], 25 | Map("algo" -> classOf[Algorithm]), 26 | classOf[Serving]) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/Preparator.scala: -------------------------------------------------------------------------------- 1 | package org.template.complementarypurchase 2 | 3 | import org.apache.predictionio.controller.PPreparator 4 | 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.SparkContext._ 7 | import org.apache.spark.rdd.RDD 8 | 9 | import grizzled.slf4j.Logger 10 | 11 | class Preparator 12 | extends PPreparator[TrainingData, PreparedData] { 13 | 14 | @transient lazy val logger = Logger[this.type] 15 | 16 | def prepare(sc: SparkContext, td: TrainingData): PreparedData = { 17 | new PreparedData(buyEvents = td.buyEvents) 18 | } 19 | } 20 | 21 | class PreparedData( 22 | val buyEvents: RDD[BuyEvent] 23 | ) extends Serializable 24 | -------------------------------------------------------------------------------- /src/main/scala/Serving.scala: -------------------------------------------------------------------------------- 1 | package org.template.complementarypurchase 2 | 3 | import org.apache.predictionio.controller.LServing 4 | 5 | import grizzled.slf4j.Logger 6 | 7 | class Serving 8 | extends LServing[Query, PredictedResult] { 9 | 10 | @transient lazy val logger = Logger[this.type] 11 | 12 | override 13 | def serve(query: Query, 14 | predictedResults: Seq[PredictedResult]): PredictedResult = { 15 | predictedResults.head 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /template.json: -------------------------------------------------------------------------------- 1 | {"pio": {"version": { "min": "0.11.0-incubating" }}} 2 | 3 | --------------------------------------------------------------------------------