├── FlinkRL ├── pom.xml └── src │ └── main │ ├── java │ └── com │ │ └── mass │ │ ├── agg │ │ └── CountAgg.java │ │ ├── entity │ │ ├── RecordEntity.java │ │ ├── TopItemEntity.java │ │ └── UserConsumed.java │ │ ├── feature │ │ ├── BuildFeature.java │ │ ├── Embedding.java │ │ └── IdConverter.java │ │ ├── map │ │ └── RecordMap.java │ │ ├── process │ │ └── TopNPopularItems.java │ │ ├── recommend │ │ └── fastapi │ │ │ └── FastapiRecommender.java │ │ ├── sink │ │ ├── MongodbRecommendSink.java │ │ ├── MongodbRecordSink.java │ │ └── TopNRedisSink.java │ │ ├── source │ │ ├── CustomFileSource.java │ │ └── KafkaSource.java │ │ ├── task │ │ ├── FileToKafka.java │ │ ├── IntervalPopularItems.java │ │ ├── ReadFromFile.java │ │ ├── RecordToMongoDB.java │ │ ├── RecordToMysql.java │ │ └── SeqRecommend.java │ │ ├── util │ │ ├── FormatTimestamp.java │ │ ├── Property.java │ │ ├── RecordToEntity.java │ │ └── TypeConvert.java │ │ └── window │ │ ├── CountProcessWindowFunction.java │ │ ├── ItemCollectWindowFunction.java │ │ └── ItemCollectWindowFunctionRedis.java │ └── resources │ ├── config.properties │ ├── db │ └── create_table.sql │ ├── features │ ├── item_map.json │ └── user_map.json │ └── log4j.properties ├── README.md ├── README_zh.md ├── pic ├── 5.png ├── 6.jpg ├── readme1.png └── readme4.png └── python_api ├── bcq.py ├── ddpg.py ├── net.py ├── reinforce.py └── utils.py /FlinkRL/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.mass 8 | FlinkRL 9 | 1.0-SNAPSHOT 10 | 11 | 12 | UTF-8 13 | 1.8 14 | 1.9.3 15 | 2.11 16 | 1.8 17 | 1.8 18 | 19 | 20 | 21 | 22 | junit 23 | junit 24 | 4.12 25 | test 26 | 27 | 28 | 29 | org.slf4j 30 | slf4j-log4j12 31 | 1.7.25 32 | 33 | 34 | 35 | mysql 36 | mysql-connector-java 37 | 8.0.17 38 | 39 | 40 | 41 | org.apache.flink 42 | flink-java 43 | ${flink.version} 44 | 45 | 46 | 47 | org.apache.flink 48 | flink-streaming-java_2.11 49 | ${flink.version} 50 | 51 | 52 | 53 | org.apache.flink 54 | flink-clients_2.11 55 | ${flink.version} 56 | 57 | 58 | 59 | org.apache.flink 60 | flink-connector-kafka_2.11 61 | ${flink.version} 62 | 63 | 64 | 65 | org.mongodb 66 | mongodb-driver-sync 67 | 4.0.5 68 | 69 | 70 | 71 | org.mongodb 72 | mongodb-driver-legacy 73 | 4.0.5 74 | 75 | 76 | 77 | org.apache.flink 78 | flink-connector-redis_2.11 79 | 1.1.5 80 | 81 | 82 | 83 | org.json 84 | json 85 | 20180813 86 | 87 | 88 | 89 | 90 | 91 | 92 | org.apache.maven.plugins 93 | maven-compiler-plugin 94 | 3.8.1 95 | 96 | 1.8 97 | 1.8 98 | 99 | 100 | 101 | org.apache.maven.plugins 102 | maven-assembly-plugin 103 | 3.3.0 104 | 105 | 106 | package 107 | 108 | single 109 | 110 | 111 | 112 | jar-with-dependencies 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | alimaven 124 | aliyun maven 125 | http://maven.aliyun.com/nexus/content/groups/public/ 126 | 127 | true 128 | 129 | 130 | false 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/agg/CountAgg.java: -------------------------------------------------------------------------------- 1 | package com.mass.agg; 2 | 3 | import org.apache.flink.api.common.functions.AggregateFunction; 4 | import com.mass.entity.RecordEntity; 5 | 6 | public class CountAgg implements AggregateFunction { 7 | 8 | @Override 9 | public Long createAccumulator() { 10 | return 0L; 11 | } 12 | 13 | @Override 14 | public Long add(RecordEntity value, Long accumulator) { 15 | return accumulator + 1; 16 | } 17 | 18 | @Override 19 | public Long getResult(Long accumulator) { 20 | return accumulator; 21 | } 22 | 23 | @Override 24 | public Long merge(Long a, Long b) { 25 | return a + b; 26 | } 27 | } 28 | 29 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/entity/RecordEntity.java: -------------------------------------------------------------------------------- 1 | package com.mass.entity; 2 | 3 | public class RecordEntity { 4 | private int userId; 5 | private int itemId; 6 | private long time; 7 | 8 | public RecordEntity() { } 9 | 10 | public RecordEntity(int userId, int itemId, long time) { 11 | this.userId = userId; 12 | this.itemId = itemId; 13 | this.time = time; 14 | } 15 | 16 | public int getUserId() { 17 | return userId; 18 | } 19 | 20 | public void setUserId(int userId) { 21 | this.userId = userId; 22 | } 23 | 24 | public int getItemId() { 25 | return itemId; 26 | } 27 | 28 | public void setItemId(int itemId) { 29 | this.itemId = itemId; 30 | } 31 | 32 | public long getTime() { 33 | return time; 34 | } 35 | 36 | public void setTime(long time) { 37 | this.time = time; 38 | } 39 | 40 | public String toString() { 41 | return String.format("user: %5d, item: %5d, time: %5d", userId, itemId, time); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/entity/TopItemEntity.java: -------------------------------------------------------------------------------- 1 | package com.mass.entity; 2 | 3 | public class TopItemEntity { 4 | private int itemId; 5 | private int counts; 6 | private long windowEnd; 7 | 8 | public static TopItemEntity of(Integer itemId, long end, Long count) { 9 | TopItemEntity res = new TopItemEntity(); 10 | res.setCounts(count.intValue()); 11 | res.setItemId(itemId); 12 | res.setWindowEnd(end); 13 | return res; 14 | } 15 | 16 | public long getWindowEnd() { 17 | return windowEnd; 18 | } 19 | 20 | public void setWindowEnd(long windowEnd) { 21 | this.windowEnd = windowEnd; 22 | } 23 | 24 | 25 | public int getItemId() { 26 | return itemId; 27 | } 28 | 29 | public void setItemId(int productId) { 30 | this.itemId = productId; 31 | } 32 | 33 | public int getCounts() { 34 | return counts; 35 | } 36 | 37 | public void setCounts(int actionTimes) { 38 | this.counts = actionTimes; 39 | } 40 | 41 | public String toString() { 42 | return String.format("TopItemEntity: itemId:%5d, count:%3d, windowEnd: %5d", itemId, counts, windowEnd); 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/entity/UserConsumed.java: -------------------------------------------------------------------------------- 1 | package com.mass.entity; 2 | 3 | import java.util.List; 4 | 5 | public class UserConsumed { 6 | public int userId; 7 | public List items; 8 | public long windowEnd; 9 | 10 | public static UserConsumed of(int userId, List items, long windowEnd) { 11 | UserConsumed result = new UserConsumed(); 12 | result.userId = userId; 13 | result.items = items; 14 | result.windowEnd = windowEnd; 15 | return result; 16 | } 17 | 18 | public int getUserId() { 19 | return userId; 20 | } 21 | 22 | public List getItems() { 23 | return items; 24 | } 25 | 26 | public long getWindowEnd() { 27 | return windowEnd; 28 | } 29 | 30 | public String toString() { 31 | return String.format("userId:%5d, itemConsumed:%6s, windowEnd:%13s", 32 | userId, items, windowEnd); 33 | } 34 | } 35 | 36 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/feature/BuildFeature.java: -------------------------------------------------------------------------------- 1 | package com.mass.feature; 2 | 3 | import org.json.JSONArray; 4 | import org.json.JSONObject; 5 | import org.json.JSONTokener; 6 | 7 | import java.io.IOException; 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | public class BuildFeature { 12 | 13 | public static int getUserId(int userIndex) { 14 | return IdConverter.getUserId(userIndex); 15 | } 16 | 17 | public static int getItemId(int itemIndex) { 18 | return IdConverter.getItemId(itemIndex); 19 | } 20 | 21 | private static JSONArray getUserEmbedding(int userId) { 22 | // int userId = getUserId(userIndex); 23 | return Embedding.getEmbedding("user", userId); 24 | } 25 | 26 | private static JSONArray getItemEmbedding(int itemId) { 27 | // int itemId = getItemId(itemIndex); 28 | return Embedding.getEmbedding("item", itemId); 29 | } 30 | 31 | public static List getEmbedding(int user, List items) { 32 | List features = new ArrayList<>(); 33 | JSONArray userArray = getUserEmbedding(user); 34 | for (int i = 0; i < userArray.length(); i++) { 35 | features.add(userArray.getFloat(i)); 36 | } 37 | 38 | for (int i : items) { 39 | JSONArray itemArray = getItemEmbedding(i); 40 | for (int j = 0; j < itemArray.length(); j++) { 41 | features.add(itemArray.getFloat(j)); 42 | } 43 | } 44 | return features; 45 | } 46 | 47 | public static List getSeq(List items) { 48 | List seq = new ArrayList<>(); 49 | for (int i: items) { 50 | int itemId = getItemId(i); 51 | seq.add(itemId); 52 | } 53 | return seq; 54 | } 55 | 56 | public static List convertItems(List items) { 57 | return IdConverter.convertItems(items); 58 | } 59 | 60 | public static void close(Boolean withState) throws IOException { 61 | if (withState) { 62 | Embedding.close(); 63 | } 64 | IdConverter.close(); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/feature/Embedding.java: -------------------------------------------------------------------------------- 1 | package com.mass.feature; 2 | 3 | import org.json.JSONArray; 4 | import org.json.JSONObject; 5 | import org.json.JSONTokener; 6 | 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | import java.io.Serializable; 10 | 11 | public class Embedding { 12 | private static InputStream embedStream; 13 | private static JSONObject embedJSON; 14 | 15 | static { 16 | embedStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("features/embeddings.json"); 17 | try { 18 | embedJSON = new JSONObject(new JSONTokener(embedStream)); 19 | } catch (NullPointerException e) { 20 | e.printStackTrace(); 21 | } 22 | } 23 | 24 | public static JSONArray getEmbedding(String feat, int index) { 25 | JSONArray embeds = embedJSON.getJSONArray(feat); 26 | return embeds.getJSONArray(index); 27 | } 28 | 29 | public static void close() throws IOException { 30 | embedStream.close(); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/feature/IdConverter.java: -------------------------------------------------------------------------------- 1 | package com.mass.feature; 2 | 3 | import com.mass.util.Property; 4 | import com.typesafe.config.ConfigException; 5 | import org.json.JSONObject; 6 | import org.json.JSONTokener; 7 | 8 | import java.io.*; 9 | import java.util.ArrayList; 10 | import java.util.List; 11 | 12 | public class IdConverter { 13 | private static InputStream userStream; 14 | private static InputStream itemStream; 15 | private static JSONObject userJSON; 16 | private static JSONObject itemJSON; 17 | private static int numUsers; 18 | private static int numItems; 19 | 20 | static { 21 | try { 22 | userStream = new FileInputStream(Property.getStrValue("user_map")); 23 | itemStream = new FileInputStream(Property.getStrValue("item_map")); 24 | userJSON = new JSONObject(new JSONTokener(userStream)); 25 | itemJSON = new JSONObject(new JSONTokener(itemStream)); 26 | } catch (FileNotFoundException e) { 27 | e.printStackTrace(); 28 | } 29 | 30 | numUsers = userJSON.length(); 31 | numItems = itemJSON.length(); 32 | } 33 | 34 | public static int getUserId(int userIndex) { 35 | String user = String.valueOf(userIndex); 36 | return userJSON.has(user) ? userJSON.getInt(user) : numUsers; 37 | } 38 | 39 | public static int getItemId(int itemIndex) { 40 | String item = String.valueOf(itemIndex); 41 | return itemJSON.has(item) ? itemJSON.getInt(item) : numItems; 42 | } 43 | 44 | public static List convertItems(List items) { 45 | List converted = new ArrayList<>(); 46 | for (int i : items) { 47 | converted.add(getItemId(i)); 48 | } 49 | return converted; 50 | } 51 | 52 | public static void close() throws IOException { 53 | userStream.close(); 54 | itemStream.close(); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/map/RecordMap.java: -------------------------------------------------------------------------------- 1 | package com.mass.map; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.util.Property; 5 | import com.mass.util.RecordToEntity; 6 | import org.apache.flink.api.common.functions.RichMapFunction; 7 | import org.apache.flink.configuration.Configuration; 8 | import org.json.JSONObject; 9 | import org.json.JSONTokener; 10 | 11 | import java.io.FileInputStream; 12 | import java.io.FileNotFoundException; 13 | import java.io.InputStream; 14 | 15 | public class RecordMap extends RichMapFunction { 16 | private static InputStream userStream; 17 | private static InputStream itemStream; 18 | private static JSONObject userJSON; 19 | private static JSONObject itemJSON; 20 | private static int numUsers; 21 | private static int numItems; 22 | 23 | @Override 24 | public void open(Configuration parameters) { 25 | try { 26 | userStream = new FileInputStream(Property.getStrValue("user_map")); 27 | itemStream = new FileInputStream(Property.getStrValue("item_map")); 28 | userJSON = new JSONObject(new JSONTokener(userStream)); 29 | itemJSON = new JSONObject(new JSONTokener(itemStream)); 30 | } catch (FileNotFoundException e) { 31 | e.printStackTrace(); 32 | } 33 | 34 | numUsers = userJSON.length(); 35 | numItems = itemJSON.length(); 36 | } 37 | 38 | @Override 39 | public RecordEntity map(String value) { 40 | RecordEntity record = RecordToEntity.getRecord(value, userJSON, itemJSON, numUsers, numItems); 41 | return record; 42 | } 43 | 44 | @Override 45 | public void close() throws Exception { 46 | userStream.close(); 47 | itemStream.close(); 48 | } 49 | } 50 | 51 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/process/TopNPopularItems.java: -------------------------------------------------------------------------------- 1 | package com.mass.process; 2 | 3 | import com.mass.entity.TopItemEntity; 4 | import org.apache.flink.api.common.state.ListStateDescriptor; 5 | import org.apache.flink.api.java.tuple.Tuple; 6 | import org.apache.flink.configuration.Configuration; 7 | import org.apache.flink.streaming.api.functions.KeyedProcessFunction; 8 | import org.apache.flink.util.Collector; 9 | import org.apache.flink.api.common.state.ListState; 10 | 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | import java.util.Comparator; 14 | 15 | public class TopNPopularItems extends KeyedProcessFunction> { 16 | 17 | private ListState itemState; 18 | private final int topSize; 19 | 20 | public TopNPopularItems(int topSize) { 21 | this.topSize = topSize; 22 | } 23 | 24 | @Override 25 | public void open(Configuration parameters) throws Exception { 26 | super.open(parameters); 27 | ListStateDescriptor itemStateDesc = new ListStateDescriptor<>("itemState", TopItemEntity.class); 28 | itemState = getRuntimeContext().getListState(itemStateDesc); 29 | } 30 | 31 | @Override 32 | public void processElement(TopItemEntity value, Context ctx, Collector> out) throws Exception { 33 | itemState.add(value); 34 | ctx.timerService().registerEventTimeTimer(value.getWindowEnd() + 1); 35 | } 36 | 37 | @Override 38 | public void onTimer(long timestamp, OnTimerContext ctx, Collector> out) throws Exception { 39 | List allItems = new ArrayList<>(); 40 | for (TopItemEntity item : itemState.get()) { 41 | allItems.add(item); 42 | } 43 | itemState.clear(); 44 | 45 | allItems.sort(new Comparator() { 46 | @Override 47 | public int compare(TopItemEntity o1, TopItemEntity o2) { 48 | return (int) (o2.getCounts() - o1.getCounts()); 49 | } 50 | }); 51 | 52 | List ret = new ArrayList<>(); 53 | long lastTime = ctx.getCurrentKey().getField(0); 54 | ret.add(String.valueOf(lastTime)); // add timestamp at first 55 | allItems.forEach(i -> ret.add(String.valueOf(i.getItemId()))); 56 | out.collect(ret); 57 | } 58 | } 59 | 60 | 61 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/recommend/fastapi/FastapiRecommender.java: -------------------------------------------------------------------------------- 1 | package com.mass.recommend.fastapi; 2 | 3 | import com.mass.entity.UserConsumed; 4 | import com.mass.feature.BuildFeature; 5 | import com.mass.feature.IdConverter; 6 | import org.apache.commons.lang3.StringUtils; 7 | import org.apache.flink.api.common.functions.RichFlatMapFunction; 8 | import org.apache.flink.api.common.state.MapState; 9 | import org.apache.flink.api.common.state.MapStateDescriptor; 10 | import org.apache.flink.api.common.typeinfo.TypeHint; 11 | import org.apache.flink.api.common.typeinfo.TypeInformation; 12 | import org.apache.flink.api.java.tuple.Tuple4; 13 | import org.apache.flink.configuration.Configuration; 14 | import org.apache.flink.util.Collector; 15 | import org.json.JSONArray; 16 | import org.json.JSONObject; 17 | import redis.clients.jedis.Jedis; 18 | 19 | import java.io.BufferedReader; 20 | import java.io.DataOutputStream; 21 | import java.io.IOException; 22 | import java.io.InputStreamReader; 23 | import java.net.HttpURLConnection; 24 | import java.net.URL; 25 | import java.nio.charset.StandardCharsets; 26 | import java.util.ArrayList; 27 | import java.util.List; 28 | 29 | import static com.mass.util.FormatTimestamp.format; 30 | import static com.mass.util.TypeConvert.convertEmbedding; 31 | import static com.mass.util.TypeConvert.convertJSON; 32 | import static com.mass.util.TypeConvert.convertSeq; 33 | 34 | public class FastapiRecommender extends RichFlatMapFunction, String, Integer>> { 35 | private final int histNum; 36 | private final int numRec; 37 | private final String algo; 38 | private final Boolean withState; 39 | private static Jedis jedis; 40 | private static HttpURLConnection con; 41 | private static MapState> lastRecState; 42 | 43 | public FastapiRecommender(int numRec, int histNum, String algo, Boolean withState) { 44 | this.histNum = histNum; 45 | this.numRec = numRec; 46 | this.algo = algo; 47 | this.withState = withState; 48 | } 49 | 50 | @Override 51 | public void open(Configuration parameters) { 52 | jedis = new Jedis("localhost", 6379); 53 | MapStateDescriptor> lastRecStateDesc = new MapStateDescriptor<>("lastRecState", 54 | TypeInformation.of(new TypeHint() {}), TypeInformation.of(new TypeHint>() {})); 55 | lastRecState = getRuntimeContext().getMapState(lastRecStateDesc); 56 | } 57 | 58 | @Override 59 | public void close() throws IOException { 60 | jedis.close(); 61 | lastRecState.clear(); 62 | BuildFeature.close(withState); 63 | } 64 | 65 | private void buildConnection() throws IOException { 66 | String path = String.format("http://127.0.0.1:8000/%s", algo); 67 | if (withState) { 68 | path += "/state"; 69 | } 70 | URL obj = new URL(path); 71 | con = (HttpURLConnection) obj.openConnection(); 72 | con.setRequestMethod("POST"); 73 | con.setRequestProperty("Content-Type", "application/json"); 74 | // con.setRequestProperty("Connection", "Keep-Alive"); 75 | con.setConnectTimeout(5000); 76 | con.setReadTimeout(5000); 77 | con.setDoOutput(true); 78 | } 79 | 80 | private void writeOutputStream(String jsonString) throws IOException { 81 | DataOutputStream wr = new DataOutputStream(con.getOutputStream()); 82 | wr.write(jsonString.getBytes(StandardCharsets.UTF_8)); 83 | wr.flush(); 84 | wr.close(); 85 | } 86 | 87 | @Override 88 | public void flatMap(UserConsumed value, Collector, String, Integer>> out) throws Exception { 89 | buildConnection(); 90 | int userId = BuildFeature.getUserId(value.userId); 91 | List items = BuildFeature.convertItems(value.items); 92 | long timestamp = value.windowEnd; 93 | String time = format(timestamp); 94 | 95 | if (items.size() == this.histNum) { 96 | String jsonString; 97 | if (withState) { 98 | List stateEmbeds = BuildFeature.getEmbedding(userId, items); 99 | jsonString = convertEmbedding(stateEmbeds, userId, numRec); 100 | } else { 101 | jsonString = convertSeq(items, userId, numRec); 102 | } 103 | 104 | writeOutputStream(jsonString); 105 | int responseCode = con.getResponseCode(); 106 | // System.out.println("Posted parameters : " + jsonString); 107 | System.out.println("Response Code : " + responseCode); 108 | 109 | List recommend = new ArrayList<>(); 110 | String printOut; 111 | if (responseCode == 200) { 112 | BufferedReader br = new BufferedReader(new InputStreamReader(con.getInputStream())); 113 | String inputLine; 114 | StringBuilder message = new StringBuilder(); 115 | while ((inputLine = br.readLine()) != null) { 116 | message.append(inputLine); 117 | } 118 | br.close(); 119 | con.disconnect(); 120 | 121 | JSONArray res = convertJSON(message.toString()); 122 | for (int i = 0; i < res.length(); i++) { 123 | recommend.add(res.getInt(i)); 124 | } 125 | 126 | int hotRecommend = numRec - recommend.size(); 127 | if (hotRecommend > 0) { 128 | for (int i = 1; i < hotRecommend + 1; i++) { 129 | String item = jedis.get(String.valueOf(i)); 130 | if (null != item) { 131 | recommend.add(Integer.parseInt(item)); 132 | } 133 | } 134 | } 135 | printOut = String.format("user: %d, recommend(%d) + hot(%d): %s, time: %s", 136 | userId, recommend.size(), hotRecommend, recommend, time); 137 | } else { 138 | for (int i = 1; i < numRec + 1; i++) { 139 | String item = jedis.get(String.valueOf(i)); 140 | if (null != item) { 141 | recommend.add(Integer.parseInt(item)); 142 | } 143 | } 144 | printOut = String.format("user: %d, bad request, recommend hot(%d): %s, " + 145 | "time: %s", userId, recommend.size(), recommend, time); 146 | } 147 | 148 | System.out.println(printOut); 149 | System.out.println(StringUtils.repeat("=", 60)); 150 | 151 | int lastReward = updateLastReward(userId, items, recommend); 152 | out.collect(Tuple4.of(userId, recommend, time, lastReward)); 153 | } 154 | } 155 | 156 | private int updateLastReward(int userId, List items, List recommend) throws Exception { 157 | // for (Map.Entry> entry: lastRecState.entries()) { 158 | // System.out.println(entry.getKey() + " - " + entry.getValue()); 159 | // } 160 | // System.out.println("user: " + lastRecState.contains(userId)); 161 | 162 | int lastReward; 163 | List lastRecommend = lastRecState.get(userId); 164 | if (lastRecommend != null) { 165 | lastReward = 0; 166 | for (int rec : lastRecommend) { 167 | for (int click : items) { 168 | if (rec == click) { 169 | lastReward++; 170 | } 171 | } 172 | } 173 | } else { 174 | lastReward = -1; 175 | } 176 | 177 | lastRecState.put(userId, recommend); 178 | return lastReward; 179 | } 180 | } 181 | 182 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/sink/MongodbRecommendSink.java: -------------------------------------------------------------------------------- 1 | package com.mass.sink; 2 | 3 | import com.mongodb.client.MongoClient; 4 | import com.mongodb.client.MongoClients; 5 | import com.mongodb.client.MongoCollection; 6 | import com.mongodb.client.MongoDatabase; 7 | import org.apache.flink.api.java.tuple.Tuple4; 8 | import org.apache.flink.configuration.Configuration; 9 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; 10 | import org.bson.Document; 11 | 12 | import java.util.List; 13 | 14 | public class MongodbRecommendSink extends RichSinkFunction, String, Integer>> { 15 | static MongoClient mongoClient; 16 | static MongoDatabase database; 17 | static MongoCollection recCollection; 18 | 19 | @Override 20 | public void open(Configuration parameters) { 21 | mongoClient = MongoClients.create("mongodb://localhost:27017"); 22 | database = mongoClient.getDatabase("flink-rl"); 23 | recCollection = database.getCollection("recommendResults"); 24 | } 25 | 26 | @Override 27 | public void close() { 28 | if (mongoClient != null) { 29 | mongoClient.close(); 30 | } 31 | } 32 | 33 | @Override 34 | public void invoke(Tuple4, String, Integer> value, Context context) { 35 | Document recDoc = new Document() 36 | .append("user", value.f0) 37 | .append("itemRec", value.f1) 38 | .append("time", value.f2) 39 | .append("lastReward", value.f3); 40 | recCollection.insertOne(recDoc); 41 | } 42 | } 43 | 44 | 45 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/sink/MongodbRecordSink.java: -------------------------------------------------------------------------------- 1 | package com.mass.sink; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mongodb.client.*; 5 | import com.mongodb.client.model.Updates; 6 | import org.apache.flink.configuration.Configuration; 7 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; 8 | import org.bson.Document; 9 | 10 | import java.util.ArrayList; 11 | 12 | import static com.mass.util.FormatTimestamp.format; 13 | import static com.mongodb.client.model.Filters.eq; 14 | 15 | public class MongodbRecordSink extends RichSinkFunction { 16 | private static MongoClient mongoClient; 17 | private static MongoDatabase database; 18 | private static MongoCollection recordCollection; 19 | private static MongoCollection consumedCollection; 20 | 21 | @Override 22 | public void open(Configuration parameters) { 23 | mongoClient = MongoClients.create("mongodb://localhost:27017"); 24 | database = mongoClient.getDatabase("flink-rl"); 25 | recordCollection = database.getCollection("records"); 26 | consumedCollection = database.getCollection("userConsumed"); 27 | } 28 | 29 | @Override 30 | public void close() { 31 | if (mongoClient != null) { 32 | mongoClient.close(); 33 | } 34 | } 35 | 36 | @Override 37 | public void invoke(RecordEntity value, Context context) { 38 | Document recordDoc = new Document() 39 | .append("user", value.getUserId()) 40 | .append("item", value.getItemId()) 41 | .append("time", format(value.getTime())); 42 | recordCollection.insertOne(recordDoc); 43 | 44 | int userId = value.getUserId(); 45 | int itemId = value.getItemId(); 46 | long timestamp = value.getTime(); 47 | String time = format(timestamp); 48 | FindIterable fe = consumedCollection.find(eq("user", userId)); 49 | if (null == fe.first()) { 50 | Document consumedDoc = new Document(); 51 | consumedDoc.append("user", userId); 52 | consumedDoc.append("consumed", new ArrayList()); 53 | consumedDoc.append("time", new ArrayList()); 54 | consumedCollection.insertOne(consumedDoc); 55 | } else { 56 | consumedCollection.updateOne(eq("user", userId), Updates.addToSet("consumed", itemId)); 57 | consumedCollection.updateOne(eq("user", userId), Updates.addToSet("time", time)); 58 | } 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/sink/TopNRedisSink.java: -------------------------------------------------------------------------------- 1 | package com.mass.sink; 2 | 3 | import org.apache.flink.api.java.tuple.Tuple2; 4 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommand; 5 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommandDescription; 6 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisMapper; 7 | 8 | public class TopNRedisSink implements RedisMapper> { 9 | 10 | @Override 11 | public RedisCommandDescription getCommandDescription() { 12 | return new RedisCommandDescription(RedisCommand.SET); 13 | } 14 | 15 | @Override 16 | public String getKeyFromData(Tuple2 data) { 17 | return data.f1; // rank 18 | } 19 | 20 | @Override 21 | public String getValueFromData(Tuple2 data) { 22 | return data.f0; // item 23 | } 24 | } 25 | 26 | 27 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/source/CustomFileSource.java: -------------------------------------------------------------------------------- 1 | package com.mass.source; 2 | 3 | import com.mass.util.Property; 4 | import org.apache.flink.configuration.Configuration; 5 | import org.apache.flink.streaming.api.functions.source.RichSourceFunction; 6 | import com.mass.entity.RecordEntity; 7 | 8 | import java.io.BufferedReader; 9 | import java.io.FileReader; 10 | import java.io.IOException; 11 | 12 | public class CustomFileSource extends RichSourceFunction { 13 | 14 | private static BufferedReader br; 15 | private Boolean header; 16 | private String filePath; 17 | 18 | public CustomFileSource(String filePath, Boolean header) { 19 | this.filePath = filePath; 20 | this.header = header; 21 | } 22 | 23 | @Override 24 | public void open(Configuration parameters) throws IOException { 25 | // String dataPath = Thread.currentThread().getContextClassLoader().getResource(filePath).getFile(); 26 | String dataPath = Property.getStrValue("data.path"); 27 | br = new BufferedReader(new FileReader(dataPath)); 28 | } 29 | 30 | @Override 31 | public void close() throws IOException { 32 | br.close(); 33 | } 34 | 35 | @Override 36 | public void run(SourceContext ctx) throws IOException, InterruptedException { 37 | String temp; 38 | int i = 0; 39 | while ((temp = br.readLine()) != null) { 40 | i++; 41 | if (header && i == 1) continue; // skip header line 42 | String[] line = temp.split(","); 43 | int userId = Integer.valueOf(line[0]); 44 | int itemId = Integer.valueOf(line[1]); 45 | long time = Long.valueOf(line[3]); 46 | ctx.collect(new RecordEntity(userId, itemId, time)); 47 | Thread.sleep(5L); 48 | } 49 | } 50 | 51 | @Override 52 | public void cancel() { } 53 | } 54 | 55 | 56 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/source/KafkaSource.java: -------------------------------------------------------------------------------- 1 | package com.mass.source; 2 | 3 | import org.apache.kafka.clients.producer.KafkaProducer; 4 | import org.apache.kafka.clients.producer.ProducerRecord; 5 | import org.apache.kafka.clients.producer.RecordMetadata; 6 | import org.apache.kafka.clients.producer.Callback; 7 | 8 | import java.io.*; 9 | import java.util.Properties; 10 | 11 | public class KafkaSource { 12 | public static void sendData(Properties kafkaProps, String dataPath, Boolean header) 13 | throws IOException, InterruptedException { 14 | KafkaProducer producer = new KafkaProducer<>(kafkaProps); 15 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(dataPath))); 16 | String temp; 17 | int i = 0; 18 | while ((temp = br.readLine()) != null) { 19 | i++; 20 | if (header && i == 1) continue; // skip header line 21 | String[] splitted = temp.split(","); 22 | String[] record = {splitted[0], splitted[1], splitted[3]}; 23 | String csvString = String.join(",", record); 24 | producer.send(new ProducerRecord<>("flink-rl", "news", csvString), new Callback() { 25 | @Override 26 | public void onCompletion(RecordMetadata metadata, Exception e) { 27 | if (e != null) { 28 | e.printStackTrace(); 29 | } 30 | } 31 | }); 32 | Thread.sleep(10L); 33 | } 34 | 35 | br.close(); 36 | producer.close(); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/FileToKafka.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import com.mass.source.KafkaSource; 4 | import com.mass.util.Property; 5 | 6 | import java.io.IOException; 7 | import java.util.Properties; 8 | 9 | public class FileToKafka { 10 | public static void main(String[] args) throws IOException, InterruptedException { 11 | Properties kafkaProps = new Properties(); 12 | kafkaProps.setProperty("bootstrap.servers", "localhost:9092"); 13 | kafkaProps.setProperty("ack", "1"); 14 | kafkaProps.setProperty("batch.size", "16384"); 15 | kafkaProps.setProperty("linger.ms", "1"); 16 | kafkaProps.setProperty("buffer.memory", "33554432"); 17 | kafkaProps.setProperty("key.serializer", "org.apache.kafka.common.serialization.StringSerializer"); 18 | kafkaProps.setProperty("value.serializer", "org.apache.kafka.common.serialization.StringSerializer"); 19 | 20 | // ClassLoader classloader = Thread.currentThread().getContextClassLoader(); 21 | // InputStream is = classloader.getResourceAsStream("/news_data.csv"); 22 | // String dataPath = FileToKafka.class.getResource("/tianchi.csv").getFile(); 23 | String dataPath = Property.getStrValue("data.path"); 24 | KafkaSource.sendData(kafkaProps, dataPath, false); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/IntervalPopularItems.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import com.mass.agg.CountAgg; 4 | import com.mass.entity.RecordEntity; 5 | import com.mass.map.RecordMap; 6 | import com.mass.process.TopNPopularItems; 7 | import com.mass.sink.TopNRedisSink; 8 | import com.mass.window.CountProcessWindowFunction; 9 | import org.apache.flink.api.common.functions.FlatMapFunction; 10 | import org.apache.flink.api.common.serialization.SimpleStringSchema; 11 | import org.apache.flink.api.java.tuple.Tuple2; 12 | import org.apache.flink.streaming.api.TimeCharacteristic; 13 | import org.apache.flink.streaming.api.datastream.DataStream; 14 | import org.apache.flink.streaming.api.datastream.DataStreamSource; 15 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 16 | import org.apache.flink.streaming.api.functions.timestamps.AscendingTimestampExtractor; 17 | import org.apache.flink.streaming.api.windowing.time.Time; 18 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer; 19 | import org.apache.flink.streaming.connectors.redis.RedisSink; 20 | import org.apache.flink.streaming.connectors.redis.common.config.FlinkJedisPoolConfig; 21 | import org.apache.flink.util.Collector; 22 | import org.apache.kafka.common.serialization.StringDeserializer; 23 | 24 | import java.text.SimpleDateFormat; 25 | import java.util.Date; 26 | import java.util.List; 27 | import java.util.Properties; 28 | 29 | public class IntervalPopularItems { 30 | private static final int topSize = 17; 31 | private static SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); 32 | 33 | public static void main(String[] args) throws Exception { 34 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); 35 | env.setParallelism(1); 36 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); 37 | 38 | FlinkJedisPoolConfig redisConf = new FlinkJedisPoolConfig 39 | .Builder() 40 | .setHost("localhost") 41 | .setPort(6379) 42 | .build(); 43 | Properties prop = new Properties(); 44 | prop.setProperty("bootstrap.servers", "localhost:9092"); 45 | prop.setProperty("group.id", "rec"); 46 | prop.setProperty("key.deserializer", StringDeserializer.class.getName()); 47 | prop.setProperty("value.deserializer", 48 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName()); 49 | DataStreamSource sourceStream = env 50 | .addSource(new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop)); 51 | 52 | DataStream> topItem = sourceStream 53 | .map(new RecordMap()) 54 | .assignTimestampsAndWatermarks(new AscendingTimestampExtractor() { 55 | @Override 56 | public long extractAscendingTimestamp(RecordEntity element) { 57 | return element.getTime(); 58 | } 59 | }) 60 | // .keyBy("itemId") // if using POJOs, will convert to tuple 61 | // .keyBy(i -> i.getItemId()) 62 | .keyBy(RecordEntity::getItemId) 63 | .timeWindow(Time.minutes(60), Time.seconds(80)) 64 | .aggregate(new CountAgg(), new CountProcessWindowFunction()) 65 | .keyBy("windowEnd") 66 | .process(new TopNPopularItems(topSize)) 67 | .flatMap(new FlatMapFunction, Tuple2>() { 68 | @Override 69 | public void flatMap(List value, Collector> out) { 70 | showTime(value.get(0)); 71 | for (int rank = 1; rank < value.size(); ++rank) { 72 | String item = value.get(rank); 73 | System.out.println(String.format("item: %s, rank: %d", item, rank)); 74 | out.collect(Tuple2.of(item, String.valueOf(rank))); 75 | } 76 | } 77 | }); 78 | 79 | topItem.addSink(new RedisSink>(redisConf, new TopNRedisSink())); 80 | env.execute("IntervalPopularItems"); 81 | } 82 | 83 | private static void showTime(String timestamp) { 84 | long windowTime = Long.valueOf(timestamp); 85 | Date date = new Date(windowTime); 86 | System.out.println("============== Top N Items during " + sdf.format(date) + " =============="); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/ReadFromFile.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import org.apache.flink.api.common.functions.FlatMapFunction; 4 | import org.apache.flink.api.java.tuple.Tuple2; 5 | import org.apache.flink.streaming.api.datastream.DataStream; 6 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 7 | import org.apache.flink.util.Collector; 8 | 9 | import java.util.Arrays; 10 | 11 | public class ReadFromFile { 12 | public static void main(String[] args) throws Exception { 13 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); 14 | 15 | // DataStream stream = env.addSource(new CustomFileSource()); 16 | // String dataPath = ReadFromFile.class.getResource("/tianchi.csv").getFile(); 17 | String dataPath = Thread.currentThread().getContextClassLoader().getResource("tianchi.csv").getFile(); 18 | DataStream stream = env.readTextFile(dataPath); 19 | DataStream> result = stream.flatMap( 20 | new FlatMapFunction>() { 21 | @Override 22 | public void flatMap(String value, Collector> out) { 23 | String[] splitted = value.split(","); 24 | boolean first = false; 25 | try { 26 | Integer.parseInt(splitted[0]); 27 | } catch (NumberFormatException e) { 28 | System.out.println("skip header line..."); 29 | first = true; 30 | } 31 | if (!first) { 32 | String ratings = String.join(",", Arrays.copyOfRange(splitted, 4, 5)); 33 | out.collect(new Tuple2<>(Integer.valueOf(splitted[0]), ratings)); 34 | } 35 | } 36 | }).keyBy(0); 37 | 38 | result.print(); 39 | env.execute(); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/RecordToMongoDB.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.sink.MongodbRecordSink; 5 | import com.mass.source.CustomFileSource; 6 | import com.mass.util.Property; 7 | import com.mass.util.RecordToEntity; 8 | import org.apache.flink.api.common.functions.RichMapFunction; 9 | import org.apache.flink.api.common.serialization.SimpleStringSchema; 10 | import org.apache.flink.configuration.Configuration; 11 | import org.apache.flink.streaming.api.TimeCharacteristic; 12 | import org.apache.flink.streaming.api.datastream.DataStream; 13 | import org.apache.flink.streaming.api.datastream.DataStreamSource; 14 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 15 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer; 16 | import org.apache.kafka.common.serialization.StringDeserializer; 17 | import org.json.JSONObject; 18 | import org.json.JSONTokener; 19 | 20 | import java.io.FileInputStream; 21 | import java.io.FileNotFoundException; 22 | import java.io.InputStream; 23 | import java.util.Properties; 24 | 25 | public class RecordToMongoDB { 26 | public static void main(String[] args) throws Exception { 27 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); 28 | env.setParallelism(1); 29 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); 30 | 31 | Properties prop = new Properties(); 32 | prop.setProperty("bootstrap.servers", "localhost:9092"); 33 | prop.setProperty("group.id", "rec"); 34 | prop.setProperty("key.deserializer", StringDeserializer.class.getName()); 35 | prop.setProperty("value.deserializer", 36 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName()); 37 | DataStreamSource sourceStream = env.addSource( 38 | new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop)); 39 | 40 | DataStream recordStream = sourceStream.map(new RichMapFunction() { 41 | private InputStream userStream; 42 | private InputStream itemStream; 43 | private JSONObject userJSON; 44 | private JSONObject itemJSON; 45 | private int numUsers; 46 | private int numItems; 47 | 48 | @Override 49 | public void open(Configuration parameters) throws FileNotFoundException { 50 | userStream = new FileInputStream(Property.getStrValue("user_map")); 51 | itemStream = new FileInputStream(Property.getStrValue("item_map")); 52 | try { 53 | userJSON = new JSONObject(new JSONTokener(userStream)); 54 | itemJSON = new JSONObject(new JSONTokener(itemStream)); 55 | } catch (NullPointerException e) { 56 | e.printStackTrace(); 57 | } 58 | numUsers = userJSON.length(); 59 | numItems = itemJSON.length(); 60 | } 61 | 62 | @Override 63 | public RecordEntity map(String value) { 64 | return RecordToEntity.getRecord(value, userJSON, itemJSON, numUsers, numItems); 65 | } 66 | 67 | @Override 68 | public void close() throws Exception { 69 | userStream.close(); 70 | itemStream.close(); 71 | } 72 | } 73 | ); 74 | 75 | recordStream.addSink(new MongodbRecordSink()); 76 | env.execute("RecordToMongoDB"); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/RecordToMysql.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.util.RecordToEntity; 5 | import com.mass.map.RecordMap; 6 | import org.apache.flink.api.common.functions.RichMapFunction; 7 | import org.apache.flink.api.common.serialization.SimpleStringSchema; 8 | import org.apache.flink.streaming.api.datastream.DataStream; 9 | import org.apache.flink.streaming.api.datastream.DataStreamSource; 10 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 11 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer; 12 | import org.apache.kafka.common.serialization.StringDeserializer; 13 | import org.apache.flink.configuration.Configuration; 14 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; 15 | import org.apache.flink.streaming.api.functions.sink.SinkFunction; 16 | import org.json.JSONObject; 17 | import org.json.JSONTokener; 18 | 19 | import java.io.InputStream; 20 | import java.text.DateFormat; 21 | import java.text.SimpleDateFormat; 22 | import java.util.Properties; 23 | import java.sql.Connection; 24 | import java.sql.DriverManager; 25 | import java.sql.PreparedStatement; 26 | 27 | 28 | public class RecordToMysql { 29 | public static void main(String[] args) throws Exception { 30 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); 31 | env.setParallelism(1); 32 | 33 | Properties prop = new Properties(); 34 | prop.setProperty("bootstrap.servers", "localhost:9092"); 35 | prop.setProperty("group.id", "rec"); 36 | prop.setProperty("key.deserializer", StringDeserializer.class.getName()); 37 | prop.setProperty("value.deserializer", 38 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName()); 39 | DataStreamSource stream = env.addSource( 40 | new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop)); 41 | // stream.print(); 42 | 43 | DataStream stream2 = stream.map(new RecordMap()); 44 | stream2.addSink(new CustomMySqlSink()); 45 | 46 | env.execute("RecordToMySql"); 47 | } 48 | 49 | private static class CustomMySqlSink extends RichSinkFunction { 50 | Connection conn; 51 | PreparedStatement stmt; 52 | 53 | @Override 54 | public void open(Configuration parameters) throws Exception { 55 | conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/flink", "root", "123456"); 56 | stmt = conn.prepareStatement("INSERT INTO record (user, item, time) VALUES (?, ?, ?)"); 57 | } 58 | 59 | @Override 60 | public void close() throws Exception { 61 | stmt.close(); 62 | conn.close(); 63 | } 64 | 65 | @Override 66 | public void invoke(RecordEntity value, Context context) throws Exception { 67 | stmt.setInt(1, value.getUserId()); 68 | stmt.setInt(2, value.getItemId()); 69 | DateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); 70 | stmt.setString(3, sdf.format(value.getTime())); 71 | stmt.executeUpdate(); 72 | } 73 | } 74 | } 75 | 76 | 77 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/task/SeqRecommend.java: -------------------------------------------------------------------------------- 1 | package com.mass.task; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.recommend.fastapi.FastapiRecommender; 5 | // import com.mass.recommend.onnx.ONNXRecommender; 6 | import com.mass.sink.MongodbRecommendSink; 7 | import com.mass.source.CustomFileSource; 8 | import com.mass.window.ItemCollectWindowFunction; 9 | import org.apache.flink.streaming.api.TimeCharacteristic; 10 | import org.apache.flink.streaming.api.datastream.DataStream; 11 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; 12 | import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor; 13 | import org.apache.flink.streaming.api.windowing.time.Time; 14 | 15 | public class SeqRecommend { 16 | public static void main(String[] args) throws Exception { 17 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); 18 | env.setParallelism(1); 19 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); 20 | 21 | DataStream stream = env.addSource(new CustomFileSource("tianchi.csv", false)); 22 | stream.assignTimestampsAndWatermarks( 23 | new BoundedOutOfOrdernessTimestampExtractor(Time.seconds(10)) { 24 | @Override 25 | public long extractTimestamp(RecordEntity element) { 26 | return element.getTime(); 27 | } 28 | }) 29 | .keyBy("userId") // return tuple 30 | // .keyBy(RecordEntity -> RecordEntity.userId) // return Integer 31 | .timeWindow(Time.seconds(60), Time.seconds(20)) 32 | .process(new ItemCollectWindowFunction(10)) 33 | .keyBy("windowEnd") 34 | .flatMap(new FastapiRecommender(8, 10, "ddpg", false)) 35 | .addSink(new MongodbRecommendSink()); 36 | 37 | env.execute("ItemSeqRecommend"); 38 | } 39 | } 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/util/FormatTimestamp.java: -------------------------------------------------------------------------------- 1 | package com.mass.util; 2 | 3 | import java.text.SimpleDateFormat; 4 | import java.util.Date; 5 | 6 | public class FormatTimestamp { 7 | public static String format(long timestamp) { 8 | Date date = new Date(timestamp); 9 | SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); 10 | return sdf.format(date); 11 | } 12 | } 13 | 14 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/util/Property.java: -------------------------------------------------------------------------------- 1 | package com.mass.util; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStream; 5 | import java.io.InputStreamReader; 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.Properties; 8 | 9 | public class Property { 10 | private static Properties contextProperties; 11 | private final static String CONFIG_NAME = "config.properties"; 12 | 13 | static { 14 | InputStream in = Thread.currentThread().getContextClassLoader().getResourceAsStream(CONFIG_NAME); 15 | contextProperties = new Properties(); 16 | try { 17 | InputStreamReader inputStreamReader = new InputStreamReader(in, StandardCharsets.UTF_8); 18 | contextProperties.load(inputStreamReader); 19 | } catch (IOException e) { 20 | System.err.println("===== Failed to load config file. ====="); 21 | e.printStackTrace(); 22 | } 23 | } 24 | 25 | public static String getStrValue(String key) { 26 | return contextProperties.getProperty(key); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/util/RecordToEntity.java: -------------------------------------------------------------------------------- 1 | package com.mass.util; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import org.json.JSONObject; 5 | 6 | public class RecordToEntity { 7 | public static RecordEntity getRecord(String s, JSONObject userJSON, JSONObject itemJSON, int numUsers, int numItems) { 8 | String[] values = s.split(","); 9 | RecordEntity record = new RecordEntity(); 10 | record.setUserId(userJSON.has(values[0]) ? userJSON.getInt(values[0]) : numUsers); 11 | record.setItemId(itemJSON.has(values[1]) ? itemJSON.getInt(values[1]) : numItems); 12 | record.setTime(Long.parseLong(values[2])); 13 | return record; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/util/TypeConvert.java: -------------------------------------------------------------------------------- 1 | package com.mass.util; 2 | 3 | import org.apache.commons.lang3.StringUtils; 4 | import org.json.JSONArray; 5 | import org.json.JSONObject; 6 | 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | public class TypeConvert { 11 | 12 | public static String convertString(List items) { 13 | StringBuilder sb = new StringBuilder("{\"columns\": [\"x\"], \"data\": [[["); 14 | for (Integer item : items) { 15 | sb.append(String.format("%d,", item)); 16 | } 17 | sb.deleteCharAt(sb.length() - 1); 18 | sb.append("]]]}"); 19 | return sb.toString(); 20 | } 21 | 22 | public static JSONArray convertJSON(String message) { 23 | JSONArray resArray = new JSONArray(message); 24 | return resArray.getJSONArray(0); 25 | } 26 | 27 | public static String convertEmbedding(List embeds, int user, int numRec) { 28 | // String repeat = StringUtils.repeat("\"x\",", embeds.size()); 29 | // String columns = "{\"columns\": [" + repeat.substring(0, repeat.length() - 1) + "], "; 30 | String state = "{\"user\": " + user + ", \"n_rec\": " + numRec + ", "; 31 | StringBuilder sb = new StringBuilder("\"embedding\": [["); 32 | for (float num : embeds) { 33 | sb.append(String.format("%f,", num)); 34 | } 35 | sb.deleteCharAt(sb.length() - 1); 36 | sb.append("]]}"); 37 | return state + sb.toString(); 38 | } 39 | 40 | public static String convertSeq(List seq, int user, int numRec) { 41 | String state = "{\"user\": [" + user + "], \"n_rec\": " + numRec + ", "; 42 | StringBuilder sb = new StringBuilder("\"item\": [["); 43 | for (int item : seq) { 44 | sb.append(String.format("%d,", item)); 45 | } 46 | sb.deleteCharAt(sb.length() - 1); 47 | sb.append("]]}"); 48 | return state + sb.toString(); 49 | } 50 | 51 | public static void main(String[] args) { 52 | Float[] aa = {1.2f, 3.5f, 5.666f}; 53 | List bb = Arrays.asList(aa); 54 | System.out.println(convertEmbedding(bb, 1, 10)); 55 | Integer[] cc = {1, 2, 3}; 56 | List dd = Arrays.asList(cc); 57 | System.out.println(convertSeq(dd, 1, 10)); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/window/CountProcessWindowFunction.java: -------------------------------------------------------------------------------- 1 | package com.mass.window; 2 | 3 | import com.mass.entity.TopItemEntity; 4 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; 5 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow; 6 | import org.apache.flink.util.Collector; 7 | 8 | public class CountProcessWindowFunction extends ProcessWindowFunction { 9 | 10 | @Override 11 | public void process(Integer itemId, Context ctx, Iterable aggregateResults, Collector out) { 12 | long windowEnd = ctx.window().getEnd(); 13 | long count = aggregateResults.iterator().next(); 14 | out.collect(TopItemEntity.of(itemId, windowEnd, count)); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/window/ItemCollectWindowFunction.java: -------------------------------------------------------------------------------- 1 | package com.mass.window; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.entity.UserConsumed; 5 | import org.apache.flink.api.common.state.ValueState; 6 | import org.apache.flink.api.common.state.ValueStateDescriptor; 7 | import org.apache.flink.api.common.typeinfo.TypeHint; 8 | import org.apache.flink.api.common.typeinfo.TypeInformation; 9 | import org.apache.flink.api.java.tuple.Tuple; 10 | import org.apache.flink.configuration.Configuration; 11 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; 12 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow; 13 | import org.apache.flink.util.Collector; 14 | 15 | import java.io.IOException; 16 | import java.util.LinkedList; 17 | 18 | public class ItemCollectWindowFunction extends ProcessWindowFunction { 19 | 20 | private final int histNum; 21 | private ValueState> itemState; 22 | 23 | public ItemCollectWindowFunction(int histNum) { 24 | this.histNum = histNum; 25 | } 26 | 27 | @Override 28 | public void open(Configuration parameters) throws Exception { 29 | super.open(parameters); 30 | ValueStateDescriptor> itemStateDesc = new ValueStateDescriptor<>( 31 | "itemState", TypeInformation.of(new TypeHint>() {}), new LinkedList<>()); 32 | itemState = getRuntimeContext().getState(itemStateDesc); 33 | } 34 | 35 | @Override 36 | public void process(Tuple key, Context ctx, Iterable elements, Collector out) 37 | throws IOException { 38 | int userId = key.getField(0); 39 | long windowEnd = ctx.window().getEnd(); 40 | 41 | LinkedList histItems = itemState.value(); 42 | // if (histItems == null) { 43 | // histItems = new LinkedList<>(); 44 | // } 45 | elements.forEach(e -> histItems.add(e.getItemId())); 46 | while (!histItems.isEmpty() && histItems.size() > histNum) { 47 | histItems.poll(); 48 | } 49 | itemState.update(histItems); 50 | out.collect(UserConsumed.of(userId, histItems, windowEnd)); 51 | } 52 | 53 | @Override 54 | public void close() { 55 | itemState.clear(); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /FlinkRL/src/main/java/com/mass/window/ItemCollectWindowFunctionRedis.java: -------------------------------------------------------------------------------- 1 | package com.mass.window; 2 | 3 | import com.mass.entity.RecordEntity; 4 | import com.mass.entity.UserConsumed; 5 | import org.apache.flink.api.java.tuple.Tuple; 6 | import org.apache.flink.configuration.Configuration; 7 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; 8 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow; 9 | import org.apache.flink.util.Collector; 10 | import redis.clients.jedis.Jedis; 11 | 12 | import java.util.ArrayList; 13 | import java.util.List; 14 | 15 | public class ItemCollectWindowFunctionRedis extends ProcessWindowFunction { 16 | 17 | private final int histNum; 18 | private static Jedis jedis; 19 | 20 | public ItemCollectWindowFunctionRedis(int histNum) { 21 | this.histNum = histNum; 22 | } 23 | 24 | @Override 25 | public void open(Configuration parameters) throws Exception { 26 | super.open(parameters); 27 | jedis = new Jedis("localhost", 6379); 28 | } 29 | 30 | @Override 31 | public void process(Tuple key, Context ctx, Iterable elements, Collector out) { 32 | int userId = key.getField(0); 33 | long windowEnd = ctx.window().getEnd(); 34 | 35 | String keyList = "hist_" + userId; 36 | elements.forEach(e -> jedis.rpush(keyList , String.valueOf(e.getItemId()))); 37 | while (jedis.llen(keyList) > histNum) { 38 | jedis.lpop(keyList); 39 | } 40 | 41 | List histItemsJedis = jedis.lrange(keyList, 0, -1); 42 | List histItems = new ArrayList<>(); 43 | for (String s : histItemsJedis) { 44 | histItems.add(Integer.parseInt(s)); 45 | } 46 | out.collect(UserConsumed.of(userId, histItems, windowEnd)); 47 | } 48 | 49 | @Override 50 | public void close() { 51 | jedis.close(); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /FlinkRL/src/main/resources/config.properties: -------------------------------------------------------------------------------- 1 | data.path=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/tianchi.csv 2 | user_map=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/features/user_map.json 3 | item_map=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/features/item_map.json -------------------------------------------------------------------------------- /FlinkRL/src/main/resources/db/create_table.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS record 2 | ( 3 | user int NOT NULL, 4 | item int NOT NULL, 5 | time TIMESTAMP NOT NULL, 6 | PRIMARY KEY (user, item) 7 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 8 | 9 | 10 | CREATE TABLE IF NOT EXISTS top 11 | ( 12 | item int NOT NULL, 13 | rank VARCHAR(5) NOT NULL, 14 | PRIMARY KEY (item, rank) 15 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8; 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /FlinkRL/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=WARN, console 2 | log4j.appender.console=org.apache.log4j.ConsoleAppender 3 | log4j.appender.console.target=System.out 4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{2}: %m%n 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flink-Reinforcement-Learning 2 | 3 | ### `English`   [`简体中文`](https://github.com/massquantity/Flink-Reinforcement-Learning/blob/master/README_zh.md)   [`blog post`](https://www.cnblogs.com/massquantity/p/13842139.html) 4 | 5 |
6 | 7 | FlinkRL is a realtime recommender system based upon Flink and Reinforcement Learning. Specifically, Flink is famous for its high-performance stateful streaming processing, which can provide fast and accurate response to the user request in this system. And reinforcement learning is good at planning long-term reward and adjust recommendation results quickly according to realtime user feedback. The combination of these two components makes the system capture the dynamic change pattern of user interests and provide insightful recommendations. 8 | 9 | FlinkRL is mainly used for online inference, and the offline model training part is implemented in another repo, i.e. [DBRL](https://github.com/massquantity/DBRL). The full system architecture is as follows: 10 | 11 | ![](https://s1.ax1x.com/2020/10/19/0x5t4P.png) 12 | 13 | 14 | 15 | ## Main Workflow 16 | 17 | To simulate an online environment, a dataset is used as a producer in order to send data to Kafka, then Flink consumes data from Kafka. Afterwards, Flink will execute three tasks: 18 | 19 | + Save user behavior log to MongoDB and MySQL, and the log can be used for offline model training. 20 | + Compute realtime top-N popular items and save them to Redis. 21 | + Collect user's recent consumed items and post data to a web service created through [FastAPI](https://github.com/tiangolo/fastapi) to get recommendation results, then save the recommendation to MongoDB. 22 | 23 | Thus an online web server can directly take recommendation results and popular items from the databases and send them to the client. Yet another consideration is why using FastAPI to build another web service ? Because the model is trained through PyTorch, and there is seemingly no unified way to deploy a PyTorch model. So FastAPI is used to load the trained model, do some processing and make final recommendation. 24 | 25 | 26 | 27 | ## Data 28 | 29 | The dataset comes from a competition held by Tianchi, a Chinese competition platform. Please refer to the original website for [full description](https://tianchi.aliyun.com/competition/entrance/231721/information?lang=en-us). Note that here we only use the round2 data. 30 | 31 | You can also download the data from [Google Drive](https://drive.google.com/file/d/1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-/view?usp=sharing). 32 | 33 | 34 | 35 | ## Usage 36 | 37 | Python dependencies: python>=3.6, numpy, pandas, torch>=1.3, tqdm, FastAPI. 38 | 39 | You'll need to clone both `FlinkRL` and `DBRL`: 40 | 41 | ```shell 42 | $ git clone https://github.com/massquantity/Flink-Reinforcement-Learning.git 43 | $ git clone https://github.com/massquantity/DBRL.git 44 | ``` 45 | 46 | We'll first use the `DBRL` repo to do offline training. After downloading the data, unzip and put them into the `DBRL/dbrl/resources` folder. The original dataset consists of three tables: `user.csv`, `item.csv`, `user_behavior.csv` . We'll first need to filter some users with too few interactions and merge all features together, and this is accomplished by `run_prepare_data.py`. Then we'll pretrain embeddings for every user and item by running `run_pretrain_embeddings.py` : 47 | 48 | ```shell 49 | $ cd DBRL/dbrl 50 | $ python run_prepare_data.py 51 | $ python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4 52 | ``` 53 | 54 | You can tune the `lr` and `n_epochs` hyper-parameters to get better evaluate loss. Then we begin to train the model. Currently there are three algorithms in `DBRL`, so we can choose one of them: 55 | 56 | ```shell 57 | $ python run_reinforce.py --n_epochs 5 --lr 1e-5 58 | $ python run_ddpg.py --n_epochs 5 --lr 1e-5 59 | $ python run_bcq.py --n_epochs 5 --lr 1e-5 60 | ``` 61 | 62 | At this point, the `DBRL/resources` should contains at least 6 files: 63 | 64 | + `model_xxx.pt`, the trained pytorch model. 65 | + `tianchi.csv`, the transformed dataset. 66 | + `tianchi_user_embeddings.npy`, the pretrained user embeddings in numpy `npy` format. 67 | + `tianchi_item_embeddings.npy`, the pretrained item embeddings in numpy `npy` format. 68 | + `user_map.json`, a json file that maps original user ids to ids used in the model. 69 | + `item_map.json`, a json file that maps original item ids to ids used in the model. 70 | 71 | 72 | 73 | After the offline training, we then turn to `FlinkRL` . First put three files: `model_xxx.pt`, `tianchi_user_embeddings.npy`, `tianchi_item_embeddings.npy` into the `Flink-Reinforcement-Learning/python_api` folder. Make sure you have already installed [FastAPI](https://github.com/tiangolo/fastapi), then start the service: 74 | 75 | ```shell 76 | $ gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is reinforce 77 | $ gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is ddpg 78 | $ gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is bcq 79 | ``` 80 | 81 | 82 | 83 | The other three files : `tianchi.csv`, `user_map.json`, `item_map.json` are used in Flink, and in principle they can be put in anywhere, so long as you specify the absolute path in the `Flink-Reinforcement-Learning/FlinkRL/src/main/resources/config.properties` file. 84 | 85 | For quick running, you can directly import `FlinkRL` into an IDE, i.e. `IntelliJ IDEA`. To run on a cluster, we use Maven to package into a jar file: 86 | 87 | ```shell 88 | $ cd FlinkRL 89 | $ mvn clean package 90 | ``` 91 | 92 | Put the generated `FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar` into the Flink installation directory. For now, assume `Kafka`, `MongoDB` and `Redis` have all been started, then we can start the Flink cluster and run tasks: 93 | 94 | ```shell 95 | $ kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic flink-rl # cereate a topic called flink-rl in Kafka 96 | $ use flink-rl # cereate a database called flink-rl in MongoDB 97 | ``` 98 | 99 | ```shell 100 | $ # first cd into the Flink installation folder 101 | $ ./bin/start-cluster.sh # start the cluster 102 | $ ./bin/flink run --class com.mass.task.FileToKafka FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # import data into Kafka 103 | $ ./bin/flink run --class com.mass.task.RecordToMongoDB FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # save log to MongoDB 104 | $ ./bin/flink run --class com.mass.task.IntervalPopularItems FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # compute realtime popular items 105 | $ ./bin/flink run --class com.mass.task.SeqRecommend FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # recommend items using reinforcement learning model 106 | $ ./bin/stop-cluster.sh # stop the cluster 107 | ``` 108 | 109 | Open [http://localhost:8081](http://localhost:8081/) to use Flink WebUI : 110 | 111 | ![](https://s1.ax1x.com/2020/10/19/0zCM2F.png) 112 | 113 | 114 | 115 | FastAPI also comes with an interactive WebUI, you can access it through http://127.0.0.1:8000/docs : 116 | 117 | ![](https://s1.ax1x.com/2020/10/19/0x58HA.jpg) 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # Flink-Reinforcement-Learning 2 | 3 | ### [`English`](https://github.com/massquantity/Flink-Reinforcement-Learning)   `简体中文`   [`博客文章`](https://www.cnblogs.com/massquantity/p/13842139.html) 4 | 5 |
6 | 7 | FlinkRL 是一个基于 Flink 和强化学习的推荐系统。具体来说,Flink 因其高性能有状态流处理而闻名,这能使系统快速而准确地响应用户请求。而强化学习则擅长规划长期收益,以及根据用户实时反馈快速调整推荐结果。二者的结合使得推荐系统能够捕获用户兴趣的动态变化规律并提供富有洞察力的推荐结果。 8 | 9 | FlinkRL 主要用于在线推理,而离线训练在另一个仓库 [DBRL](https://github.com/massquantity/DBRL) 中实现,下图是整体系统架构: 10 | 11 | ![](https://s1.ax1x.com/2020/10/19/0x5Qje.png) 12 | 13 | 14 | 15 | ## 主要流程 16 | 17 | 为模拟在线环境,由一个数据集作为生产者发送数据到 Kafka,Flink 则从 Kafka 消费数据。之后,Flink 可以执行三种任务: 18 | 19 | + 保存用户行为日志到 MongoDB 和 MySQL,日志可用于离线模型训练。 20 | + 计算实时 top-N 热门物品,保存到 Redis 。 21 | + 收集用户近期消费过的物品,发送数据到由 [FastAPI](https://github.com/tiangolo/fastapi) 创建的网络服务中,得到推荐结果并将其存到 MongoDB 。 22 | 23 | 之后在线服务器可以直接从数据库中调用推荐结果和热门物品发送到客户端。这里使用 FastAPI 建立另一个服务是因为线下训练使用的是 PyTorch,对于 PyTorch 模型现在貌似没有一个统一的部署方案,所以 FastAPI 用于载入模型,预处理和产生最终推荐结果。 24 | 25 | 26 | 27 | ## 数据 28 | 29 | 数据来源于天池的一个比赛,详情可参阅[官方网站](https://tianchi.aliyun.com/competition/entrance/231721/information?lang=zh-cn) ,注意这里只是用了第二轮的数据。也可以从 [Google Drive](https://drive.google.com/file/d/1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-/view?usp=sharing) 下载。 30 | 31 | 32 | 33 | ## 使用步骤 34 | 35 | Python 依赖库:python>=3.6, numpy, pandas, torch>=1.3, tqdm, FastAPI 36 | 37 | 首先 clone 两个仓库 38 | 39 | ```shell 40 | $ git clone https://github.com/massquantity/Flink-Reinforcement-Learning.git 41 | $ git clone https://github.com/massquantity/DBRL.git 42 | ``` 43 | 44 | 首先使用 `DBRL` 作离线训练。下载完数据后,解压并放到 `DBRL/dbrl/resources` 文件夹中。原始数据有三张表:`user.csv`, `item.csv`, `user_behavior.csv` 。首先用脚本 `run_prepare_data.py` 过滤掉一些行为太少的用户并将所有特征合并到一张表。接着用 `run_pretrain_embeddings.py` 为每个用户和物品预训练 embedding: 45 | 46 | ```shell 47 | $ cd DBRL/dbrl 48 | $ python run_prepare_data.py 49 | $ python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4 50 | ``` 51 | 52 | 可以调整一些参数如 `lr` 和 `n_epochs` 来获得更好的评估效果。接下来开始训练模型,现在在 `DBRL` 中有三种模型,任选一种即可: 53 | 54 | ```shell 55 | $ python run_reinforce.py --n_epochs 5 --lr 1e-5 56 | $ python run_ddpg.py --n_epochs 5 --lr 1e-5 57 | $ python run_bcq.py --n_epochs 5 --lr 1e-5 58 | ``` 59 | 60 | 到这个阶段,`DBRL/resources` 中应该至少有 6 个文件: 61 | 62 | + `model_xxx.pt`, 训练好的 PyTorch 模型。 63 | + `tianchi.csv`, 转换过的数据集。 64 | + `tianchi_user_embeddings.npy`, `npy` 格式的 user 预训练 embedding。 65 | + `tianchi_item_embeddings.npy`, `npy` 格式的 item 预训练 embedding。 66 | + `user_map.json`, 将原始用户 id 映射到模型中 id 的 json 文件。 67 | + `item_map.json`, 将原始物品 id 映射到模型中 id 的 json 文件。 68 | 69 | 70 | 71 | 离线训练完成后,接下来使用 `FlinkRL` 作在线推理。首先将三个文件 `model_xxx.pt`, `tianchi_user_embeddings.npy`, `tianchi_item_embeddings.npy` 放入 `Flink-Reinforcement-Learning/python_api` 文件夹。确保已经安装了[FastAPI](https://github.com/tiangolo/fastapi), 就可以启动服务了: 72 | 73 | ```shell 74 | $ gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is reinforce 75 | $ gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is ddpg 76 | $ gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is bcq 77 | ``` 78 | 79 | 80 | 81 | 另外三个文件:`tianchi.csv`, `user_map.json`, `item_map.json` 在 Flink 中使用,原则上可以放在任何地方,只要在 `Flink-Reinforcement-Learning/FlinkRL/src/main/resources/config.properties` 中指定绝对路径。 82 | 83 | 如果想要快速开始,可以直接将 `FlinkRL` 导入到 `IntelliJ IDEA` 这样的 IDE 中。而要在集群上运行,则使用 Maven 达成 jar 包: 84 | 85 | ```shell 86 | $ cd FlinkRL 87 | $ mvn clean package 88 | ``` 89 | 90 | 将生成的 `FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar` 放到 Flink 安装目录。现在假设 `Kafka`, `MongoDB` and `Redis` 都已然启动,然后就可以启动 Flink 集群并执行任务: 91 | 92 | ```shell 93 | $ kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic flink-rl # cereate a topic called flink-rl in Kafka 94 | $ use flink-rl # cereate a database called flink-rl in MongoDB 95 | ``` 96 | 97 | ```shell 98 | $ # first cd into the Flink installation folder 99 | $ ./bin/start-cluster.sh # start the cluster 100 | $ ./bin/flink run --class com.mass.task.FileToKafka FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # import data into Kafka 101 | $ ./bin/flink run --class com.mass.task.RecordToMongoDB FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # save log to MongoDB 102 | $ ./bin/flink run --class com.mass.task.IntervalPopularItems FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # compute realtime popular items 103 | $ ./bin/flink run --class com.mass.task.SeqRecommend FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # recommend items using reinforcement learning model 104 | $ ./bin/stop-cluster.sh # stop the cluster 105 | ``` 106 | 107 | 打开 [http://localhost:8081](http://localhost:8081/) 可使用 Flink WebUI : 108 | 109 | ![](https://s1.ax1x.com/2020/10/19/0zCM2F.png) 110 | 111 | 112 | 113 | FastAPI 中也提供了交互式的 WebUI,可访问 http://127.0.0.1:8000/docs : 114 | 115 | ![](https://s1.ax1x.com/2020/10/19/0x58HA.jpg) 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /pic/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/5.png -------------------------------------------------------------------------------- /pic/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/6.jpg -------------------------------------------------------------------------------- /pic/readme1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/readme1.png -------------------------------------------------------------------------------- /pic/readme4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/readme4.png -------------------------------------------------------------------------------- /python_api/bcq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Optional, List 4 | from fastapi import FastAPI, HTTPException, Body 5 | from pydantic import BaseModel, Field 6 | from utils import load_model_bcq 7 | 8 | app = FastAPI() 9 | hist_num = 10 10 | action_dim = 32 11 | input_dim = action_dim * (hist_num + 1) 12 | hidden_size = 64 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | model_path = "model_bcq.pt" 15 | user_embeddings_path = "tianchi_user_embeddings.npy" 16 | item_embeddings_path = "tianchi_item_embeddings.npy" 17 | 18 | generator, perturbator = load_model_bcq( 19 | model_path, user_embeddings_path, item_embeddings_path, 20 | input_dim, action_dim, hidden_size, device=device 21 | ) 22 | 23 | 24 | class Seq(BaseModel): 25 | user: List[int] 26 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]]) 27 | n_rec: int 28 | 29 | 30 | class State(BaseModel): 31 | user: List[int] 32 | embedding: List[List[float]] 33 | n_rec: int 34 | 35 | 36 | @app.post("/{algo}") 37 | async def recommend(algo: str, seq: Seq) -> list: 38 | if algo == "bcq": 39 | with torch.no_grad(): 40 | data = { 41 | "user": torch.as_tensor(seq.user), 42 | "item": torch.as_tensor(seq.item) 43 | } 44 | state = generator.get_state(data) 45 | gen_actions = generator.decode(state) 46 | action = perturbator(state, gen_actions) 47 | scores = torch.matmul(action, generator.item_embeds.weight.T) 48 | _, res = torch.topk(scores, seq.n_rec, dim=1, sorted=False) 49 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}" 50 | return res.tolist() 51 | else: 52 | raise HTTPException(status_code=404, detail="wrong algorithm.") 53 | 54 | 55 | @app.post("/{algo}/state") 56 | async def recommend_with_state(algo: str, state: State): 57 | if algo == "bcq": 58 | with torch.no_grad(): 59 | data = torch.as_tensor(state.embedding) 60 | gen_actions = generator.decode(data) 61 | action = perturbator(state, gen_actions) 62 | scores = torch.matmul(action, generator.item_embeds.weight.T) 63 | _, res = torch.topk(scores, state.n_rec, dim=1, sorted=False) 64 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}" 65 | return res.tolist() 66 | else: 67 | raise HTTPException(status_code=404, detail="wrong algorithm.") 68 | 69 | 70 | # gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker 71 | # curl -X POST "http://127.0.0.1:8000/bcq" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}' 72 | -------------------------------------------------------------------------------- /python_api/ddpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Optional, List 4 | from fastapi import FastAPI, HTTPException, Body 5 | from pydantic import BaseModel, Field 6 | from utils import load_model_ddpg 7 | 8 | app = FastAPI() 9 | hist_num = 10 10 | action_dim = 32 11 | input_dim = action_dim * (hist_num + 1) 12 | hidden_size = 64 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | model_path = "model_ddpg.pt" 15 | user_embeddings_path = "tianchi_user_embeddings.npy" 16 | item_embeddings_path = "tianchi_item_embeddings.npy" 17 | 18 | model = load_model_ddpg( 19 | model_path, user_embeddings_path, item_embeddings_path, 20 | input_dim, action_dim, hidden_size, device=device 21 | ) 22 | 23 | 24 | class Seq(BaseModel): 25 | user: List[int] 26 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]]) 27 | n_rec: int 28 | 29 | 30 | class State(BaseModel): 31 | user: List[int] 32 | embedding: List[List[float]] 33 | n_rec: int 34 | 35 | 36 | @app.post("/{algo}") 37 | async def recommend(algo: str, seq: Seq) -> list: 38 | if algo == "ddpg": 39 | with torch.no_grad(): 40 | data = { 41 | "user": torch.as_tensor(seq.user), 42 | "item": torch.as_tensor(seq.item) 43 | } 44 | _, action = model(data) 45 | scores = torch.matmul(action, model.item_embeds.weight.T) 46 | _, res = torch.topk(scores, seq.n_rec, dim=1, sorted=False) 47 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}" 48 | return res.tolist() 49 | else: 50 | raise HTTPException(status_code=404, detail="wrong algorithm.") 51 | 52 | 53 | @app.post("/{algo}/state") 54 | async def recommend_with_state(algo: str, state: State): 55 | if algo == "ddpg": 56 | with torch.no_grad(): 57 | data = torch.as_tensor(state.embedding) 58 | action = model.get_action(data) 59 | scores = torch.matmul(action, model.item_embeds.weight.T) 60 | _, res = torch.topk(scores, state.n_rec, dim=1, sorted=False) 61 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}" 62 | return res.tolist() 63 | else: 64 | raise HTTPException(status_code=404, detail="wrong algorithm.") 65 | 66 | 67 | # gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker 68 | # curl -X POST "http://127.0.0.1:8000/ddpg" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}' 69 | -------------------------------------------------------------------------------- /python_api/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def multihead_attention( 7 | attention, 8 | user_embeds, 9 | item_embeds, 10 | items, 11 | pad_val, 12 | att_mode="normal" 13 | ): 14 | mask = (items == pad_val) 15 | mask[torch.where((mask == torch.tensor(1)).all(1))] = False 16 | item_repr = item_embeds.transpose(1, 0) 17 | if att_mode == "normal": 18 | user_repr = user_embeds.unsqueeze(0) 19 | output, weight = attention(user_repr, item_repr, item_repr, 20 | key_padding_mask=mask) 21 | output = output.squeeze() 22 | elif att_mode == "self": 23 | output, weight = attention(item_repr, item_repr, item_repr, 24 | key_padding_mask=mask) 25 | output = output.transpose(1, 0).mean(dim=1) # max[0], sum 26 | else: 27 | raise ValueError("attention must be 'normal' or 'self'") 28 | return output 29 | 30 | 31 | class Actor(nn.Module): 32 | def __init__( 33 | self, 34 | input_dim, 35 | action_dim, 36 | hidden_size, 37 | user_embeds, 38 | item_embeds, 39 | attention=None, 40 | pad_val=0, 41 | head=1, 42 | device=torch.device("cpu") 43 | ): 44 | super(Actor, self).__init__() 45 | self.user_embeds = nn.Embedding.from_pretrained( 46 | torch.as_tensor(user_embeds), freeze=True, 47 | ) 48 | self.item_embeds = nn.Embedding.from_pretrained( 49 | torch.as_tensor(item_embeds), freeze=True, 50 | ) 51 | if attention is not None: 52 | self.attention = nn.MultiheadAttention( 53 | item_embeds.shape[1], head 54 | ).to(device) 55 | self.att_mode = attention 56 | else: 57 | self.attention = None 58 | 59 | # self.dropout = nn.Dropout(0.5) 60 | self.fc1 = nn.Linear(input_dim, hidden_size) 61 | self.fc2 = nn.Linear(hidden_size, hidden_size) 62 | self.fc3 = nn.Linear(hidden_size, action_dim) 63 | self.pad_val = pad_val 64 | self.device = device 65 | 66 | def get_state(self, data, next_state=False): 67 | user_repr = self.user_embeds(data["user"]) 68 | item_col = "next_item" if next_state else "item" 69 | items = data[item_col] 70 | 71 | if not self.attention: 72 | batch_size = data[item_col].size(0) 73 | item_repr = self.item_embeds(data[item_col]).view(batch_size, -1) 74 | 75 | # true_num = torch.sum(items != self.pad_val, dim=1) + 1e-5 76 | # item_repr = self.item_embeds(items).sum(dim=1) 77 | # compute mean embeddings 78 | # item_repr = torch.div(item_repr, true_num.float().view(-1, 1)) 79 | else: 80 | item_repr = self.item_embeds(items) 81 | item_repr = multihead_attention(self.attention, user_repr, 82 | item_repr, items, self.pad_val, 83 | self.att_mode) 84 | 85 | state = torch.cat([user_repr, item_repr], dim=1) 86 | return state 87 | 88 | def get_action(self, state, tanh=False): 89 | action = F.relu(self.fc1(state)) 90 | action = F.relu(self.fc2(action)) 91 | action = self.fc3(action) 92 | if tanh: 93 | action = F.tanh(action) 94 | return action 95 | 96 | def forward(self, data, tanh=False): 97 | state = self.get_state(data) 98 | action = self.get_action(state, tanh) 99 | return state, action 100 | 101 | 102 | class Critic(nn.Module): 103 | def __init__( 104 | self, 105 | input_dim, 106 | action_dim, 107 | hidden_size 108 | ): 109 | super(Critic, self).__init__() 110 | self.fc1 = nn.Linear(input_dim + action_dim, hidden_size) 111 | self.fc2 = nn.Linear(hidden_size, hidden_size) 112 | self.fc3 = nn.Linear(hidden_size, 1) 113 | 114 | def forward(self, state, action): 115 | out = torch.cat([state, action], dim=1) 116 | out = F.relu(self.fc1(out)) 117 | out = F.relu(self.fc2(out)) 118 | out = self.fc3(out) 119 | return out.squeeze() 120 | 121 | 122 | class PolicyPi(Actor): 123 | def __init__( 124 | self, 125 | input_dim, 126 | action_dim, 127 | hidden_size, 128 | user_embeds, 129 | item_embeds, 130 | attention=None, 131 | pad_val=0, 132 | head=1, 133 | device=torch.device("cpu") 134 | ): 135 | super(PolicyPi, self).__init__( 136 | input_dim, 137 | action_dim, 138 | hidden_size, 139 | user_embeds, 140 | item_embeds, 141 | attention, 142 | pad_val, 143 | head, 144 | device 145 | ) 146 | 147 | # self.dropout = nn.Dropout(0.5) 148 | self.pfc1 = nn.Linear(input_dim, hidden_size) 149 | self.pfc2 = nn.Linear(hidden_size, hidden_size) 150 | self.softmax_fc = nn.Linear(hidden_size, action_dim) 151 | 152 | def get_action(self, state, tanh=False): 153 | action = F.relu(self.pfc1(state)) 154 | action = F.relu(self.pfc2(action)) 155 | if tanh: 156 | action = F.tanh(action) 157 | return action 158 | 159 | def forward(self, data, tanh=False): 160 | state = self.get_state(data) 161 | action = self.get_action(state, tanh) 162 | return state, action 163 | 164 | def get_log_probs(self, data): 165 | state, action = self.forward(data) 166 | logits = self.softmax_fc(action) 167 | log_prob = F.log_softmax(logits, dim=1) 168 | return state, log_prob, action 169 | 170 | def get_beta_state(self, data): 171 | user_repr = self.user_embeds(data["beta_user"]) 172 | items = data["beta_item"] 173 | if not self.attention: 174 | # true_num = torch.sum(items != self.pad_val, dim=1) + 1e-5 175 | # item_repr = self.item_embeds(items).sum(dim=1) 176 | # item_repr = torch.div(item_repr, true_num.float().view(-1, 1)) 177 | batch_size = data["item"].size(0) 178 | item_repr = self.item_embeds(data["item"]).view(batch_size, -1) 179 | else: 180 | item_repr = self.item_embeds(items) 181 | item_repr = multihead_attention(self.attention, user_repr, 182 | item_repr, items, self.pad_val, 183 | self.att_mode) 184 | 185 | state = torch.cat([user_repr, item_repr], dim=1) 186 | return state 187 | 188 | 189 | class Beta(nn.Module): 190 | def __init__( 191 | self, 192 | input_dim, 193 | action_dim, 194 | hidden_size 195 | ): 196 | super(Beta, self).__init__() 197 | self.bfc = nn.Linear(input_dim, hidden_size) 198 | self.softmax_fc = nn.Linear(hidden_size, action_dim) 199 | 200 | def forward(self, state): 201 | out = F.relu(self.bfc(state)) 202 | return out 203 | 204 | def get_log_probs(self, state): 205 | action = self.forward(state) 206 | logits = self.softmax_fc(action) 207 | log_prob = F.log_softmax(logits, dim=1) 208 | return log_prob, logits 209 | 210 | 211 | class VAE(Actor): 212 | def __init__( 213 | self, 214 | input_dim, 215 | action_dim, 216 | latent_dim, 217 | hidden_size, 218 | user_embeds, 219 | item_embeds, 220 | attention=None, 221 | pad_val=0, 222 | head=1, 223 | device=torch.device("cpu") 224 | ): 225 | super(VAE, self).__init__( 226 | input_dim, 227 | action_dim, 228 | hidden_size, 229 | user_embeds, 230 | item_embeds, 231 | attention, 232 | pad_val, 233 | head, 234 | device 235 | ) 236 | 237 | self.encoder1 = nn.Linear(input_dim + action_dim, hidden_size) 238 | self.encoder2 = nn.Linear(hidden_size, hidden_size) 239 | 240 | self.mean = nn.Linear(hidden_size, latent_dim) 241 | self.log_std = nn.Linear(hidden_size, latent_dim) 242 | 243 | self.decoder1 = nn.Linear(input_dim + latent_dim, hidden_size) 244 | self.decoder2 = nn.Linear(hidden_size, hidden_size) 245 | self.decoder3 = nn.Linear(hidden_size, action_dim) 246 | self.latent_dim = latent_dim 247 | self.device = device 248 | 249 | def forward(self, data, action): 250 | state = self.get_state(data) 251 | z = F.relu(self.encoder1(torch.cat([state, action], dim=1))) 252 | z = F.relu(self.encoder2(z)) 253 | 254 | mean = self.mean(z) 255 | log_std = self.log_std(z).clamp(-4, 15) 256 | std = log_std.exp() 257 | z = mean + std * torch.randn_like(std) 258 | 259 | u = self.decode(state, z) 260 | return state, u, mean, std 261 | 262 | def decode(self, state, z=None): 263 | if z is None: 264 | z = torch.randn(state.size(0), self.latent_dim) 265 | z = z.clamp(-0.5, 0.5).to(self.device) 266 | action = F.relu(self.decoder1(torch.cat([state, z], dim=1))) 267 | action = F.relu(self.decoder2(action)) 268 | action = self.decoder3(action) 269 | return action 270 | 271 | 272 | class Perturbator(nn.Module): 273 | def __init__( 274 | self, 275 | input_dim, 276 | action_dim, 277 | hidden_size, 278 | phi=0.05, 279 | action_range=None 280 | ): 281 | super(Perturbator, self).__init__() 282 | self.fc1 = nn.Linear(input_dim + action_dim, hidden_size) 283 | self.fc2 = nn.Linear(hidden_size, hidden_size) 284 | self.fc3 = nn.Linear(hidden_size, action_dim) 285 | self.phi = phi 286 | self.action_range = action_range 287 | 288 | def forward(self, state, action): 289 | a = F.relu(self.fc1(torch.cat([state, action], dim=1))) 290 | a = F.relu(self.fc2(a)) 291 | a = self.fc3(a) 292 | if self.phi is not None: 293 | a = self.phi * torch.tanh(a) 294 | 295 | act = action + a 296 | if self.action_range is not None: 297 | act = act.clamp(self.action_range[0], self.action_range[1]) 298 | return act 299 | 300 | -------------------------------------------------------------------------------- /python_api/reinforce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Optional, List 4 | from fastapi import FastAPI, HTTPException, Body 5 | from pydantic import BaseModel, Field 6 | from utils import load_model_reinforce 7 | 8 | app = FastAPI() 9 | hist_num = 10 10 | embed_size = 32 11 | hidden_size = 64 12 | input_dim = embed_size * (hist_num + 1) 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | model_path = "model_reinforce.pt" 15 | user_embeddings_path = "tianchi_user_embeddings.npy" 16 | item_embeddings_path = "tianchi_item_embeddings.npy" 17 | with open(item_embeddings_path, "rb") as f: 18 | action_dim = len(np.load(f)) 19 | 20 | model = load_model_reinforce( 21 | model_path, user_embeddings_path, item_embeddings_path, 22 | input_dim, action_dim, hidden_size, device=device 23 | ) 24 | 25 | 26 | class Seq(BaseModel): 27 | user: List[int] 28 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]]) 29 | n_rec: int 30 | 31 | 32 | class State(BaseModel): 33 | user: List[int] 34 | embedding: List[List[float]] 35 | n_rec: int 36 | 37 | 38 | @app.post("/{algo}") 39 | async def recommend(algo: str, seq: Seq) -> list: 40 | if algo == "reinforce": 41 | with torch.no_grad(): 42 | data = { 43 | "user": torch.as_tensor(seq.user), 44 | "item": torch.as_tensor(seq.item) 45 | } 46 | _, action_probs, _ = model.get_log_probs(data) 47 | _, res = torch.topk(action_probs, seq.n_rec, dim=1, sorted=False) 48 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}" 49 | return res.tolist() 50 | else: 51 | raise HTTPException(status_code=404, detail="wrong algorithm.") 52 | 53 | 54 | @app.post("/{algo}/state") 55 | async def recommend_with_state(algo: str, state: State): 56 | if algo == "reinforce": 57 | with torch.no_grad(): 58 | data = torch.as_tensor(state.embedding) 59 | action = model.get_action(data) 60 | action_probs = model.softmax_fc(action) 61 | _, res = torch.topk(action_probs, state.n_rec, dim=1, sorted=False) 62 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}" 63 | return res.tolist() 64 | else: 65 | raise HTTPException(status_code=404, detail="wrong algorithm.") 66 | 67 | 68 | # gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker 69 | # curl -X POST "http://127.0.0.1:8000/reinforce" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}' 70 | -------------------------------------------------------------------------------- /python_api/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from net import Actor, PolicyPi, VAE, Perturbator 4 | 5 | 6 | def load_model_reinforce( 7 | model_path, 8 | user_embeddings_path, 9 | item_embeddings_path, 10 | input_dim, 11 | action_dim, 12 | hidden_size, 13 | device 14 | ): 15 | with open(user_embeddings_path, "rb") as f: 16 | user_embeddings = np.load(f) 17 | with open(item_embeddings_path, "rb") as f: 18 | item_embeddings = np.load(f) 19 | model = PolicyPi( 20 | input_dim, action_dim, hidden_size, user_embeddings, item_embeddings 21 | ) 22 | model.load_state_dict(torch.load(model_path, map_location=device)) 23 | model.eval() 24 | return model 25 | 26 | 27 | def load_model_ddpg( 28 | model_path, 29 | user_embeddings_path, 30 | item_embeddings_path, 31 | input_dim, 32 | action_dim, 33 | hidden_size, 34 | device 35 | ): 36 | with open(user_embeddings_path, "rb") as f: 37 | user_embeddings = np.load(f) 38 | with open(item_embeddings_path, "rb") as f: 39 | item_embeddings = np.load(f) 40 | model = Actor( 41 | input_dim, action_dim, hidden_size, user_embeddings, item_embeddings 42 | ) 43 | model.load_state_dict(torch.load(model_path, map_location=device)) 44 | model.eval() 45 | return model 46 | 47 | 48 | def load_model_bcq( 49 | model_path, 50 | user_embeddings_path, 51 | item_embeddings_path, 52 | input_dim, 53 | action_dim, 54 | hidden_size, 55 | device 56 | ): 57 | with open(user_embeddings_path, "rb") as f: 58 | user_embeddings = np.load(f) 59 | with open(item_embeddings_path, "rb") as f: 60 | item_embeddings = np.load(f) 61 | generator = VAE( 62 | input_dim, action_dim, action_dim * 2, hidden_size, 63 | user_embeddings, item_embeddings 64 | ) 65 | perturbator = Perturbator(input_dim, action_dim, hidden_size) 66 | checkpoint = torch.load(model_path, map_location=device) 67 | generator.load_state_dict(checkpoint["generator"]) 68 | perturbator.load_state_dict(checkpoint["perturbator"]) 69 | generator.eval() 70 | perturbator.eval() 71 | return generator, perturbator 72 | --------------------------------------------------------------------------------