├── .gitignore ├── LICENSE.txt ├── README.md ├── build.sbt ├── data ├── import_eventserver.py └── send_query.py ├── engine.json ├── project ├── assembly.sbt └── build.properties ├── src └── main │ └── java │ └── org │ └── example │ └── recommendation │ ├── Algorithm.java │ ├── AlgorithmParams.java │ ├── DataSource.java │ ├── DataSourceParams.java │ ├── Item.java │ ├── ItemScore.java │ ├── Model.java │ ├── PredictedResult.java │ ├── Preparator.java │ ├── PreparedData.java │ ├── Query.java │ ├── RecommendationEngine.java │ ├── Serving.java │ ├── TrainingData.java │ ├── User.java │ ├── UserItemEvent.java │ ├── UserItemEventType.java │ └── evaluation │ ├── EvaluationParameter.java │ ├── EvaluationSpec.java │ └── PrecisionMetric.java └── template.json /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | manifest.json 3 | target/ 4 | pio.log 5 | /pio.sbt 6 | 7 | # Eclipse 8 | .project 9 | .classpath 10 | .settings/ 11 | 12 | # IntelliJ 13 | *.iml 14 | .idea/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E-Commerce Recommendation Template in Java 2 | 3 | ## Documentation 4 | 5 | Please refer to 6 | https://predictionio.apache.org/templates/javaecommercerecommendation/quickstart/. 7 | 8 | ## Versions 9 | 10 | ### v0.14.0 11 | 12 | Update for Apache PredictionIO 0.14.0 13 | 14 | ### v0.13.0 15 | 16 | Update for Apache PredictionIO 0.13.0 17 | 18 | ### v0.12.0-incubating 19 | - Bump version number to track PredictionIO version 20 | - Sets default build targets according to PredictionIO 21 | - Fix compilation issue with Scala 2.11 22 | 23 | ### v0.11.0-incubating 24 | 25 | - Update to build with PredictionIO 0.11.0-incubating 26 | - Rename Java package name 27 | - Update SBT and plugin versions 28 | 29 | ### v0.1.2 30 | add "org.jblas" dependency in build.sbt 31 | 32 | ### v0.1.1 33 | - parallelize filtering valid items 34 | 35 | ### v0.1.0 36 | 37 | - initial version 38 | 39 | 40 | ## Development Notes 41 | 42 | ### Import Sample Data 43 | 44 | ``` 45 | $ python data/import_eventserver.py --access_key 46 | ``` 47 | 48 | ### Query 49 | 50 | normal: 51 | 52 | ``` 53 | $ curl -H "Content-Type: application/json" \ 54 | -d '{ 55 | "userEntityId" : "u1", 56 | "number" : 10 }' \ 57 | http://localhost:8000/queries.json 58 | ``` 59 | 60 | ``` 61 | $ curl -H "Content-Type: application/json" \ 62 | -d '{ 63 | "userEntityId" : "u1", 64 | "number": 10, 65 | "categories" : ["c4", "c3"] 66 | }' \ 67 | http://localhost:8000/queries.json 68 | ``` 69 | 70 | ``` 71 | curl -H "Content-Type: application/json" \ 72 | -d '{ 73 | "userEntityId" : "u1", 74 | "number": 10, 75 | "whitelist": ["i21", "i26", "i40"] 76 | }' \ 77 | http://localhost:8000/queries.json 78 | ``` 79 | 80 | ``` 81 | curl -H "Content-Type: application/json" \ 82 | -d '{ 83 | "userEntityId" : "u1", 84 | "number": 10, 85 | "blacklist": ["i21", "i26", "i40"] 86 | }' \ 87 | http://localhost:8000/queries.json 88 | ``` 89 | 90 | unknown user: 91 | 92 | ``` 93 | curl -H "Content-Type: application/json" \ 94 | -d '{ 95 | "userEntityId" : "unk1", 96 | "number": 10}' \ 97 | http://localhost:8000/queries.json 98 | ``` 99 | 100 | ### Handle New User 101 | 102 | new user: 103 | 104 | ``` 105 | curl -H "Content-Type: application/json" \ 106 | -d '{ 107 | "userEntityId" : "x1", 108 | "number": 10}' \ 109 | http://localhost:8000/queries.json 110 | ``` 111 | 112 | import some view events and try to get recommendation for x1 again. 113 | 114 | ``` 115 | accessKey= 116 | ``` 117 | 118 | ``` 119 | curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \ 120 | -H "Content-Type: application/json" \ 121 | -d '{ 122 | "event" : "view", 123 | "entityType" : "user" 124 | "entityId" : "x1", 125 | "targetEntityType" : "item", 126 | "targetEntityId" : "i2", 127 | "eventTime" : "2015-02-17T02:11:21.934Z" 128 | }' 129 | 130 | curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \ 131 | -H "Content-Type: application/json" \ 132 | -d '{ 133 | "event" : "view", 134 | "entityType" : "user" 135 | "entityId" : "x1", 136 | "targetEntityType" : "item", 137 | "targetEntityId" : "i3", 138 | "eventTime" : "2015-02-17T02:12:21.934Z" 139 | }' 140 | 141 | ``` 142 | 143 | ### Handle Unavailable Items 144 | 145 | Set the following items as unavailable (need to specify complete list each time when this list is changed): 146 | 147 | ``` 148 | curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \ 149 | -H "Content-Type: application/json" \ 150 | -d '{ 151 | "event" : "$set", 152 | "entityType" : "constraint" 153 | "entityId" : "unavailableItems", 154 | "properties" : { 155 | "items": ["i43", "i20", "i37", "i3", "i4", "i5"], 156 | } 157 | "eventTime" : "2015-02-17T02:11:21.934Z" 158 | }' 159 | ``` 160 | 161 | Set empty list when no more items unavailable: 162 | 163 | ``` 164 | curl -i -X POST http://localhost:7070/events.json?accessKey=$accessKey \ 165 | -H "Content-Type: application/json" \ 166 | -d '{ 167 | "event" : "$set", 168 | "entityType" : "constraint" 169 | "entityId" : "unavailableItems", 170 | "properties" : { 171 | "items": [], 172 | } 173 | "eventTime" : "2015-02-18T02:11:21.934Z" 174 | }' 175 | ``` 176 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "template-java-parallel-ecommercerecommendation" 2 | 3 | scalaVersion := "2.11.12" 4 | libraryDependencies ++= Seq( 5 | "org.apache.predictionio" %% "apache-predictionio-core" % "0.14.0" % "provided", 6 | "org.apache.spark" %% "spark-mllib" % "2.4.0" % "provided", 7 | "org.jblas" % "jblas" % "1.2.4") 8 | -------------------------------------------------------------------------------- /data/import_eventserver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Import sample data for E-Commerce Recommendation Engine Template 3 | """ 4 | 5 | import predictionio 6 | import argparse 7 | import random 8 | 9 | SEED = 3 10 | 11 | def import_events(client): 12 | random.seed(SEED) 13 | count = 0 14 | print(client.get_status()) 15 | print("Importing data...") 16 | 17 | # generate 10 users, with user ids u1,u2,....,u10 18 | user_ids = ["u%s" % i for i in range(1, 11)] 19 | for user_id in user_ids: 20 | print("Set user", user_id) 21 | client.create_event( 22 | event="$set", 23 | entity_type="user", 24 | entity_id=user_id 25 | ) 26 | count += 1 27 | 28 | # generate 50 items, with itemEntityId ids i1,i2,....,i50 29 | # random assign 1 to 4 categories among c1-c6 to items 30 | categories = ["c%s" % i for i in range(1, 7)] 31 | item_ids = ["i%s" % i for i in range(1, 51)] 32 | for item_id in item_ids: 33 | print("Set itemEntityId", item_id) 34 | client.create_event( 35 | event="$set", 36 | entity_type="item", 37 | entity_id=item_id, 38 | properties={ 39 | "categories" : random.sample(categories, random.randint(1, 4)) 40 | } 41 | ) 42 | count += 1 43 | 44 | # each user randomly viewed 10 items 45 | for user_id in user_ids: 46 | for viewed_item in random.sample(item_ids, 10): 47 | print("User", user_id ,"views itemEntityId", viewed_item) 48 | client.create_event( 49 | event="view", 50 | entity_type="user", 51 | entity_id=user_id, 52 | target_entity_type="item", 53 | target_entity_id=viewed_item 54 | ) 55 | count += 1 56 | 57 | # each user randomly bought 3 items 58 | for user_id in user_ids: 59 | for item in random.sample(item_ids, 3): 60 | print("User", user_id ,"buys itemEntityId", item) 61 | client.create_event( 62 | event="buy", 63 | entity_type="user", 64 | entity_id=user_id, 65 | target_entity_type="item", 66 | target_entity_id=item 67 | ) 68 | count += 1 69 | 70 | 71 | print("%s events are imported." % count) 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser( 75 | description="Import sample data for e-commerce recommendation engine") 76 | parser.add_argument('--access_key', default='invald_access_key') 77 | parser.add_argument('--url', default="http://localhost:7070") 78 | 79 | args = parser.parse_args() 80 | print(args) 81 | 82 | client = predictionio.EventClient( 83 | access_key=args.access_key, 84 | url=args.url, 85 | threads=5, 86 | qsize=500) 87 | import_events(client) 88 | -------------------------------------------------------------------------------- /data/send_query.py: -------------------------------------------------------------------------------- 1 | """ 2 | Send sample query to prediction engine 3 | """ 4 | 5 | import predictionio 6 | engine_client = predictionio.EngineClient(url="http://localhost:8000") 7 | print(engine_client.send_query({"userEntityId": "u1", "number": 4})) 8 | -------------------------------------------------------------------------------- /engine.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "default", 3 | "description": "Default settings", 4 | "engineFactory": "org.example.recommendation.RecommendationEngine", 5 | "datasource": { 6 | "params" : { 7 | "appName": "javadase" 8 | } 9 | }, 10 | "algorithms": [ 11 | { 12 | "name": "algo", 13 | "params": { 14 | "seed": 1, 15 | "rank": 10, 16 | "iteration": 10, 17 | "lambda": 0.01, 18 | "appName": "javadase", 19 | "similarItemEvents": ["view"], 20 | "seenItemEvents": ["buy", "view"], 21 | "unseenOnly": true 22 | } 23 | } 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9") 2 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.2.8 2 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Algorithm.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import com.google.common.collect.Sets; 4 | import org.apache.predictionio.controller.java.PJavaAlgorithm; 5 | import org.apache.predictionio.data.storage.Event; 6 | import org.apache.predictionio.data.store.java.LJavaEventStore; 7 | import org.apache.predictionio.data.store.java.OptionHelper; 8 | import org.apache.spark.SparkContext; 9 | import org.apache.spark.api.java.JavaPairRDD; 10 | import org.apache.spark.api.java.JavaRDD; 11 | import org.apache.spark.api.java.JavaSparkContext; 12 | import org.apache.spark.api.java.function.Function; 13 | import org.apache.spark.api.java.function.Function2; 14 | import org.apache.spark.api.java.function.PairFunction; 15 | import org.apache.spark.mllib.recommendation.ALS; 16 | import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; 17 | import org.apache.spark.mllib.recommendation.Rating; 18 | import org.apache.spark.rdd.RDD; 19 | import org.jblas.DoubleMatrix; 20 | import org.joda.time.DateTime; 21 | import org.slf4j.Logger; 22 | import org.slf4j.LoggerFactory; 23 | import scala.Option; 24 | import scala.Tuple2; 25 | import scala.concurrent.duration.Duration; 26 | 27 | import java.util.ArrayList; 28 | import java.util.Collections; 29 | import java.util.HashSet; 30 | import java.util.List; 31 | import java.util.Map; 32 | import java.util.Set; 33 | import java.util.concurrent.TimeUnit; 34 | 35 | public class Algorithm extends PJavaAlgorithm { 36 | 37 | private static final Logger logger = LoggerFactory.getLogger(Algorithm.class); 38 | private final AlgorithmParams ap; 39 | 40 | public Algorithm(AlgorithmParams ap) { 41 | this.ap = ap; 42 | } 43 | 44 | @Override 45 | public Model train(SparkContext sc, PreparedData preparedData) { 46 | TrainingData data = preparedData.getTrainingData(); 47 | 48 | // user stuff 49 | JavaPairRDD userIndexRDD = data.getUsers().map(new Function, String>() { 50 | @Override 51 | public String call(Tuple2 idUser) throws Exception { 52 | return idUser._1(); 53 | } 54 | }).zipWithIndex().mapToPair(new PairFunction, String, Integer>() { 55 | @Override 56 | public Tuple2 call(Tuple2 element) throws Exception { 57 | return new Tuple2<>(element._1(), element._2().intValue()); 58 | } 59 | }); 60 | final Map userIndexMap = userIndexRDD.collectAsMap(); 61 | 62 | // item stuff 63 | JavaPairRDD itemIndexRDD = data.getItems().map(new Function, String>() { 64 | @Override 65 | public String call(Tuple2 idItem) throws Exception { 66 | return idItem._1(); 67 | } 68 | }).zipWithIndex().mapToPair(new PairFunction, String, Integer>() { 69 | @Override 70 | public Tuple2 call(Tuple2 element) throws Exception { 71 | return new Tuple2<>(element._1(), element._2().intValue()); 72 | } 73 | }); 74 | final Map itemIndexMap = itemIndexRDD.collectAsMap(); 75 | JavaPairRDD indexItemRDD = itemIndexRDD.mapToPair(new PairFunction, Integer, String>() { 76 | @Override 77 | public Tuple2 call(Tuple2 element) throws Exception { 78 | return element.swap(); 79 | } 80 | }); 81 | final Map indexItemMap = indexItemRDD.collectAsMap(); 82 | 83 | // ratings stuff 84 | JavaRDD ratings = data.getViewEvents().mapToPair(new PairFunction, Integer>() { 85 | @Override 86 | public Tuple2, Integer> call(UserItemEvent viewEvent) throws Exception { 87 | Integer userIndex = userIndexMap.get(viewEvent.getUser()); 88 | Integer itemIndex = itemIndexMap.get(viewEvent.getItem()); 89 | 90 | return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); 91 | } 92 | }).filter(new Function, Integer>, Boolean>() { 93 | @Override 94 | public Boolean call(Tuple2, Integer> element) throws Exception { 95 | return (element != null); 96 | } 97 | }).reduceByKey(new Function2() { 98 | @Override 99 | public Integer call(Integer integer, Integer integer2) throws Exception { 100 | return integer + integer2; 101 | } 102 | }).map(new Function, Integer>, Rating>() { 103 | @Override 104 | public Rating call(Tuple2, Integer> userItemCount) throws Exception { 105 | return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue()); 106 | } 107 | }); 108 | 109 | if (ratings.isEmpty()) 110 | throw new AssertionError("Please check if your events contain valid user and item ID."); 111 | 112 | // MLlib ALS stuff 113 | MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed()); 114 | JavaPairRDD userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction, Integer, double[]>() { 115 | @Override 116 | public Tuple2 call(Tuple2 element) throws Exception { 117 | return new Tuple2<>((Integer) element._1(), element._2()); 118 | } 119 | }); 120 | JavaPairRDD productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction, Integer, double[]>() { 121 | @Override 122 | public Tuple2 call(Tuple2 element) throws Exception { 123 | return new Tuple2<>((Integer) element._1(), element._2()); 124 | } 125 | }); 126 | 127 | // popularity scores 128 | JavaRDD itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction, Integer>() { 129 | @Override 130 | public Tuple2, Integer> call(UserItemEvent buyEvent) throws Exception { 131 | Integer userIndex = userIndexMap.get(buyEvent.getUser()); 132 | Integer itemIndex = itemIndexMap.get(buyEvent.getItem()); 133 | 134 | return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); 135 | } 136 | }).filter(new Function, Integer>, Boolean>() { 137 | @Override 138 | public Boolean call(Tuple2, Integer> element) throws Exception { 139 | return (element != null); 140 | } 141 | }).mapToPair(new PairFunction, Integer>, Integer, Integer>() { 142 | @Override 143 | public Tuple2 call(Tuple2, Integer> element) throws Exception { 144 | return new Tuple2<>(element._1()._2(), element._2()); 145 | } 146 | }).reduceByKey(new Function2() { 147 | @Override 148 | public Integer call(Integer integer, Integer integer2) throws Exception { 149 | return integer + integer2; 150 | } 151 | }).map(new Function, ItemScore>() { 152 | @Override 153 | public ItemScore call(Tuple2 element) throws Exception { 154 | return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue()); 155 | } 156 | }); 157 | 158 | JavaPairRDD> indexItemFeatures = indexItemRDD.join(productFeaturesRDD); 159 | 160 | return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap()); 161 | } 162 | 163 | @Override 164 | public PredictedResult predict(Model model, final Query query) { 165 | final JavaPairRDD matchedUser = model.getUserIndex().filter(new Function, Boolean>() { 166 | @Override 167 | public Boolean call(Tuple2 userIndex) throws Exception { 168 | return userIndex._1().equals(query.getUserEntityId()); 169 | } 170 | }); 171 | 172 | double[] userFeature = null; 173 | if (!matchedUser.isEmpty()) { 174 | final Integer matchedUserIndex = matchedUser.first()._2(); 175 | userFeature = model.getUserFeatures().filter(new Function, Boolean>() { 176 | @Override 177 | public Boolean call(Tuple2 element) throws Exception { 178 | return element._1().equals(matchedUserIndex); 179 | } 180 | }).first()._2(); 181 | } 182 | 183 | if (userFeature != null) { 184 | return new PredictedResult(topItemsForUser(userFeature, model, query)); 185 | } else { 186 | List recentProductFeatures = getRecentProductFeatures(query, model); 187 | if (recentProductFeatures.isEmpty()) { 188 | return new PredictedResult(mostPopularItems(model, query)); 189 | } else { 190 | return new PredictedResult(similarItems(recentProductFeatures, model, query)); 191 | } 192 | } 193 | } 194 | 195 | @Override 196 | public RDD> batchPredict(Model model, RDD> qs) { 197 | List> indexQueries = qs.toJavaRDD().collect(); 198 | List> results = new ArrayList<>(); 199 | 200 | for (Tuple2 indexQuery : indexQueries) { 201 | results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2()))); 202 | } 203 | 204 | return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd(); 205 | } 206 | 207 | private List getRecentProductFeatures(Query query, Model model) { 208 | try { 209 | List result = new ArrayList<>(); 210 | 211 | List events = LJavaEventStore.findByEntity( 212 | ap.getAppName(), 213 | "user", 214 | query.getUserEntityId(), 215 | OptionHelper.none(), 216 | OptionHelper.some(ap.getSimilarItemEvents()), 217 | OptionHelper.some(OptionHelper.some("item")), 218 | OptionHelper.>none(), 219 | OptionHelper.none(), 220 | OptionHelper.none(), 221 | OptionHelper.some(10), 222 | true, 223 | Duration.apply(10, TimeUnit.SECONDS)); 224 | 225 | for (final Event event : events) { 226 | if (event.targetEntityId().isDefined()) { 227 | JavaPairRDD filtered = model.getItemIndex().filter(new Function, Boolean>() { 228 | @Override 229 | public Boolean call(Tuple2 element) throws Exception { 230 | return element._1().equals(event.targetEntityId().get()); 231 | } 232 | }); 233 | 234 | final Integer itemIndex = filtered.first()._2(); 235 | 236 | if (!filtered.isEmpty()) { 237 | 238 | JavaPairRDD> indexItemFeatures = model.getIndexItemFeatures().filter(new Function>, Boolean>() { 239 | @Override 240 | public Boolean call(Tuple2> element) throws Exception { 241 | return itemIndex.equals(element._1()); 242 | } 243 | }); 244 | 245 | List>> oneIndexItemFeatures = indexItemFeatures.collect(); 246 | if (oneIndexItemFeatures.size() > 0) { 247 | result.add(oneIndexItemFeatures.get(0)._2()._2()); 248 | } 249 | } 250 | } 251 | } 252 | 253 | return result; 254 | } catch (Exception e) { 255 | logger.error("Error reading recent events for user " + query.getUserEntityId()); 256 | throw new RuntimeException(e.getMessage(), e); 257 | } 258 | } 259 | 260 | private List topItemsForUser(double[] userFeature, Model model, Query query) { 261 | final DoubleMatrix userMatrix = new DoubleMatrix(userFeature); 262 | 263 | JavaRDD itemScores = model.getIndexItemFeatures().map(new Function>, ItemScore>() { 264 | @Override 265 | public ItemScore call(Tuple2> element) throws Exception { 266 | return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2()))); 267 | } 268 | }); 269 | 270 | itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); 271 | return sortAndTake(itemScores, query.getNumber()); 272 | } 273 | 274 | private List similarItems(final List recentProductFeatures, Model model, Query query) { 275 | JavaRDD itemScores = model.getIndexItemFeatures().map(new Function>, ItemScore>() { 276 | @Override 277 | public ItemScore call(Tuple2> element) throws Exception { 278 | double similarity = 0.0; 279 | for (double[] recentFeature : recentProductFeatures) { 280 | similarity += cosineSimilarity(element._2()._2(), recentFeature); 281 | } 282 | 283 | return new ItemScore(element._2()._1(), similarity); 284 | } 285 | }); 286 | 287 | itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); 288 | return sortAndTake(itemScores, query.getNumber()); 289 | } 290 | 291 | private List mostPopularItems(Model model, Query query) { 292 | JavaRDD itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); 293 | return sortAndTake(itemScores, query.getNumber()); 294 | } 295 | 296 | private double cosineSimilarity(double[] a, double[] b) { 297 | DoubleMatrix matrixA = new DoubleMatrix(a); 298 | DoubleMatrix matrixB = new DoubleMatrix(b); 299 | 300 | return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2()); 301 | } 302 | 303 | private List sortAndTake(JavaRDD all, int number) { 304 | return all.sortBy(new Function() { 305 | @Override 306 | public Double call(ItemScore itemScore) throws Exception { 307 | return itemScore.getScore(); 308 | } 309 | }, false, all.partitions().size()).take(number); 310 | } 311 | 312 | private JavaRDD validScores(JavaRDD all, final Set whitelist, final Set blacklist, final Set categories, final Map items, String userEntityId) { 313 | final Set seenItemEntityIds = seenItemEntityIds(userEntityId); 314 | final Set unavailableItemEntityIds = unavailableItemEntityIds(); 315 | 316 | return all.filter(new Function() { 317 | @Override 318 | public Boolean call(ItemScore itemScore) throws Exception { 319 | Item item = items.get(itemScore.getItemEntityId()); 320 | 321 | return (item != null 322 | && passWhitelistCriteria(whitelist, item.getEntityId()) 323 | && passBlacklistCriteria(blacklist, item.getEntityId()) 324 | && passCategoryCriteria(categories, item) 325 | && passUnseenCriteria(seenItemEntityIds, item.getEntityId()) 326 | && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId())); 327 | } 328 | }); 329 | } 330 | 331 | private boolean passWhitelistCriteria(Set whitelist, String itemEntityId) { 332 | return (whitelist.isEmpty() || whitelist.contains(itemEntityId)); 333 | } 334 | 335 | private boolean passBlacklistCriteria(Set blacklist, String itemEntityId) { 336 | return !blacklist.contains(itemEntityId); 337 | } 338 | 339 | private boolean passCategoryCriteria(Set categories, Item item) { 340 | return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0); 341 | } 342 | 343 | private boolean passUnseenCriteria(Set seen, String itemEntityId) { 344 | return !seen.contains(itemEntityId); 345 | } 346 | 347 | private boolean passAvailabilityCriteria(Set unavailableItemEntityIds, String entityId) { 348 | return !unavailableItemEntityIds.contains(entityId); 349 | } 350 | 351 | private Set unavailableItemEntityIds() { 352 | try { 353 | List unavailableConstraintEvents = LJavaEventStore.findByEntity( 354 | ap.getAppName(), 355 | "constraint", 356 | "unavailableItems", 357 | OptionHelper.none(), 358 | OptionHelper.some(Collections.singletonList("$set")), 359 | OptionHelper.>none(), 360 | OptionHelper.>none(), 361 | OptionHelper.none(), 362 | OptionHelper.none(), 363 | OptionHelper.some(1), 364 | true, 365 | Duration.apply(10, TimeUnit.SECONDS)); 366 | 367 | if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet(); 368 | 369 | Event unavailableConstraint = unavailableConstraintEvents.get(0); 370 | 371 | List unavailableItems = unavailableConstraint.properties().getStringList("items"); 372 | 373 | return new HashSet<>(unavailableItems); 374 | } catch (Exception e) { 375 | logger.error("Error reading constraint events"); 376 | throw new RuntimeException(e.getMessage(), e); 377 | } 378 | } 379 | 380 | private Set seenItemEntityIds(String userEntityId) { 381 | if (!ap.isUnseenOnly()) return Collections.emptySet(); 382 | 383 | try { 384 | Set result = new HashSet<>(); 385 | List seenEvents = LJavaEventStore.findByEntity( 386 | ap.getAppName(), 387 | "user", 388 | userEntityId, 389 | OptionHelper.none(), 390 | OptionHelper.some(ap.getSeenItemEvents()), 391 | OptionHelper.some(OptionHelper.some("item")), 392 | OptionHelper.>none(), 393 | OptionHelper.none(), 394 | OptionHelper.none(), 395 | OptionHelper.none(), 396 | true, 397 | Duration.apply(10, TimeUnit.SECONDS)); 398 | 399 | for (Event event : seenEvents) { 400 | result.add(event.targetEntityId().get()); 401 | } 402 | 403 | return result; 404 | } catch (Exception e) { 405 | logger.error("Error reading seen events for user " + userEntityId); 406 | throw new RuntimeException(e.getMessage(), e); 407 | } 408 | } 409 | } 410 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/AlgorithmParams.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.Params; 4 | 5 | import java.util.List; 6 | 7 | public class AlgorithmParams implements Params{ 8 | private final long seed; 9 | private final int rank; 10 | private final int iteration; 11 | private final double lambda; 12 | private final String appName; 13 | private final List similarItemEvents; 14 | private final boolean unseenOnly; 15 | private final List seenItemEvents; 16 | 17 | 18 | public AlgorithmParams(long seed, int rank, int iteration, double lambda, String appName, List similarItemEvents, boolean unseenOnly, List seenItemEvents) { 19 | this.seed = seed; 20 | this.rank = rank; 21 | this.iteration = iteration; 22 | this.lambda = lambda; 23 | this.appName = appName; 24 | this.similarItemEvents = similarItemEvents; 25 | this.unseenOnly = unseenOnly; 26 | this.seenItemEvents = seenItemEvents; 27 | } 28 | 29 | public long getSeed() { 30 | return seed; 31 | } 32 | 33 | public int getRank() { 34 | return rank; 35 | } 36 | 37 | public int getIteration() { 38 | return iteration; 39 | } 40 | 41 | public double getLambda() { 42 | return lambda; 43 | } 44 | 45 | public String getAppName() { 46 | return appName; 47 | } 48 | 49 | public List getSimilarItemEvents() { 50 | return similarItemEvents; 51 | } 52 | 53 | public boolean isUnseenOnly() { 54 | return unseenOnly; 55 | } 56 | 57 | public List getSeenItemEvents() { 58 | return seenItemEvents; 59 | } 60 | 61 | @Override 62 | public String toString() { 63 | return "AlgorithmParams{" + 64 | "seed=" + seed + 65 | ", rank=" + rank + 66 | ", iteration=" + iteration + 67 | ", lambda=" + lambda + 68 | ", appName='" + appName + '\'' + 69 | ", similarItemEvents=" + similarItemEvents + 70 | ", unseenOnly=" + unseenOnly + 71 | ", seenItemEvents=" + seenItemEvents + 72 | '}'; 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/DataSource.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import com.google.common.collect.ImmutableMap; 4 | import com.google.common.collect.ImmutableSet; 5 | import org.apache.predictionio.controller.EmptyParams; 6 | import org.apache.predictionio.controller.java.PJavaDataSource; 7 | import org.apache.predictionio.data.storage.Event; 8 | import org.apache.predictionio.data.storage.PropertyMap; 9 | import org.apache.predictionio.data.store.java.OptionHelper; 10 | import org.apache.predictionio.data.store.java.PJavaEventStore; 11 | import org.apache.spark.SparkContext; 12 | import org.apache.spark.api.java.JavaPairRDD; 13 | import org.apache.spark.api.java.JavaRDD; 14 | import org.apache.spark.api.java.function.Function; 15 | import org.apache.spark.api.java.function.PairFunction; 16 | import org.apache.spark.rdd.RDD; 17 | import org.joda.time.DateTime; 18 | import scala.Option; 19 | import scala.Tuple2; 20 | import scala.Tuple3; 21 | import scala.collection.JavaConversions; 22 | import scala.collection.JavaConversions$; 23 | import scala.collection.Seq; 24 | 25 | import java.util.Collections; 26 | import java.util.HashMap; 27 | import java.util.HashSet; 28 | import java.util.List; 29 | import java.util.Map; 30 | import java.util.Set; 31 | 32 | public class DataSource extends PJavaDataSource> { 33 | 34 | private final DataSourceParams dsp; 35 | 36 | public DataSource(DataSourceParams dsp) { 37 | this.dsp = dsp; 38 | } 39 | 40 | @Override 41 | public TrainingData readTraining(SparkContext sc) { 42 | JavaPairRDD usersRDD = PJavaEventStore.aggregateProperties( 43 | dsp.getAppName(), 44 | "user", 45 | OptionHelper.none(), 46 | OptionHelper.none(), 47 | OptionHelper.none(), 48 | OptionHelper.>none(), 49 | sc) 50 | .mapToPair(new PairFunction, String, User>() { 51 | @Override 52 | public Tuple2 call(Tuple2 entityIdProperty) throws Exception { 53 | Set keys = JavaConversions$.MODULE$.setAsJavaSet(entityIdProperty._2().keySet()); 54 | Map properties = new HashMap<>(); 55 | for (String key : keys) { 56 | properties.put(key, entityIdProperty._2().get(key, String.class)); 57 | } 58 | 59 | User user = new User(entityIdProperty._1(), ImmutableMap.copyOf(properties)); 60 | 61 | return new Tuple2<>(user.getEntityId(), user); 62 | } 63 | }); 64 | 65 | JavaPairRDD itemsRDD = PJavaEventStore.aggregateProperties( 66 | dsp.getAppName(), 67 | "item", 68 | OptionHelper.none(), 69 | OptionHelper.none(), 70 | OptionHelper.none(), 71 | OptionHelper.>none(), 72 | sc) 73 | .mapToPair(new PairFunction, String, Item>() { 74 | @Override 75 | public Tuple2 call(Tuple2 entityIdProperty) throws Exception { 76 | List categories = entityIdProperty._2().getStringList("categories"); 77 | Item item = new Item(entityIdProperty._1(), ImmutableSet.copyOf(categories)); 78 | 79 | return new Tuple2<>(item.getEntityId(), item); 80 | } 81 | }); 82 | 83 | JavaRDD viewEventsRDD = PJavaEventStore.find( 84 | dsp.getAppName(), 85 | OptionHelper.none(), 86 | OptionHelper.none(), 87 | OptionHelper.none(), 88 | OptionHelper.some("user"), 89 | OptionHelper.none(), 90 | OptionHelper.some(Collections.singletonList("view")), 91 | OptionHelper.>none(), 92 | OptionHelper.>none(), 93 | sc) 94 | .map(new Function() { 95 | @Override 96 | public UserItemEvent call(Event event) throws Exception { 97 | return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.VIEW); 98 | } 99 | }); 100 | 101 | JavaRDD buyEventsRDD = PJavaEventStore.find( 102 | dsp.getAppName(), 103 | OptionHelper.none(), 104 | OptionHelper.none(), 105 | OptionHelper.none(), 106 | OptionHelper.some("user"), 107 | OptionHelper.none(), 108 | OptionHelper.some(Collections.singletonList("buy")), 109 | OptionHelper.>none(), 110 | OptionHelper.>none(), 111 | sc) 112 | .map(new Function() { 113 | @Override 114 | public UserItemEvent call(Event event) throws Exception { 115 | return new UserItemEvent(event.entityId(), event.targetEntityId().get(), event.eventTime().getMillis(), UserItemEventType.BUY); 116 | } 117 | }); 118 | 119 | return new TrainingData(usersRDD, itemsRDD, viewEventsRDD, buyEventsRDD); 120 | } 121 | 122 | @Override 123 | public Seq>>>> readEval(SparkContext sc) { 124 | TrainingData all = readTraining(sc); 125 | double[] split = {0.5, 0.5}; 126 | JavaRDD[] trainingAndTestingViews = all.getViewEvents().randomSplit(split, 1); 127 | JavaRDD[] trainingAndTestingBuys = all.getBuyEvents().randomSplit(split, 1); 128 | 129 | RDD>> queryActual = JavaPairRDD.toRDD(trainingAndTestingViews[1].union(trainingAndTestingBuys[1]).groupBy(new Function() { 130 | @Override 131 | public String call(UserItemEvent event) throws Exception { 132 | return event.getUser(); 133 | } 134 | }).mapToPair(new PairFunction>, Query, Set>() { 135 | @Override 136 | public Tuple2> call(Tuple2> userEvents) throws Exception { 137 | Query query = new Query(userEvents._1(), 10, Collections.emptySet(), Collections.emptySet(), Collections.emptySet()); 138 | Set actualSet = new HashSet<>(); 139 | for (UserItemEvent event : userEvents._2()) { 140 | actualSet.add(event.getItem()); 141 | } 142 | return new Tuple2<>(query, actualSet); 143 | } 144 | })); 145 | 146 | Tuple3>>> setData = new Tuple3<>(new TrainingData(all.getUsers(), all.getItems(), trainingAndTestingViews[0], trainingAndTestingBuys[0]), new EmptyParams(), queryActual); 147 | 148 | return JavaConversions.iterableAsScalaIterable(Collections.singletonList(setData)).toSeq(); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/DataSourceParams.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.Params; 4 | 5 | public class DataSourceParams implements Params{ 6 | private final String appName; 7 | 8 | public DataSourceParams(String appName) { 9 | this.appName = appName; 10 | } 11 | 12 | public String getAppName() { 13 | return appName; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Item.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | import java.util.Set; 5 | 6 | public class Item implements Serializable{ 7 | private final Set categories; 8 | private final String entityId; 9 | 10 | public Item(String entityId, Set categories) { 11 | this.categories = categories; 12 | this.entityId = entityId; 13 | } 14 | 15 | public String getEntityId() { 16 | return entityId; 17 | } 18 | 19 | public Set getCategories() { 20 | return categories; 21 | } 22 | 23 | @Override 24 | public String toString() { 25 | return "Item{" + 26 | "categories=" + categories + 27 | ", entityId='" + entityId + '\'' + 28 | '}'; 29 | } 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/ItemScore.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | 5 | public class ItemScore implements Serializable, Comparable { 6 | private final String itemEntityId; 7 | private final double score; 8 | 9 | public ItemScore(String itemEntityId, double score) { 10 | this.itemEntityId = itemEntityId; 11 | this.score = score; 12 | } 13 | 14 | public String getItemEntityId() { 15 | return itemEntityId; 16 | } 17 | 18 | public double getScore() { 19 | return score; 20 | } 21 | 22 | @Override 23 | public String toString() { 24 | return "ItemScore{" + 25 | "itemEntityId='" + itemEntityId + '\'' + 26 | ", score=" + score + 27 | '}'; 28 | } 29 | 30 | @Override 31 | public int compareTo(ItemScore o) { 32 | return Double.valueOf(score).compareTo(o.score); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Model.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.Params; 4 | import org.apache.predictionio.controller.PersistentModel; 5 | import org.apache.spark.SparkContext; 6 | import org.apache.spark.api.java.JavaPairRDD; 7 | import org.apache.spark.api.java.JavaRDD; 8 | import org.apache.spark.api.java.JavaSparkContext; 9 | import org.slf4j.Logger; 10 | import org.slf4j.LoggerFactory; 11 | import scala.Tuple2; 12 | 13 | import java.io.Serializable; 14 | import java.util.Collections; 15 | import java.util.Map; 16 | 17 | public class Model implements Serializable, PersistentModel { 18 | private static final Logger logger = LoggerFactory.getLogger(Model.class); 19 | private final JavaPairRDD userFeatures; 20 | private final JavaPairRDD> indexItemFeatures; 21 | private final JavaPairRDD userIndex; 22 | private final JavaPairRDD itemIndex; 23 | private final JavaRDD itemPopularityScore; 24 | private final Map items; 25 | 26 | public Model(JavaPairRDD userFeatures, JavaPairRDD> indexItemFeatures, JavaPairRDD userIndex, JavaPairRDD itemIndex, JavaRDD itemPopularityScore, Map items) { 27 | this.userFeatures = userFeatures; 28 | this.indexItemFeatures = indexItemFeatures; 29 | this.userIndex = userIndex; 30 | this.itemIndex = itemIndex; 31 | this.itemPopularityScore = itemPopularityScore; 32 | this.items = items; 33 | } 34 | 35 | public JavaPairRDD getUserFeatures() { 36 | return userFeatures; 37 | } 38 | 39 | public JavaPairRDD> getIndexItemFeatures() { 40 | return indexItemFeatures; 41 | } 42 | 43 | public JavaPairRDD getUserIndex() { 44 | return userIndex; 45 | } 46 | 47 | public JavaPairRDD getItemIndex() { 48 | return itemIndex; 49 | } 50 | 51 | public JavaRDD getItemPopularityScore() { 52 | return itemPopularityScore; 53 | } 54 | 55 | public Map getItems() { 56 | return items; 57 | } 58 | 59 | @Override 60 | public boolean save(String id, AlgorithmParams params, SparkContext sc) { 61 | userFeatures.saveAsObjectFile("/tmp/" + id + "/userFeatures"); 62 | indexItemFeatures.saveAsObjectFile("/tmp/" + id + "/indexItemFeatures"); 63 | userIndex.saveAsObjectFile("/tmp/" + id + "/userIndex"); 64 | itemIndex.saveAsObjectFile("/tmp/" + id + "/itemIndex"); 65 | itemPopularityScore.saveAsObjectFile("/tmp/" + id + "/itemPopularityScore"); 66 | new JavaSparkContext(sc).parallelize(Collections.singletonList(items)).saveAsObjectFile("/tmp/" + id + "/items"); 67 | 68 | logger.info("Saved model to /tmp/" + id); 69 | return true; 70 | } 71 | 72 | public static Model load(String id, Params params, SparkContext sc) { 73 | JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); 74 | JavaPairRDD userFeatures = JavaPairRDD.fromJavaRDD(jsc.>objectFile("/tmp/" + id + "/userFeatures")); 75 | JavaPairRDD> indexItemFeatures = JavaPairRDD.>fromJavaRDD(jsc.>>objectFile("/tmp/" + id + "/indexItemFeatures")); 76 | JavaPairRDD userIndex = JavaPairRDD.fromJavaRDD(jsc.>objectFile("/tmp/" + id + "/userIndex")); 77 | JavaPairRDD itemIndex = JavaPairRDD.fromJavaRDD(jsc.>objectFile("/tmp/" + id + "/itemIndex")); 78 | JavaRDD itemPopularityScore = jsc.objectFile("/tmp/" + id + "/itemPopularityScore"); 79 | Map items = jsc.>objectFile("/tmp/" + id + "/items").collect().get(0); 80 | 81 | logger.info("loaded model"); 82 | return new Model(userFeatures, indexItemFeatures, userIndex, itemIndex, itemPopularityScore, items); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/PredictedResult.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | public class PredictedResult implements Serializable{ 7 | private final List itemScores; 8 | 9 | public PredictedResult(List itemScores) { 10 | this.itemScores = itemScores; 11 | } 12 | 13 | public List getItemScores() { 14 | return itemScores; 15 | } 16 | 17 | @Override 18 | public String toString() { 19 | return "PredictedResult{" + 20 | "itemScores=" + itemScores + 21 | '}'; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Preparator.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.java.PJavaPreparator; 4 | import org.apache.spark.SparkContext; 5 | 6 | public class Preparator extends PJavaPreparator { 7 | 8 | @Override 9 | public PreparedData prepare(SparkContext sc, TrainingData trainingData) { 10 | return new PreparedData(trainingData); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/PreparedData.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | 5 | public class PreparedData implements Serializable { 6 | private final TrainingData trainingData; 7 | 8 | public PreparedData(TrainingData trainingData) { 9 | this.trainingData = trainingData; 10 | } 11 | 12 | public TrainingData getTrainingData() { 13 | return trainingData; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Query.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | import java.util.Collections; 5 | import java.util.Set; 6 | 7 | public class Query implements Serializable{ 8 | private final String userEntityId; 9 | private final int number; 10 | private final Set categories; 11 | private final Set whitelist; 12 | private final Set blacklist; 13 | 14 | public Query(String userEntityId, int number, Set categories, Set whitelist, Set blacklist) { 15 | this.userEntityId = userEntityId; 16 | this.number = number; 17 | this.categories = categories; 18 | this.whitelist = whitelist; 19 | this.blacklist = blacklist; 20 | } 21 | 22 | public String getUserEntityId() { 23 | return userEntityId; 24 | } 25 | 26 | public int getNumber() { 27 | return number; 28 | } 29 | 30 | public Set getCategories() { 31 | if (categories == null) return Collections.emptySet(); 32 | return categories; 33 | } 34 | 35 | public Set getWhitelist() { 36 | if (whitelist == null) return Collections.emptySet(); 37 | return whitelist; 38 | } 39 | 40 | public Set getBlacklist() { 41 | if (blacklist == null) return Collections.emptySet(); 42 | return blacklist; 43 | } 44 | 45 | @Override 46 | public String toString() { 47 | return "Query{" + 48 | "userEntityId='" + userEntityId + '\'' + 49 | ", number=" + number + 50 | ", categories=" + categories + 51 | ", whitelist=" + whitelist + 52 | ", blacklist=" + blacklist + 53 | '}'; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/RecommendationEngine.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.EmptyParams; 4 | import org.apache.predictionio.controller.Engine; 5 | import org.apache.predictionio.controller.EngineFactory; 6 | import org.apache.predictionio.core.BaseAlgorithm; 7 | import org.apache.predictionio.core.BaseEngine; 8 | 9 | import java.util.Collections; 10 | import java.util.Set; 11 | 12 | public class RecommendationEngine extends EngineFactory { 13 | 14 | @Override 15 | public BaseEngine> apply() { 16 | return new Engine<>( 17 | DataSource.class, 18 | Preparator.class, 19 | Collections.>>singletonMap("algo", Algorithm.class), 20 | Serving.class 21 | ); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/Serving.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.java.LJavaServing; 4 | import scala.collection.Seq; 5 | 6 | public class Serving extends LJavaServing { 7 | 8 | @Override 9 | public PredictedResult serve(Query query, Seq predictions) { 10 | return predictions.head(); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/TrainingData.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import org.apache.predictionio.controller.SanityCheck; 4 | import org.apache.spark.api.java.JavaPairRDD; 5 | import org.apache.spark.api.java.JavaRDD; 6 | 7 | import java.io.Serializable; 8 | 9 | public class TrainingData implements Serializable, SanityCheck { 10 | private final JavaPairRDD users; 11 | private final JavaPairRDD items; 12 | private final JavaRDD viewEvents; 13 | private final JavaRDD buyEvents; 14 | 15 | public TrainingData(JavaPairRDD users, JavaPairRDD items, JavaRDD viewEvents, JavaRDD buyEvents) { 16 | this.users = users; 17 | this.items = items; 18 | this.viewEvents = viewEvents; 19 | this.buyEvents = buyEvents; 20 | } 21 | 22 | public JavaPairRDD getUsers() { 23 | return users; 24 | } 25 | 26 | public JavaPairRDD getItems() { 27 | return items; 28 | } 29 | 30 | public JavaRDD getViewEvents() { 31 | return viewEvents; 32 | } 33 | 34 | public JavaRDD getBuyEvents() { 35 | return buyEvents; 36 | } 37 | 38 | @Override 39 | public void sanityCheck() { 40 | if (users.isEmpty()) { 41 | throw new AssertionError("User data is empty"); 42 | } 43 | if (items.isEmpty()) { 44 | throw new AssertionError("Item data is empty"); 45 | } 46 | if (viewEvents.isEmpty()) { 47 | throw new AssertionError("View Event data is empty"); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/User.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | 6 | public class User implements Serializable { 7 | private final String entityId; 8 | private final Map properties; 9 | 10 | public User(String entityId, Map properties) { 11 | this.entityId = entityId; 12 | this.properties = properties; 13 | } 14 | 15 | public String getEntityId() { 16 | return entityId; 17 | } 18 | 19 | public Map getProperties() { 20 | return properties; 21 | } 22 | 23 | @Override 24 | public String toString() { 25 | return "User{" + 26 | "entityId='" + entityId + '\'' + 27 | ", properties=" + properties + 28 | '}'; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/UserItemEvent.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | import java.io.Serializable; 4 | 5 | public class UserItemEvent implements Serializable { 6 | private final String user; 7 | private final String item; 8 | private final long time; 9 | private final UserItemEventType type; 10 | 11 | public UserItemEvent(String user, String item, long time, UserItemEventType type) { 12 | this.user = user; 13 | this.item = item; 14 | this.time = time; 15 | this.type = type; 16 | } 17 | 18 | public String getUser() { 19 | return user; 20 | } 21 | 22 | public String getItem() { 23 | return item; 24 | } 25 | 26 | public long getTime() { 27 | return time; 28 | } 29 | 30 | public UserItemEventType getType() { 31 | return type; 32 | } 33 | 34 | @Override 35 | public String toString() { 36 | return "UserItemEvent{" + 37 | "user='" + user + '\'' + 38 | ", item='" + item + '\'' + 39 | ", time=" + time + 40 | ", type=" + type + 41 | '}'; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/UserItemEventType.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation; 2 | 3 | public enum UserItemEventType { 4 | VIEW, BUY 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/evaluation/EvaluationParameter.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation.evaluation; 2 | 3 | import org.apache.predictionio.controller.EmptyParams; 4 | import org.apache.predictionio.controller.EngineParams; 5 | import org.apache.predictionio.controller.java.JavaEngineParamsGenerator; 6 | import org.example.recommendation.AlgorithmParams; 7 | import org.example.recommendation.DataSourceParams; 8 | 9 | import java.util.Arrays; 10 | import java.util.Collections; 11 | 12 | public class EvaluationParameter extends JavaEngineParamsGenerator { 13 | public EvaluationParameter() { 14 | this.setEngineParamsList( 15 | Collections.singletonList( 16 | new EngineParams( 17 | "", 18 | new DataSourceParams("javadase"), 19 | "", 20 | new EmptyParams(), 21 | Collections.singletonMap("algo", new AlgorithmParams(1, 10, 10, 0.01, "javadase", Collections.singletonList("view"), true, Arrays.asList("buy", "view"))), 22 | "", 23 | new EmptyParams() 24 | ) 25 | ) 26 | ); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/evaluation/EvaluationSpec.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation.evaluation; 2 | 3 | import org.apache.predictionio.controller.Engine; 4 | import org.apache.predictionio.controller.java.JavaEvaluation; 5 | import org.apache.predictionio.core.BaseAlgorithm; 6 | import org.example.recommendation.Algorithm; 7 | import org.example.recommendation.DataSource; 8 | import org.example.recommendation.PredictedResult; 9 | import org.example.recommendation.Preparator; 10 | import org.example.recommendation.PreparedData; 11 | import org.example.recommendation.Query; 12 | import org.example.recommendation.Serving; 13 | 14 | import java.util.Collections; 15 | 16 | public class EvaluationSpec extends JavaEvaluation { 17 | public EvaluationSpec() { 18 | this.setEngineMetric( 19 | new Engine<>( 20 | DataSource.class, 21 | Preparator.class, 22 | Collections.>>singletonMap("algo", Algorithm.class), 23 | Serving.class 24 | ), 25 | new PrecisionMetric() 26 | ); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/org/example/recommendation/evaluation/PrecisionMetric.java: -------------------------------------------------------------------------------- 1 | package org.example.recommendation.evaluation; 2 | 3 | import org.apache.predictionio.controller.EmptyParams; 4 | import org.apache.predictionio.controller.Metric; 5 | import org.apache.predictionio.controller.java.SerializableComparator; 6 | import org.apache.spark.SparkContext; 7 | import org.apache.spark.api.java.function.Function; 8 | import org.apache.spark.rdd.RDD; 9 | import org.example.recommendation.ItemScore; 10 | import org.example.recommendation.PredictedResult; 11 | import org.example.recommendation.Query; 12 | import scala.Tuple2; 13 | import scala.Tuple3; 14 | import scala.collection.JavaConversions; 15 | import scala.collection.Seq; 16 | 17 | import java.util.ArrayList; 18 | import java.util.HashSet; 19 | import java.util.List; 20 | import java.util.Set; 21 | 22 | public class PrecisionMetric extends Metric, Double> { 23 | 24 | private static final class MetricComparator implements SerializableComparator { 25 | @Override 26 | public int compare(Double o1, Double o2) { 27 | return o1.compareTo(o2); 28 | } 29 | } 30 | 31 | public PrecisionMetric() { 32 | super(new MetricComparator()); 33 | } 34 | 35 | @Override 36 | public Double calculate(SparkContext sc, Seq>>>> qpas) { 37 | List>>>> sets = JavaConversions.seqAsJavaList(qpas); 38 | List allSetResults = new ArrayList<>(); 39 | 40 | for (Tuple2>>> set : sets) { 41 | List setResults = set._2().toJavaRDD().map(new Function>, Double>() { 42 | @Override 43 | public Double call(Tuple3> qpa) throws Exception { 44 | Set predicted = new HashSet<>(); 45 | for (ItemScore itemScore : qpa._2().getItemScores()) { 46 | predicted.add(itemScore.getItemEntityId()); 47 | } 48 | Set intersection = new HashSet<>(predicted); 49 | intersection.retainAll(qpa._3()); 50 | 51 | return 1.0 * intersection.size() / qpa._2().getItemScores().size(); 52 | } 53 | }).collect(); 54 | 55 | allSetResults.addAll(setResults); 56 | } 57 | double sum = 0.0; 58 | for (Double value : allSetResults) sum += value; 59 | 60 | return sum / allSetResults.size(); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /template.json: -------------------------------------------------------------------------------- 1 | {"pio": {"version": { "min": "0.11.0-incubating" }}} 2 | --------------------------------------------------------------------------------