├── FlinkRL
├── pom.xml
└── src
│ └── main
│ ├── java
│ └── com
│ │ └── mass
│ │ ├── agg
│ │ └── CountAgg.java
│ │ ├── entity
│ │ ├── RecordEntity.java
│ │ ├── TopItemEntity.java
│ │ └── UserConsumed.java
│ │ ├── feature
│ │ ├── BuildFeature.java
│ │ ├── Embedding.java
│ │ └── IdConverter.java
│ │ ├── map
│ │ └── RecordMap.java
│ │ ├── process
│ │ └── TopNPopularItems.java
│ │ ├── recommend
│ │ └── fastapi
│ │ │ └── FastapiRecommender.java
│ │ ├── sink
│ │ ├── MongodbRecommendSink.java
│ │ ├── MongodbRecordSink.java
│ │ └── TopNRedisSink.java
│ │ ├── source
│ │ ├── CustomFileSource.java
│ │ └── KafkaSource.java
│ │ ├── task
│ │ ├── FileToKafka.java
│ │ ├── IntervalPopularItems.java
│ │ ├── ReadFromFile.java
│ │ ├── RecordToMongoDB.java
│ │ ├── RecordToMysql.java
│ │ └── SeqRecommend.java
│ │ ├── util
│ │ ├── FormatTimestamp.java
│ │ ├── Property.java
│ │ ├── RecordToEntity.java
│ │ └── TypeConvert.java
│ │ └── window
│ │ ├── CountProcessWindowFunction.java
│ │ ├── ItemCollectWindowFunction.java
│ │ └── ItemCollectWindowFunctionRedis.java
│ └── resources
│ ├── config.properties
│ ├── db
│ └── create_table.sql
│ ├── features
│ ├── item_map.json
│ └── user_map.json
│ └── log4j.properties
├── README.md
├── README_zh.md
├── pic
├── 5.png
├── 6.jpg
├── readme1.png
└── readme4.png
└── python_api
├── bcq.py
├── ddpg.py
├── net.py
├── reinforce.py
└── utils.py
/FlinkRL/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | com.mass
8 | FlinkRL
9 | 1.0-SNAPSHOT
10 |
11 |
12 | UTF-8
13 | 1.8
14 | 1.9.3
15 | 2.11
16 | 1.8
17 | 1.8
18 |
19 |
20 |
21 |
22 | junit
23 | junit
24 | 4.12
25 | test
26 |
27 |
28 |
29 | org.slf4j
30 | slf4j-log4j12
31 | 1.7.25
32 |
33 |
34 |
35 | mysql
36 | mysql-connector-java
37 | 8.0.17
38 |
39 |
40 |
41 | org.apache.flink
42 | flink-java
43 | ${flink.version}
44 |
45 |
46 |
47 | org.apache.flink
48 | flink-streaming-java_2.11
49 | ${flink.version}
50 |
51 |
52 |
53 | org.apache.flink
54 | flink-clients_2.11
55 | ${flink.version}
56 |
57 |
58 |
59 | org.apache.flink
60 | flink-connector-kafka_2.11
61 | ${flink.version}
62 |
63 |
64 |
65 | org.mongodb
66 | mongodb-driver-sync
67 | 4.0.5
68 |
69 |
70 |
71 | org.mongodb
72 | mongodb-driver-legacy
73 | 4.0.5
74 |
75 |
76 |
77 | org.apache.flink
78 | flink-connector-redis_2.11
79 | 1.1.5
80 |
81 |
82 |
83 | org.json
84 | json
85 | 20180813
86 |
87 |
88 |
89 |
90 |
91 |
92 | org.apache.maven.plugins
93 | maven-compiler-plugin
94 | 3.8.1
95 |
96 | 1.8
97 | 1.8
98 |
99 |
100 |
101 | org.apache.maven.plugins
102 | maven-assembly-plugin
103 | 3.3.0
104 |
105 |
106 | package
107 |
108 | single
109 |
110 |
111 |
112 | jar-with-dependencies
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | alimaven
124 | aliyun maven
125 | http://maven.aliyun.com/nexus/content/groups/public/
126 |
127 | true
128 |
129 |
130 | false
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/agg/CountAgg.java:
--------------------------------------------------------------------------------
1 | package com.mass.agg;
2 |
3 | import org.apache.flink.api.common.functions.AggregateFunction;
4 | import com.mass.entity.RecordEntity;
5 |
6 | public class CountAgg implements AggregateFunction {
7 |
8 | @Override
9 | public Long createAccumulator() {
10 | return 0L;
11 | }
12 |
13 | @Override
14 | public Long add(RecordEntity value, Long accumulator) {
15 | return accumulator + 1;
16 | }
17 |
18 | @Override
19 | public Long getResult(Long accumulator) {
20 | return accumulator;
21 | }
22 |
23 | @Override
24 | public Long merge(Long a, Long b) {
25 | return a + b;
26 | }
27 | }
28 |
29 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/entity/RecordEntity.java:
--------------------------------------------------------------------------------
1 | package com.mass.entity;
2 |
3 | public class RecordEntity {
4 | private int userId;
5 | private int itemId;
6 | private long time;
7 |
8 | public RecordEntity() { }
9 |
10 | public RecordEntity(int userId, int itemId, long time) {
11 | this.userId = userId;
12 | this.itemId = itemId;
13 | this.time = time;
14 | }
15 |
16 | public int getUserId() {
17 | return userId;
18 | }
19 |
20 | public void setUserId(int userId) {
21 | this.userId = userId;
22 | }
23 |
24 | public int getItemId() {
25 | return itemId;
26 | }
27 |
28 | public void setItemId(int itemId) {
29 | this.itemId = itemId;
30 | }
31 |
32 | public long getTime() {
33 | return time;
34 | }
35 |
36 | public void setTime(long time) {
37 | this.time = time;
38 | }
39 |
40 | public String toString() {
41 | return String.format("user: %5d, item: %5d, time: %5d", userId, itemId, time);
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/entity/TopItemEntity.java:
--------------------------------------------------------------------------------
1 | package com.mass.entity;
2 |
3 | public class TopItemEntity {
4 | private int itemId;
5 | private int counts;
6 | private long windowEnd;
7 |
8 | public static TopItemEntity of(Integer itemId, long end, Long count) {
9 | TopItemEntity res = new TopItemEntity();
10 | res.setCounts(count.intValue());
11 | res.setItemId(itemId);
12 | res.setWindowEnd(end);
13 | return res;
14 | }
15 |
16 | public long getWindowEnd() {
17 | return windowEnd;
18 | }
19 |
20 | public void setWindowEnd(long windowEnd) {
21 | this.windowEnd = windowEnd;
22 | }
23 |
24 |
25 | public int getItemId() {
26 | return itemId;
27 | }
28 |
29 | public void setItemId(int productId) {
30 | this.itemId = productId;
31 | }
32 |
33 | public int getCounts() {
34 | return counts;
35 | }
36 |
37 | public void setCounts(int actionTimes) {
38 | this.counts = actionTimes;
39 | }
40 |
41 | public String toString() {
42 | return String.format("TopItemEntity: itemId:%5d, count:%3d, windowEnd: %5d", itemId, counts, windowEnd);
43 | }
44 | }
45 |
46 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/entity/UserConsumed.java:
--------------------------------------------------------------------------------
1 | package com.mass.entity;
2 |
3 | import java.util.List;
4 |
5 | public class UserConsumed {
6 | public int userId;
7 | public List items;
8 | public long windowEnd;
9 |
10 | public static UserConsumed of(int userId, List items, long windowEnd) {
11 | UserConsumed result = new UserConsumed();
12 | result.userId = userId;
13 | result.items = items;
14 | result.windowEnd = windowEnd;
15 | return result;
16 | }
17 |
18 | public int getUserId() {
19 | return userId;
20 | }
21 |
22 | public List getItems() {
23 | return items;
24 | }
25 |
26 | public long getWindowEnd() {
27 | return windowEnd;
28 | }
29 |
30 | public String toString() {
31 | return String.format("userId:%5d, itemConsumed:%6s, windowEnd:%13s",
32 | userId, items, windowEnd);
33 | }
34 | }
35 |
36 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/feature/BuildFeature.java:
--------------------------------------------------------------------------------
1 | package com.mass.feature;
2 |
3 | import org.json.JSONArray;
4 | import org.json.JSONObject;
5 | import org.json.JSONTokener;
6 |
7 | import java.io.IOException;
8 | import java.util.ArrayList;
9 | import java.util.List;
10 |
11 | public class BuildFeature {
12 |
13 | public static int getUserId(int userIndex) {
14 | return IdConverter.getUserId(userIndex);
15 | }
16 |
17 | public static int getItemId(int itemIndex) {
18 | return IdConverter.getItemId(itemIndex);
19 | }
20 |
21 | private static JSONArray getUserEmbedding(int userId) {
22 | // int userId = getUserId(userIndex);
23 | return Embedding.getEmbedding("user", userId);
24 | }
25 |
26 | private static JSONArray getItemEmbedding(int itemId) {
27 | // int itemId = getItemId(itemIndex);
28 | return Embedding.getEmbedding("item", itemId);
29 | }
30 |
31 | public static List getEmbedding(int user, List items) {
32 | List features = new ArrayList<>();
33 | JSONArray userArray = getUserEmbedding(user);
34 | for (int i = 0; i < userArray.length(); i++) {
35 | features.add(userArray.getFloat(i));
36 | }
37 |
38 | for (int i : items) {
39 | JSONArray itemArray = getItemEmbedding(i);
40 | for (int j = 0; j < itemArray.length(); j++) {
41 | features.add(itemArray.getFloat(j));
42 | }
43 | }
44 | return features;
45 | }
46 |
47 | public static List getSeq(List items) {
48 | List seq = new ArrayList<>();
49 | for (int i: items) {
50 | int itemId = getItemId(i);
51 | seq.add(itemId);
52 | }
53 | return seq;
54 | }
55 |
56 | public static List convertItems(List items) {
57 | return IdConverter.convertItems(items);
58 | }
59 |
60 | public static void close(Boolean withState) throws IOException {
61 | if (withState) {
62 | Embedding.close();
63 | }
64 | IdConverter.close();
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/feature/Embedding.java:
--------------------------------------------------------------------------------
1 | package com.mass.feature;
2 |
3 | import org.json.JSONArray;
4 | import org.json.JSONObject;
5 | import org.json.JSONTokener;
6 |
7 | import java.io.IOException;
8 | import java.io.InputStream;
9 | import java.io.Serializable;
10 |
11 | public class Embedding {
12 | private static InputStream embedStream;
13 | private static JSONObject embedJSON;
14 |
15 | static {
16 | embedStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("features/embeddings.json");
17 | try {
18 | embedJSON = new JSONObject(new JSONTokener(embedStream));
19 | } catch (NullPointerException e) {
20 | e.printStackTrace();
21 | }
22 | }
23 |
24 | public static JSONArray getEmbedding(String feat, int index) {
25 | JSONArray embeds = embedJSON.getJSONArray(feat);
26 | return embeds.getJSONArray(index);
27 | }
28 |
29 | public static void close() throws IOException {
30 | embedStream.close();
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/feature/IdConverter.java:
--------------------------------------------------------------------------------
1 | package com.mass.feature;
2 |
3 | import com.mass.util.Property;
4 | import com.typesafe.config.ConfigException;
5 | import org.json.JSONObject;
6 | import org.json.JSONTokener;
7 |
8 | import java.io.*;
9 | import java.util.ArrayList;
10 | import java.util.List;
11 |
12 | public class IdConverter {
13 | private static InputStream userStream;
14 | private static InputStream itemStream;
15 | private static JSONObject userJSON;
16 | private static JSONObject itemJSON;
17 | private static int numUsers;
18 | private static int numItems;
19 |
20 | static {
21 | try {
22 | userStream = new FileInputStream(Property.getStrValue("user_map"));
23 | itemStream = new FileInputStream(Property.getStrValue("item_map"));
24 | userJSON = new JSONObject(new JSONTokener(userStream));
25 | itemJSON = new JSONObject(new JSONTokener(itemStream));
26 | } catch (FileNotFoundException e) {
27 | e.printStackTrace();
28 | }
29 |
30 | numUsers = userJSON.length();
31 | numItems = itemJSON.length();
32 | }
33 |
34 | public static int getUserId(int userIndex) {
35 | String user = String.valueOf(userIndex);
36 | return userJSON.has(user) ? userJSON.getInt(user) : numUsers;
37 | }
38 |
39 | public static int getItemId(int itemIndex) {
40 | String item = String.valueOf(itemIndex);
41 | return itemJSON.has(item) ? itemJSON.getInt(item) : numItems;
42 | }
43 |
44 | public static List convertItems(List items) {
45 | List converted = new ArrayList<>();
46 | for (int i : items) {
47 | converted.add(getItemId(i));
48 | }
49 | return converted;
50 | }
51 |
52 | public static void close() throws IOException {
53 | userStream.close();
54 | itemStream.close();
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/map/RecordMap.java:
--------------------------------------------------------------------------------
1 | package com.mass.map;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.util.Property;
5 | import com.mass.util.RecordToEntity;
6 | import org.apache.flink.api.common.functions.RichMapFunction;
7 | import org.apache.flink.configuration.Configuration;
8 | import org.json.JSONObject;
9 | import org.json.JSONTokener;
10 |
11 | import java.io.FileInputStream;
12 | import java.io.FileNotFoundException;
13 | import java.io.InputStream;
14 |
15 | public class RecordMap extends RichMapFunction {
16 | private static InputStream userStream;
17 | private static InputStream itemStream;
18 | private static JSONObject userJSON;
19 | private static JSONObject itemJSON;
20 | private static int numUsers;
21 | private static int numItems;
22 |
23 | @Override
24 | public void open(Configuration parameters) {
25 | try {
26 | userStream = new FileInputStream(Property.getStrValue("user_map"));
27 | itemStream = new FileInputStream(Property.getStrValue("item_map"));
28 | userJSON = new JSONObject(new JSONTokener(userStream));
29 | itemJSON = new JSONObject(new JSONTokener(itemStream));
30 | } catch (FileNotFoundException e) {
31 | e.printStackTrace();
32 | }
33 |
34 | numUsers = userJSON.length();
35 | numItems = itemJSON.length();
36 | }
37 |
38 | @Override
39 | public RecordEntity map(String value) {
40 | RecordEntity record = RecordToEntity.getRecord(value, userJSON, itemJSON, numUsers, numItems);
41 | return record;
42 | }
43 |
44 | @Override
45 | public void close() throws Exception {
46 | userStream.close();
47 | itemStream.close();
48 | }
49 | }
50 |
51 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/process/TopNPopularItems.java:
--------------------------------------------------------------------------------
1 | package com.mass.process;
2 |
3 | import com.mass.entity.TopItemEntity;
4 | import org.apache.flink.api.common.state.ListStateDescriptor;
5 | import org.apache.flink.api.java.tuple.Tuple;
6 | import org.apache.flink.configuration.Configuration;
7 | import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
8 | import org.apache.flink.util.Collector;
9 | import org.apache.flink.api.common.state.ListState;
10 |
11 | import java.util.ArrayList;
12 | import java.util.List;
13 | import java.util.Comparator;
14 |
15 | public class TopNPopularItems extends KeyedProcessFunction> {
16 |
17 | private ListState itemState;
18 | private final int topSize;
19 |
20 | public TopNPopularItems(int topSize) {
21 | this.topSize = topSize;
22 | }
23 |
24 | @Override
25 | public void open(Configuration parameters) throws Exception {
26 | super.open(parameters);
27 | ListStateDescriptor itemStateDesc = new ListStateDescriptor<>("itemState", TopItemEntity.class);
28 | itemState = getRuntimeContext().getListState(itemStateDesc);
29 | }
30 |
31 | @Override
32 | public void processElement(TopItemEntity value, Context ctx, Collector> out) throws Exception {
33 | itemState.add(value);
34 | ctx.timerService().registerEventTimeTimer(value.getWindowEnd() + 1);
35 | }
36 |
37 | @Override
38 | public void onTimer(long timestamp, OnTimerContext ctx, Collector> out) throws Exception {
39 | List allItems = new ArrayList<>();
40 | for (TopItemEntity item : itemState.get()) {
41 | allItems.add(item);
42 | }
43 | itemState.clear();
44 |
45 | allItems.sort(new Comparator() {
46 | @Override
47 | public int compare(TopItemEntity o1, TopItemEntity o2) {
48 | return (int) (o2.getCounts() - o1.getCounts());
49 | }
50 | });
51 |
52 | List ret = new ArrayList<>();
53 | long lastTime = ctx.getCurrentKey().getField(0);
54 | ret.add(String.valueOf(lastTime)); // add timestamp at first
55 | allItems.forEach(i -> ret.add(String.valueOf(i.getItemId())));
56 | out.collect(ret);
57 | }
58 | }
59 |
60 |
61 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/recommend/fastapi/FastapiRecommender.java:
--------------------------------------------------------------------------------
1 | package com.mass.recommend.fastapi;
2 |
3 | import com.mass.entity.UserConsumed;
4 | import com.mass.feature.BuildFeature;
5 | import com.mass.feature.IdConverter;
6 | import org.apache.commons.lang3.StringUtils;
7 | import org.apache.flink.api.common.functions.RichFlatMapFunction;
8 | import org.apache.flink.api.common.state.MapState;
9 | import org.apache.flink.api.common.state.MapStateDescriptor;
10 | import org.apache.flink.api.common.typeinfo.TypeHint;
11 | import org.apache.flink.api.common.typeinfo.TypeInformation;
12 | import org.apache.flink.api.java.tuple.Tuple4;
13 | import org.apache.flink.configuration.Configuration;
14 | import org.apache.flink.util.Collector;
15 | import org.json.JSONArray;
16 | import org.json.JSONObject;
17 | import redis.clients.jedis.Jedis;
18 |
19 | import java.io.BufferedReader;
20 | import java.io.DataOutputStream;
21 | import java.io.IOException;
22 | import java.io.InputStreamReader;
23 | import java.net.HttpURLConnection;
24 | import java.net.URL;
25 | import java.nio.charset.StandardCharsets;
26 | import java.util.ArrayList;
27 | import java.util.List;
28 |
29 | import static com.mass.util.FormatTimestamp.format;
30 | import static com.mass.util.TypeConvert.convertEmbedding;
31 | import static com.mass.util.TypeConvert.convertJSON;
32 | import static com.mass.util.TypeConvert.convertSeq;
33 |
34 | public class FastapiRecommender extends RichFlatMapFunction, String, Integer>> {
35 | private final int histNum;
36 | private final int numRec;
37 | private final String algo;
38 | private final Boolean withState;
39 | private static Jedis jedis;
40 | private static HttpURLConnection con;
41 | private static MapState> lastRecState;
42 |
43 | public FastapiRecommender(int numRec, int histNum, String algo, Boolean withState) {
44 | this.histNum = histNum;
45 | this.numRec = numRec;
46 | this.algo = algo;
47 | this.withState = withState;
48 | }
49 |
50 | @Override
51 | public void open(Configuration parameters) {
52 | jedis = new Jedis("localhost", 6379);
53 | MapStateDescriptor> lastRecStateDesc = new MapStateDescriptor<>("lastRecState",
54 | TypeInformation.of(new TypeHint() {}), TypeInformation.of(new TypeHint>() {}));
55 | lastRecState = getRuntimeContext().getMapState(lastRecStateDesc);
56 | }
57 |
58 | @Override
59 | public void close() throws IOException {
60 | jedis.close();
61 | lastRecState.clear();
62 | BuildFeature.close(withState);
63 | }
64 |
65 | private void buildConnection() throws IOException {
66 | String path = String.format("http://127.0.0.1:8000/%s", algo);
67 | if (withState) {
68 | path += "/state";
69 | }
70 | URL obj = new URL(path);
71 | con = (HttpURLConnection) obj.openConnection();
72 | con.setRequestMethod("POST");
73 | con.setRequestProperty("Content-Type", "application/json");
74 | // con.setRequestProperty("Connection", "Keep-Alive");
75 | con.setConnectTimeout(5000);
76 | con.setReadTimeout(5000);
77 | con.setDoOutput(true);
78 | }
79 |
80 | private void writeOutputStream(String jsonString) throws IOException {
81 | DataOutputStream wr = new DataOutputStream(con.getOutputStream());
82 | wr.write(jsonString.getBytes(StandardCharsets.UTF_8));
83 | wr.flush();
84 | wr.close();
85 | }
86 |
87 | @Override
88 | public void flatMap(UserConsumed value, Collector, String, Integer>> out) throws Exception {
89 | buildConnection();
90 | int userId = BuildFeature.getUserId(value.userId);
91 | List items = BuildFeature.convertItems(value.items);
92 | long timestamp = value.windowEnd;
93 | String time = format(timestamp);
94 |
95 | if (items.size() == this.histNum) {
96 | String jsonString;
97 | if (withState) {
98 | List stateEmbeds = BuildFeature.getEmbedding(userId, items);
99 | jsonString = convertEmbedding(stateEmbeds, userId, numRec);
100 | } else {
101 | jsonString = convertSeq(items, userId, numRec);
102 | }
103 |
104 | writeOutputStream(jsonString);
105 | int responseCode = con.getResponseCode();
106 | // System.out.println("Posted parameters : " + jsonString);
107 | System.out.println("Response Code : " + responseCode);
108 |
109 | List recommend = new ArrayList<>();
110 | String printOut;
111 | if (responseCode == 200) {
112 | BufferedReader br = new BufferedReader(new InputStreamReader(con.getInputStream()));
113 | String inputLine;
114 | StringBuilder message = new StringBuilder();
115 | while ((inputLine = br.readLine()) != null) {
116 | message.append(inputLine);
117 | }
118 | br.close();
119 | con.disconnect();
120 |
121 | JSONArray res = convertJSON(message.toString());
122 | for (int i = 0; i < res.length(); i++) {
123 | recommend.add(res.getInt(i));
124 | }
125 |
126 | int hotRecommend = numRec - recommend.size();
127 | if (hotRecommend > 0) {
128 | for (int i = 1; i < hotRecommend + 1; i++) {
129 | String item = jedis.get(String.valueOf(i));
130 | if (null != item) {
131 | recommend.add(Integer.parseInt(item));
132 | }
133 | }
134 | }
135 | printOut = String.format("user: %d, recommend(%d) + hot(%d): %s, time: %s",
136 | userId, recommend.size(), hotRecommend, recommend, time);
137 | } else {
138 | for (int i = 1; i < numRec + 1; i++) {
139 | String item = jedis.get(String.valueOf(i));
140 | if (null != item) {
141 | recommend.add(Integer.parseInt(item));
142 | }
143 | }
144 | printOut = String.format("user: %d, bad request, recommend hot(%d): %s, " +
145 | "time: %s", userId, recommend.size(), recommend, time);
146 | }
147 |
148 | System.out.println(printOut);
149 | System.out.println(StringUtils.repeat("=", 60));
150 |
151 | int lastReward = updateLastReward(userId, items, recommend);
152 | out.collect(Tuple4.of(userId, recommend, time, lastReward));
153 | }
154 | }
155 |
156 | private int updateLastReward(int userId, List items, List recommend) throws Exception {
157 | // for (Map.Entry> entry: lastRecState.entries()) {
158 | // System.out.println(entry.getKey() + " - " + entry.getValue());
159 | // }
160 | // System.out.println("user: " + lastRecState.contains(userId));
161 |
162 | int lastReward;
163 | List lastRecommend = lastRecState.get(userId);
164 | if (lastRecommend != null) {
165 | lastReward = 0;
166 | for (int rec : lastRecommend) {
167 | for (int click : items) {
168 | if (rec == click) {
169 | lastReward++;
170 | }
171 | }
172 | }
173 | } else {
174 | lastReward = -1;
175 | }
176 |
177 | lastRecState.put(userId, recommend);
178 | return lastReward;
179 | }
180 | }
181 |
182 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/sink/MongodbRecommendSink.java:
--------------------------------------------------------------------------------
1 | package com.mass.sink;
2 |
3 | import com.mongodb.client.MongoClient;
4 | import com.mongodb.client.MongoClients;
5 | import com.mongodb.client.MongoCollection;
6 | import com.mongodb.client.MongoDatabase;
7 | import org.apache.flink.api.java.tuple.Tuple4;
8 | import org.apache.flink.configuration.Configuration;
9 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
10 | import org.bson.Document;
11 |
12 | import java.util.List;
13 |
14 | public class MongodbRecommendSink extends RichSinkFunction, String, Integer>> {
15 | static MongoClient mongoClient;
16 | static MongoDatabase database;
17 | static MongoCollection recCollection;
18 |
19 | @Override
20 | public void open(Configuration parameters) {
21 | mongoClient = MongoClients.create("mongodb://localhost:27017");
22 | database = mongoClient.getDatabase("flink-rl");
23 | recCollection = database.getCollection("recommendResults");
24 | }
25 |
26 | @Override
27 | public void close() {
28 | if (mongoClient != null) {
29 | mongoClient.close();
30 | }
31 | }
32 |
33 | @Override
34 | public void invoke(Tuple4, String, Integer> value, Context context) {
35 | Document recDoc = new Document()
36 | .append("user", value.f0)
37 | .append("itemRec", value.f1)
38 | .append("time", value.f2)
39 | .append("lastReward", value.f3);
40 | recCollection.insertOne(recDoc);
41 | }
42 | }
43 |
44 |
45 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/sink/MongodbRecordSink.java:
--------------------------------------------------------------------------------
1 | package com.mass.sink;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mongodb.client.*;
5 | import com.mongodb.client.model.Updates;
6 | import org.apache.flink.configuration.Configuration;
7 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
8 | import org.bson.Document;
9 |
10 | import java.util.ArrayList;
11 |
12 | import static com.mass.util.FormatTimestamp.format;
13 | import static com.mongodb.client.model.Filters.eq;
14 |
15 | public class MongodbRecordSink extends RichSinkFunction {
16 | private static MongoClient mongoClient;
17 | private static MongoDatabase database;
18 | private static MongoCollection recordCollection;
19 | private static MongoCollection consumedCollection;
20 |
21 | @Override
22 | public void open(Configuration parameters) {
23 | mongoClient = MongoClients.create("mongodb://localhost:27017");
24 | database = mongoClient.getDatabase("flink-rl");
25 | recordCollection = database.getCollection("records");
26 | consumedCollection = database.getCollection("userConsumed");
27 | }
28 |
29 | @Override
30 | public void close() {
31 | if (mongoClient != null) {
32 | mongoClient.close();
33 | }
34 | }
35 |
36 | @Override
37 | public void invoke(RecordEntity value, Context context) {
38 | Document recordDoc = new Document()
39 | .append("user", value.getUserId())
40 | .append("item", value.getItemId())
41 | .append("time", format(value.getTime()));
42 | recordCollection.insertOne(recordDoc);
43 |
44 | int userId = value.getUserId();
45 | int itemId = value.getItemId();
46 | long timestamp = value.getTime();
47 | String time = format(timestamp);
48 | FindIterable fe = consumedCollection.find(eq("user", userId));
49 | if (null == fe.first()) {
50 | Document consumedDoc = new Document();
51 | consumedDoc.append("user", userId);
52 | consumedDoc.append("consumed", new ArrayList());
53 | consumedDoc.append("time", new ArrayList());
54 | consumedCollection.insertOne(consumedDoc);
55 | } else {
56 | consumedCollection.updateOne(eq("user", userId), Updates.addToSet("consumed", itemId));
57 | consumedCollection.updateOne(eq("user", userId), Updates.addToSet("time", time));
58 | }
59 | }
60 | }
61 |
62 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/sink/TopNRedisSink.java:
--------------------------------------------------------------------------------
1 | package com.mass.sink;
2 |
3 | import org.apache.flink.api.java.tuple.Tuple2;
4 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommand;
5 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisCommandDescription;
6 | import org.apache.flink.streaming.connectors.redis.common.mapper.RedisMapper;
7 |
8 | public class TopNRedisSink implements RedisMapper> {
9 |
10 | @Override
11 | public RedisCommandDescription getCommandDescription() {
12 | return new RedisCommandDescription(RedisCommand.SET);
13 | }
14 |
15 | @Override
16 | public String getKeyFromData(Tuple2 data) {
17 | return data.f1; // rank
18 | }
19 |
20 | @Override
21 | public String getValueFromData(Tuple2 data) {
22 | return data.f0; // item
23 | }
24 | }
25 |
26 |
27 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/source/CustomFileSource.java:
--------------------------------------------------------------------------------
1 | package com.mass.source;
2 |
3 | import com.mass.util.Property;
4 | import org.apache.flink.configuration.Configuration;
5 | import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
6 | import com.mass.entity.RecordEntity;
7 |
8 | import java.io.BufferedReader;
9 | import java.io.FileReader;
10 | import java.io.IOException;
11 |
12 | public class CustomFileSource extends RichSourceFunction {
13 |
14 | private static BufferedReader br;
15 | private Boolean header;
16 | private String filePath;
17 |
18 | public CustomFileSource(String filePath, Boolean header) {
19 | this.filePath = filePath;
20 | this.header = header;
21 | }
22 |
23 | @Override
24 | public void open(Configuration parameters) throws IOException {
25 | // String dataPath = Thread.currentThread().getContextClassLoader().getResource(filePath).getFile();
26 | String dataPath = Property.getStrValue("data.path");
27 | br = new BufferedReader(new FileReader(dataPath));
28 | }
29 |
30 | @Override
31 | public void close() throws IOException {
32 | br.close();
33 | }
34 |
35 | @Override
36 | public void run(SourceContext ctx) throws IOException, InterruptedException {
37 | String temp;
38 | int i = 0;
39 | while ((temp = br.readLine()) != null) {
40 | i++;
41 | if (header && i == 1) continue; // skip header line
42 | String[] line = temp.split(",");
43 | int userId = Integer.valueOf(line[0]);
44 | int itemId = Integer.valueOf(line[1]);
45 | long time = Long.valueOf(line[3]);
46 | ctx.collect(new RecordEntity(userId, itemId, time));
47 | Thread.sleep(5L);
48 | }
49 | }
50 |
51 | @Override
52 | public void cancel() { }
53 | }
54 |
55 |
56 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/source/KafkaSource.java:
--------------------------------------------------------------------------------
1 | package com.mass.source;
2 |
3 | import org.apache.kafka.clients.producer.KafkaProducer;
4 | import org.apache.kafka.clients.producer.ProducerRecord;
5 | import org.apache.kafka.clients.producer.RecordMetadata;
6 | import org.apache.kafka.clients.producer.Callback;
7 |
8 | import java.io.*;
9 | import java.util.Properties;
10 |
11 | public class KafkaSource {
12 | public static void sendData(Properties kafkaProps, String dataPath, Boolean header)
13 | throws IOException, InterruptedException {
14 | KafkaProducer producer = new KafkaProducer<>(kafkaProps);
15 | BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(dataPath)));
16 | String temp;
17 | int i = 0;
18 | while ((temp = br.readLine()) != null) {
19 | i++;
20 | if (header && i == 1) continue; // skip header line
21 | String[] splitted = temp.split(",");
22 | String[] record = {splitted[0], splitted[1], splitted[3]};
23 | String csvString = String.join(",", record);
24 | producer.send(new ProducerRecord<>("flink-rl", "news", csvString), new Callback() {
25 | @Override
26 | public void onCompletion(RecordMetadata metadata, Exception e) {
27 | if (e != null) {
28 | e.printStackTrace();
29 | }
30 | }
31 | });
32 | Thread.sleep(10L);
33 | }
34 |
35 | br.close();
36 | producer.close();
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/FileToKafka.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import com.mass.source.KafkaSource;
4 | import com.mass.util.Property;
5 |
6 | import java.io.IOException;
7 | import java.util.Properties;
8 |
9 | public class FileToKafka {
10 | public static void main(String[] args) throws IOException, InterruptedException {
11 | Properties kafkaProps = new Properties();
12 | kafkaProps.setProperty("bootstrap.servers", "localhost:9092");
13 | kafkaProps.setProperty("ack", "1");
14 | kafkaProps.setProperty("batch.size", "16384");
15 | kafkaProps.setProperty("linger.ms", "1");
16 | kafkaProps.setProperty("buffer.memory", "33554432");
17 | kafkaProps.setProperty("key.serializer", "org.apache.kafka.common.serialization.StringSerializer");
18 | kafkaProps.setProperty("value.serializer", "org.apache.kafka.common.serialization.StringSerializer");
19 |
20 | // ClassLoader classloader = Thread.currentThread().getContextClassLoader();
21 | // InputStream is = classloader.getResourceAsStream("/news_data.csv");
22 | // String dataPath = FileToKafka.class.getResource("/tianchi.csv").getFile();
23 | String dataPath = Property.getStrValue("data.path");
24 | KafkaSource.sendData(kafkaProps, dataPath, false);
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/IntervalPopularItems.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import com.mass.agg.CountAgg;
4 | import com.mass.entity.RecordEntity;
5 | import com.mass.map.RecordMap;
6 | import com.mass.process.TopNPopularItems;
7 | import com.mass.sink.TopNRedisSink;
8 | import com.mass.window.CountProcessWindowFunction;
9 | import org.apache.flink.api.common.functions.FlatMapFunction;
10 | import org.apache.flink.api.common.serialization.SimpleStringSchema;
11 | import org.apache.flink.api.java.tuple.Tuple2;
12 | import org.apache.flink.streaming.api.TimeCharacteristic;
13 | import org.apache.flink.streaming.api.datastream.DataStream;
14 | import org.apache.flink.streaming.api.datastream.DataStreamSource;
15 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
16 | import org.apache.flink.streaming.api.functions.timestamps.AscendingTimestampExtractor;
17 | import org.apache.flink.streaming.api.windowing.time.Time;
18 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer;
19 | import org.apache.flink.streaming.connectors.redis.RedisSink;
20 | import org.apache.flink.streaming.connectors.redis.common.config.FlinkJedisPoolConfig;
21 | import org.apache.flink.util.Collector;
22 | import org.apache.kafka.common.serialization.StringDeserializer;
23 |
24 | import java.text.SimpleDateFormat;
25 | import java.util.Date;
26 | import java.util.List;
27 | import java.util.Properties;
28 |
29 | public class IntervalPopularItems {
30 | private static final int topSize = 17;
31 | private static SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
32 |
33 | public static void main(String[] args) throws Exception {
34 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
35 | env.setParallelism(1);
36 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
37 |
38 | FlinkJedisPoolConfig redisConf = new FlinkJedisPoolConfig
39 | .Builder()
40 | .setHost("localhost")
41 | .setPort(6379)
42 | .build();
43 | Properties prop = new Properties();
44 | prop.setProperty("bootstrap.servers", "localhost:9092");
45 | prop.setProperty("group.id", "rec");
46 | prop.setProperty("key.deserializer", StringDeserializer.class.getName());
47 | prop.setProperty("value.deserializer",
48 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName());
49 | DataStreamSource sourceStream = env
50 | .addSource(new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop));
51 |
52 | DataStream> topItem = sourceStream
53 | .map(new RecordMap())
54 | .assignTimestampsAndWatermarks(new AscendingTimestampExtractor() {
55 | @Override
56 | public long extractAscendingTimestamp(RecordEntity element) {
57 | return element.getTime();
58 | }
59 | })
60 | // .keyBy("itemId") // if using POJOs, will convert to tuple
61 | // .keyBy(i -> i.getItemId())
62 | .keyBy(RecordEntity::getItemId)
63 | .timeWindow(Time.minutes(60), Time.seconds(80))
64 | .aggregate(new CountAgg(), new CountProcessWindowFunction())
65 | .keyBy("windowEnd")
66 | .process(new TopNPopularItems(topSize))
67 | .flatMap(new FlatMapFunction, Tuple2>() {
68 | @Override
69 | public void flatMap(List value, Collector> out) {
70 | showTime(value.get(0));
71 | for (int rank = 1; rank < value.size(); ++rank) {
72 | String item = value.get(rank);
73 | System.out.println(String.format("item: %s, rank: %d", item, rank));
74 | out.collect(Tuple2.of(item, String.valueOf(rank)));
75 | }
76 | }
77 | });
78 |
79 | topItem.addSink(new RedisSink>(redisConf, new TopNRedisSink()));
80 | env.execute("IntervalPopularItems");
81 | }
82 |
83 | private static void showTime(String timestamp) {
84 | long windowTime = Long.valueOf(timestamp);
85 | Date date = new Date(windowTime);
86 | System.out.println("============== Top N Items during " + sdf.format(date) + " ==============");
87 | }
88 | }
89 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/ReadFromFile.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import org.apache.flink.api.common.functions.FlatMapFunction;
4 | import org.apache.flink.api.java.tuple.Tuple2;
5 | import org.apache.flink.streaming.api.datastream.DataStream;
6 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
7 | import org.apache.flink.util.Collector;
8 |
9 | import java.util.Arrays;
10 |
11 | public class ReadFromFile {
12 | public static void main(String[] args) throws Exception {
13 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
14 |
15 | // DataStream stream = env.addSource(new CustomFileSource());
16 | // String dataPath = ReadFromFile.class.getResource("/tianchi.csv").getFile();
17 | String dataPath = Thread.currentThread().getContextClassLoader().getResource("tianchi.csv").getFile();
18 | DataStream stream = env.readTextFile(dataPath);
19 | DataStream> result = stream.flatMap(
20 | new FlatMapFunction>() {
21 | @Override
22 | public void flatMap(String value, Collector> out) {
23 | String[] splitted = value.split(",");
24 | boolean first = false;
25 | try {
26 | Integer.parseInt(splitted[0]);
27 | } catch (NumberFormatException e) {
28 | System.out.println("skip header line...");
29 | first = true;
30 | }
31 | if (!first) {
32 | String ratings = String.join(",", Arrays.copyOfRange(splitted, 4, 5));
33 | out.collect(new Tuple2<>(Integer.valueOf(splitted[0]), ratings));
34 | }
35 | }
36 | }).keyBy(0);
37 |
38 | result.print();
39 | env.execute();
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/RecordToMongoDB.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.sink.MongodbRecordSink;
5 | import com.mass.source.CustomFileSource;
6 | import com.mass.util.Property;
7 | import com.mass.util.RecordToEntity;
8 | import org.apache.flink.api.common.functions.RichMapFunction;
9 | import org.apache.flink.api.common.serialization.SimpleStringSchema;
10 | import org.apache.flink.configuration.Configuration;
11 | import org.apache.flink.streaming.api.TimeCharacteristic;
12 | import org.apache.flink.streaming.api.datastream.DataStream;
13 | import org.apache.flink.streaming.api.datastream.DataStreamSource;
14 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
15 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer;
16 | import org.apache.kafka.common.serialization.StringDeserializer;
17 | import org.json.JSONObject;
18 | import org.json.JSONTokener;
19 |
20 | import java.io.FileInputStream;
21 | import java.io.FileNotFoundException;
22 | import java.io.InputStream;
23 | import java.util.Properties;
24 |
25 | public class RecordToMongoDB {
26 | public static void main(String[] args) throws Exception {
27 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
28 | env.setParallelism(1);
29 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
30 |
31 | Properties prop = new Properties();
32 | prop.setProperty("bootstrap.servers", "localhost:9092");
33 | prop.setProperty("group.id", "rec");
34 | prop.setProperty("key.deserializer", StringDeserializer.class.getName());
35 | prop.setProperty("value.deserializer",
36 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName());
37 | DataStreamSource sourceStream = env.addSource(
38 | new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop));
39 |
40 | DataStream recordStream = sourceStream.map(new RichMapFunction() {
41 | private InputStream userStream;
42 | private InputStream itemStream;
43 | private JSONObject userJSON;
44 | private JSONObject itemJSON;
45 | private int numUsers;
46 | private int numItems;
47 |
48 | @Override
49 | public void open(Configuration parameters) throws FileNotFoundException {
50 | userStream = new FileInputStream(Property.getStrValue("user_map"));
51 | itemStream = new FileInputStream(Property.getStrValue("item_map"));
52 | try {
53 | userJSON = new JSONObject(new JSONTokener(userStream));
54 | itemJSON = new JSONObject(new JSONTokener(itemStream));
55 | } catch (NullPointerException e) {
56 | e.printStackTrace();
57 | }
58 | numUsers = userJSON.length();
59 | numItems = itemJSON.length();
60 | }
61 |
62 | @Override
63 | public RecordEntity map(String value) {
64 | return RecordToEntity.getRecord(value, userJSON, itemJSON, numUsers, numItems);
65 | }
66 |
67 | @Override
68 | public void close() throws Exception {
69 | userStream.close();
70 | itemStream.close();
71 | }
72 | }
73 | );
74 |
75 | recordStream.addSink(new MongodbRecordSink());
76 | env.execute("RecordToMongoDB");
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/RecordToMysql.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.util.RecordToEntity;
5 | import com.mass.map.RecordMap;
6 | import org.apache.flink.api.common.functions.RichMapFunction;
7 | import org.apache.flink.api.common.serialization.SimpleStringSchema;
8 | import org.apache.flink.streaming.api.datastream.DataStream;
9 | import org.apache.flink.streaming.api.datastream.DataStreamSource;
10 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
11 | import org.apache.flink.streaming.connectors.kafka.FlinkKafkaConsumer;
12 | import org.apache.kafka.common.serialization.StringDeserializer;
13 | import org.apache.flink.configuration.Configuration;
14 | import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
15 | import org.apache.flink.streaming.api.functions.sink.SinkFunction;
16 | import org.json.JSONObject;
17 | import org.json.JSONTokener;
18 |
19 | import java.io.InputStream;
20 | import java.text.DateFormat;
21 | import java.text.SimpleDateFormat;
22 | import java.util.Properties;
23 | import java.sql.Connection;
24 | import java.sql.DriverManager;
25 | import java.sql.PreparedStatement;
26 |
27 |
28 | public class RecordToMysql {
29 | public static void main(String[] args) throws Exception {
30 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
31 | env.setParallelism(1);
32 |
33 | Properties prop = new Properties();
34 | prop.setProperty("bootstrap.servers", "localhost:9092");
35 | prop.setProperty("group.id", "rec");
36 | prop.setProperty("key.deserializer", StringDeserializer.class.getName());
37 | prop.setProperty("value.deserializer",
38 | Class.forName("org.apache.kafka.common.serialization.StringDeserializer").getName());
39 | DataStreamSource stream = env.addSource(
40 | new FlinkKafkaConsumer<>("flink-rl", new SimpleStringSchema(), prop));
41 | // stream.print();
42 |
43 | DataStream stream2 = stream.map(new RecordMap());
44 | stream2.addSink(new CustomMySqlSink());
45 |
46 | env.execute("RecordToMySql");
47 | }
48 |
49 | private static class CustomMySqlSink extends RichSinkFunction {
50 | Connection conn;
51 | PreparedStatement stmt;
52 |
53 | @Override
54 | public void open(Configuration parameters) throws Exception {
55 | conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/flink", "root", "123456");
56 | stmt = conn.prepareStatement("INSERT INTO record (user, item, time) VALUES (?, ?, ?)");
57 | }
58 |
59 | @Override
60 | public void close() throws Exception {
61 | stmt.close();
62 | conn.close();
63 | }
64 |
65 | @Override
66 | public void invoke(RecordEntity value, Context context) throws Exception {
67 | stmt.setInt(1, value.getUserId());
68 | stmt.setInt(2, value.getItemId());
69 | DateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
70 | stmt.setString(3, sdf.format(value.getTime()));
71 | stmt.executeUpdate();
72 | }
73 | }
74 | }
75 |
76 |
77 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/task/SeqRecommend.java:
--------------------------------------------------------------------------------
1 | package com.mass.task;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.recommend.fastapi.FastapiRecommender;
5 | // import com.mass.recommend.onnx.ONNXRecommender;
6 | import com.mass.sink.MongodbRecommendSink;
7 | import com.mass.source.CustomFileSource;
8 | import com.mass.window.ItemCollectWindowFunction;
9 | import org.apache.flink.streaming.api.TimeCharacteristic;
10 | import org.apache.flink.streaming.api.datastream.DataStream;
11 | import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
12 | import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor;
13 | import org.apache.flink.streaming.api.windowing.time.Time;
14 |
15 | public class SeqRecommend {
16 | public static void main(String[] args) throws Exception {
17 | StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
18 | env.setParallelism(1);
19 | env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
20 |
21 | DataStream stream = env.addSource(new CustomFileSource("tianchi.csv", false));
22 | stream.assignTimestampsAndWatermarks(
23 | new BoundedOutOfOrdernessTimestampExtractor(Time.seconds(10)) {
24 | @Override
25 | public long extractTimestamp(RecordEntity element) {
26 | return element.getTime();
27 | }
28 | })
29 | .keyBy("userId") // return tuple
30 | // .keyBy(RecordEntity -> RecordEntity.userId) // return Integer
31 | .timeWindow(Time.seconds(60), Time.seconds(20))
32 | .process(new ItemCollectWindowFunction(10))
33 | .keyBy("windowEnd")
34 | .flatMap(new FastapiRecommender(8, 10, "ddpg", false))
35 | .addSink(new MongodbRecommendSink());
36 |
37 | env.execute("ItemSeqRecommend");
38 | }
39 | }
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/util/FormatTimestamp.java:
--------------------------------------------------------------------------------
1 | package com.mass.util;
2 |
3 | import java.text.SimpleDateFormat;
4 | import java.util.Date;
5 |
6 | public class FormatTimestamp {
7 | public static String format(long timestamp) {
8 | Date date = new Date(timestamp);
9 | SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
10 | return sdf.format(date);
11 | }
12 | }
13 |
14 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/util/Property.java:
--------------------------------------------------------------------------------
1 | package com.mass.util;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStream;
5 | import java.io.InputStreamReader;
6 | import java.nio.charset.StandardCharsets;
7 | import java.util.Properties;
8 |
9 | public class Property {
10 | private static Properties contextProperties;
11 | private final static String CONFIG_NAME = "config.properties";
12 |
13 | static {
14 | InputStream in = Thread.currentThread().getContextClassLoader().getResourceAsStream(CONFIG_NAME);
15 | contextProperties = new Properties();
16 | try {
17 | InputStreamReader inputStreamReader = new InputStreamReader(in, StandardCharsets.UTF_8);
18 | contextProperties.load(inputStreamReader);
19 | } catch (IOException e) {
20 | System.err.println("===== Failed to load config file. =====");
21 | e.printStackTrace();
22 | }
23 | }
24 |
25 | public static String getStrValue(String key) {
26 | return contextProperties.getProperty(key);
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/util/RecordToEntity.java:
--------------------------------------------------------------------------------
1 | package com.mass.util;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import org.json.JSONObject;
5 |
6 | public class RecordToEntity {
7 | public static RecordEntity getRecord(String s, JSONObject userJSON, JSONObject itemJSON, int numUsers, int numItems) {
8 | String[] values = s.split(",");
9 | RecordEntity record = new RecordEntity();
10 | record.setUserId(userJSON.has(values[0]) ? userJSON.getInt(values[0]) : numUsers);
11 | record.setItemId(itemJSON.has(values[1]) ? itemJSON.getInt(values[1]) : numItems);
12 | record.setTime(Long.parseLong(values[2]));
13 | return record;
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/util/TypeConvert.java:
--------------------------------------------------------------------------------
1 | package com.mass.util;
2 |
3 | import org.apache.commons.lang3.StringUtils;
4 | import org.json.JSONArray;
5 | import org.json.JSONObject;
6 |
7 | import java.util.Arrays;
8 | import java.util.List;
9 |
10 | public class TypeConvert {
11 |
12 | public static String convertString(List items) {
13 | StringBuilder sb = new StringBuilder("{\"columns\": [\"x\"], \"data\": [[[");
14 | for (Integer item : items) {
15 | sb.append(String.format("%d,", item));
16 | }
17 | sb.deleteCharAt(sb.length() - 1);
18 | sb.append("]]]}");
19 | return sb.toString();
20 | }
21 |
22 | public static JSONArray convertJSON(String message) {
23 | JSONArray resArray = new JSONArray(message);
24 | return resArray.getJSONArray(0);
25 | }
26 |
27 | public static String convertEmbedding(List embeds, int user, int numRec) {
28 | // String repeat = StringUtils.repeat("\"x\",", embeds.size());
29 | // String columns = "{\"columns\": [" + repeat.substring(0, repeat.length() - 1) + "], ";
30 | String state = "{\"user\": " + user + ", \"n_rec\": " + numRec + ", ";
31 | StringBuilder sb = new StringBuilder("\"embedding\": [[");
32 | for (float num : embeds) {
33 | sb.append(String.format("%f,", num));
34 | }
35 | sb.deleteCharAt(sb.length() - 1);
36 | sb.append("]]}");
37 | return state + sb.toString();
38 | }
39 |
40 | public static String convertSeq(List seq, int user, int numRec) {
41 | String state = "{\"user\": [" + user + "], \"n_rec\": " + numRec + ", ";
42 | StringBuilder sb = new StringBuilder("\"item\": [[");
43 | for (int item : seq) {
44 | sb.append(String.format("%d,", item));
45 | }
46 | sb.deleteCharAt(sb.length() - 1);
47 | sb.append("]]}");
48 | return state + sb.toString();
49 | }
50 |
51 | public static void main(String[] args) {
52 | Float[] aa = {1.2f, 3.5f, 5.666f};
53 | List bb = Arrays.asList(aa);
54 | System.out.println(convertEmbedding(bb, 1, 10));
55 | Integer[] cc = {1, 2, 3};
56 | List dd = Arrays.asList(cc);
57 | System.out.println(convertSeq(dd, 1, 10));
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/window/CountProcessWindowFunction.java:
--------------------------------------------------------------------------------
1 | package com.mass.window;
2 |
3 | import com.mass.entity.TopItemEntity;
4 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
5 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
6 | import org.apache.flink.util.Collector;
7 |
8 | public class CountProcessWindowFunction extends ProcessWindowFunction {
9 |
10 | @Override
11 | public void process(Integer itemId, Context ctx, Iterable aggregateResults, Collector out) {
12 | long windowEnd = ctx.window().getEnd();
13 | long count = aggregateResults.iterator().next();
14 | out.collect(TopItemEntity.of(itemId, windowEnd, count));
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/window/ItemCollectWindowFunction.java:
--------------------------------------------------------------------------------
1 | package com.mass.window;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.entity.UserConsumed;
5 | import org.apache.flink.api.common.state.ValueState;
6 | import org.apache.flink.api.common.state.ValueStateDescriptor;
7 | import org.apache.flink.api.common.typeinfo.TypeHint;
8 | import org.apache.flink.api.common.typeinfo.TypeInformation;
9 | import org.apache.flink.api.java.tuple.Tuple;
10 | import org.apache.flink.configuration.Configuration;
11 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
12 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
13 | import org.apache.flink.util.Collector;
14 |
15 | import java.io.IOException;
16 | import java.util.LinkedList;
17 |
18 | public class ItemCollectWindowFunction extends ProcessWindowFunction {
19 |
20 | private final int histNum;
21 | private ValueState> itemState;
22 |
23 | public ItemCollectWindowFunction(int histNum) {
24 | this.histNum = histNum;
25 | }
26 |
27 | @Override
28 | public void open(Configuration parameters) throws Exception {
29 | super.open(parameters);
30 | ValueStateDescriptor> itemStateDesc = new ValueStateDescriptor<>(
31 | "itemState", TypeInformation.of(new TypeHint>() {}), new LinkedList<>());
32 | itemState = getRuntimeContext().getState(itemStateDesc);
33 | }
34 |
35 | @Override
36 | public void process(Tuple key, Context ctx, Iterable elements, Collector out)
37 | throws IOException {
38 | int userId = key.getField(0);
39 | long windowEnd = ctx.window().getEnd();
40 |
41 | LinkedList histItems = itemState.value();
42 | // if (histItems == null) {
43 | // histItems = new LinkedList<>();
44 | // }
45 | elements.forEach(e -> histItems.add(e.getItemId()));
46 | while (!histItems.isEmpty() && histItems.size() > histNum) {
47 | histItems.poll();
48 | }
49 | itemState.update(histItems);
50 | out.collect(UserConsumed.of(userId, histItems, windowEnd));
51 | }
52 |
53 | @Override
54 | public void close() {
55 | itemState.clear();
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/java/com/mass/window/ItemCollectWindowFunctionRedis.java:
--------------------------------------------------------------------------------
1 | package com.mass.window;
2 |
3 | import com.mass.entity.RecordEntity;
4 | import com.mass.entity.UserConsumed;
5 | import org.apache.flink.api.java.tuple.Tuple;
6 | import org.apache.flink.configuration.Configuration;
7 | import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction;
8 | import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
9 | import org.apache.flink.util.Collector;
10 | import redis.clients.jedis.Jedis;
11 |
12 | import java.util.ArrayList;
13 | import java.util.List;
14 |
15 | public class ItemCollectWindowFunctionRedis extends ProcessWindowFunction {
16 |
17 | private final int histNum;
18 | private static Jedis jedis;
19 |
20 | public ItemCollectWindowFunctionRedis(int histNum) {
21 | this.histNum = histNum;
22 | }
23 |
24 | @Override
25 | public void open(Configuration parameters) throws Exception {
26 | super.open(parameters);
27 | jedis = new Jedis("localhost", 6379);
28 | }
29 |
30 | @Override
31 | public void process(Tuple key, Context ctx, Iterable elements, Collector out) {
32 | int userId = key.getField(0);
33 | long windowEnd = ctx.window().getEnd();
34 |
35 | String keyList = "hist_" + userId;
36 | elements.forEach(e -> jedis.rpush(keyList , String.valueOf(e.getItemId())));
37 | while (jedis.llen(keyList) > histNum) {
38 | jedis.lpop(keyList);
39 | }
40 |
41 | List histItemsJedis = jedis.lrange(keyList, 0, -1);
42 | List histItems = new ArrayList<>();
43 | for (String s : histItemsJedis) {
44 | histItems.add(Integer.parseInt(s));
45 | }
46 | out.collect(UserConsumed.of(userId, histItems, windowEnd));
47 | }
48 |
49 | @Override
50 | public void close() {
51 | jedis.close();
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/resources/config.properties:
--------------------------------------------------------------------------------
1 | data.path=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/tianchi.csv
2 | user_map=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/features/user_map.json
3 | item_map=/home/massquantity/Workspace/flink-reinforcement-learning/FlinkRL/src/main/resources/features/item_map.json
--------------------------------------------------------------------------------
/FlinkRL/src/main/resources/db/create_table.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE IF NOT EXISTS record
2 | (
3 | user int NOT NULL,
4 | item int NOT NULL,
5 | time TIMESTAMP NOT NULL,
6 | PRIMARY KEY (user, item)
7 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
8 |
9 |
10 | CREATE TABLE IF NOT EXISTS top
11 | (
12 | item int NOT NULL,
13 | rank VARCHAR(5) NOT NULL,
14 | PRIMARY KEY (item, rank)
15 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/FlinkRL/src/main/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | log4j.rootLogger=WARN, console
2 | log4j.appender.console=org.apache.log4j.ConsoleAppender
3 | log4j.appender.console.target=System.out
4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{2}: %m%n
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Flink-Reinforcement-Learning
2 |
3 | ### `English` [`简体中文`](https://github.com/massquantity/Flink-Reinforcement-Learning/blob/master/README_zh.md) [`blog post`](https://www.cnblogs.com/massquantity/p/13842139.html)
4 |
5 |
6 |
7 | FlinkRL is a realtime recommender system based upon Flink and Reinforcement Learning. Specifically, Flink is famous for its high-performance stateful streaming processing, which can provide fast and accurate response to the user request in this system. And reinforcement learning is good at planning long-term reward and adjust recommendation results quickly according to realtime user feedback. The combination of these two components makes the system capture the dynamic change pattern of user interests and provide insightful recommendations.
8 |
9 | FlinkRL is mainly used for online inference, and the offline model training part is implemented in another repo, i.e. [DBRL](https://github.com/massquantity/DBRL). The full system architecture is as follows:
10 |
11 | 
12 |
13 |
14 |
15 | ## Main Workflow
16 |
17 | To simulate an online environment, a dataset is used as a producer in order to send data to Kafka, then Flink consumes data from Kafka. Afterwards, Flink will execute three tasks:
18 |
19 | + Save user behavior log to MongoDB and MySQL, and the log can be used for offline model training.
20 | + Compute realtime top-N popular items and save them to Redis.
21 | + Collect user's recent consumed items and post data to a web service created through [FastAPI](https://github.com/tiangolo/fastapi) to get recommendation results, then save the recommendation to MongoDB.
22 |
23 | Thus an online web server can directly take recommendation results and popular items from the databases and send them to the client. Yet another consideration is why using FastAPI to build another web service ? Because the model is trained through PyTorch, and there is seemingly no unified way to deploy a PyTorch model. So FastAPI is used to load the trained model, do some processing and make final recommendation.
24 |
25 |
26 |
27 | ## Data
28 |
29 | The dataset comes from a competition held by Tianchi, a Chinese competition platform. Please refer to the original website for [full description](https://tianchi.aliyun.com/competition/entrance/231721/information?lang=en-us). Note that here we only use the round2 data.
30 |
31 | You can also download the data from [Google Drive](https://drive.google.com/file/d/1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-/view?usp=sharing).
32 |
33 |
34 |
35 | ## Usage
36 |
37 | Python dependencies: python>=3.6, numpy, pandas, torch>=1.3, tqdm, FastAPI.
38 |
39 | You'll need to clone both `FlinkRL` and `DBRL`:
40 |
41 | ```shell
42 | $ git clone https://github.com/massquantity/Flink-Reinforcement-Learning.git
43 | $ git clone https://github.com/massquantity/DBRL.git
44 | ```
45 |
46 | We'll first use the `DBRL` repo to do offline training. After downloading the data, unzip and put them into the `DBRL/dbrl/resources` folder. The original dataset consists of three tables: `user.csv`, `item.csv`, `user_behavior.csv` . We'll first need to filter some users with too few interactions and merge all features together, and this is accomplished by `run_prepare_data.py`. Then we'll pretrain embeddings for every user and item by running `run_pretrain_embeddings.py` :
47 |
48 | ```shell
49 | $ cd DBRL/dbrl
50 | $ python run_prepare_data.py
51 | $ python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4
52 | ```
53 |
54 | You can tune the `lr` and `n_epochs` hyper-parameters to get better evaluate loss. Then we begin to train the model. Currently there are three algorithms in `DBRL`, so we can choose one of them:
55 |
56 | ```shell
57 | $ python run_reinforce.py --n_epochs 5 --lr 1e-5
58 | $ python run_ddpg.py --n_epochs 5 --lr 1e-5
59 | $ python run_bcq.py --n_epochs 5 --lr 1e-5
60 | ```
61 |
62 | At this point, the `DBRL/resources` should contains at least 6 files:
63 |
64 | + `model_xxx.pt`, the trained pytorch model.
65 | + `tianchi.csv`, the transformed dataset.
66 | + `tianchi_user_embeddings.npy`, the pretrained user embeddings in numpy `npy` format.
67 | + `tianchi_item_embeddings.npy`, the pretrained item embeddings in numpy `npy` format.
68 | + `user_map.json`, a json file that maps original user ids to ids used in the model.
69 | + `item_map.json`, a json file that maps original item ids to ids used in the model.
70 |
71 |
72 |
73 | After the offline training, we then turn to `FlinkRL` . First put three files: `model_xxx.pt`, `tianchi_user_embeddings.npy`, `tianchi_item_embeddings.npy` into the `Flink-Reinforcement-Learning/python_api` folder. Make sure you have already installed [FastAPI](https://github.com/tiangolo/fastapi), then start the service:
74 |
75 | ```shell
76 | $ gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is reinforce
77 | $ gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is ddpg
78 | $ gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is bcq
79 | ```
80 |
81 |
82 |
83 | The other three files : `tianchi.csv`, `user_map.json`, `item_map.json` are used in Flink, and in principle they can be put in anywhere, so long as you specify the absolute path in the `Flink-Reinforcement-Learning/FlinkRL/src/main/resources/config.properties` file.
84 |
85 | For quick running, you can directly import `FlinkRL` into an IDE, i.e. `IntelliJ IDEA`. To run on a cluster, we use Maven to package into a jar file:
86 |
87 | ```shell
88 | $ cd FlinkRL
89 | $ mvn clean package
90 | ```
91 |
92 | Put the generated `FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar` into the Flink installation directory. For now, assume `Kafka`, `MongoDB` and `Redis` have all been started, then we can start the Flink cluster and run tasks:
93 |
94 | ```shell
95 | $ kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic flink-rl # cereate a topic called flink-rl in Kafka
96 | $ use flink-rl # cereate a database called flink-rl in MongoDB
97 | ```
98 |
99 | ```shell
100 | $ # first cd into the Flink installation folder
101 | $ ./bin/start-cluster.sh # start the cluster
102 | $ ./bin/flink run --class com.mass.task.FileToKafka FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # import data into Kafka
103 | $ ./bin/flink run --class com.mass.task.RecordToMongoDB FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # save log to MongoDB
104 | $ ./bin/flink run --class com.mass.task.IntervalPopularItems FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # compute realtime popular items
105 | $ ./bin/flink run --class com.mass.task.SeqRecommend FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # recommend items using reinforcement learning model
106 | $ ./bin/stop-cluster.sh # stop the cluster
107 | ```
108 |
109 | Open [http://localhost:8081](http://localhost:8081/) to use Flink WebUI :
110 |
111 | 
112 |
113 |
114 |
115 | FastAPI also comes with an interactive WebUI, you can access it through http://127.0.0.1:8000/docs :
116 |
117 | 
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | # Flink-Reinforcement-Learning
2 |
3 | ### [`English`](https://github.com/massquantity/Flink-Reinforcement-Learning) `简体中文` [`博客文章`](https://www.cnblogs.com/massquantity/p/13842139.html)
4 |
5 |
6 |
7 | FlinkRL 是一个基于 Flink 和强化学习的推荐系统。具体来说,Flink 因其高性能有状态流处理而闻名,这能使系统快速而准确地响应用户请求。而强化学习则擅长规划长期收益,以及根据用户实时反馈快速调整推荐结果。二者的结合使得推荐系统能够捕获用户兴趣的动态变化规律并提供富有洞察力的推荐结果。
8 |
9 | FlinkRL 主要用于在线推理,而离线训练在另一个仓库 [DBRL](https://github.com/massquantity/DBRL) 中实现,下图是整体系统架构:
10 |
11 | 
12 |
13 |
14 |
15 | ## 主要流程
16 |
17 | 为模拟在线环境,由一个数据集作为生产者发送数据到 Kafka,Flink 则从 Kafka 消费数据。之后,Flink 可以执行三种任务:
18 |
19 | + 保存用户行为日志到 MongoDB 和 MySQL,日志可用于离线模型训练。
20 | + 计算实时 top-N 热门物品,保存到 Redis 。
21 | + 收集用户近期消费过的物品,发送数据到由 [FastAPI](https://github.com/tiangolo/fastapi) 创建的网络服务中,得到推荐结果并将其存到 MongoDB 。
22 |
23 | 之后在线服务器可以直接从数据库中调用推荐结果和热门物品发送到客户端。这里使用 FastAPI 建立另一个服务是因为线下训练使用的是 PyTorch,对于 PyTorch 模型现在貌似没有一个统一的部署方案,所以 FastAPI 用于载入模型,预处理和产生最终推荐结果。
24 |
25 |
26 |
27 | ## 数据
28 |
29 | 数据来源于天池的一个比赛,详情可参阅[官方网站](https://tianchi.aliyun.com/competition/entrance/231721/information?lang=zh-cn) ,注意这里只是用了第二轮的数据。也可以从 [Google Drive](https://drive.google.com/file/d/1erBjYEOa7IuOIGpI8pGPn1WNBAC4Rv0-/view?usp=sharing) 下载。
30 |
31 |
32 |
33 | ## 使用步骤
34 |
35 | Python 依赖库:python>=3.6, numpy, pandas, torch>=1.3, tqdm, FastAPI
36 |
37 | 首先 clone 两个仓库
38 |
39 | ```shell
40 | $ git clone https://github.com/massquantity/Flink-Reinforcement-Learning.git
41 | $ git clone https://github.com/massquantity/DBRL.git
42 | ```
43 |
44 | 首先使用 `DBRL` 作离线训练。下载完数据后,解压并放到 `DBRL/dbrl/resources` 文件夹中。原始数据有三张表:`user.csv`, `item.csv`, `user_behavior.csv` 。首先用脚本 `run_prepare_data.py` 过滤掉一些行为太少的用户并将所有特征合并到一张表。接着用 `run_pretrain_embeddings.py` 为每个用户和物品预训练 embedding:
45 |
46 | ```shell
47 | $ cd DBRL/dbrl
48 | $ python run_prepare_data.py
49 | $ python run_pretrain_embeddings.py --lr 0.001 --n_epochs 4
50 | ```
51 |
52 | 可以调整一些参数如 `lr` 和 `n_epochs` 来获得更好的评估效果。接下来开始训练模型,现在在 `DBRL` 中有三种模型,任选一种即可:
53 |
54 | ```shell
55 | $ python run_reinforce.py --n_epochs 5 --lr 1e-5
56 | $ python run_ddpg.py --n_epochs 5 --lr 1e-5
57 | $ python run_bcq.py --n_epochs 5 --lr 1e-5
58 | ```
59 |
60 | 到这个阶段,`DBRL/resources` 中应该至少有 6 个文件:
61 |
62 | + `model_xxx.pt`, 训练好的 PyTorch 模型。
63 | + `tianchi.csv`, 转换过的数据集。
64 | + `tianchi_user_embeddings.npy`, `npy` 格式的 user 预训练 embedding。
65 | + `tianchi_item_embeddings.npy`, `npy` 格式的 item 预训练 embedding。
66 | + `user_map.json`, 将原始用户 id 映射到模型中 id 的 json 文件。
67 | + `item_map.json`, 将原始物品 id 映射到模型中 id 的 json 文件。
68 |
69 |
70 |
71 | 离线训练完成后,接下来使用 `FlinkRL` 作在线推理。首先将三个文件 `model_xxx.pt`, `tianchi_user_embeddings.npy`, `tianchi_item_embeddings.npy` 放入 `Flink-Reinforcement-Learning/python_api` 文件夹。确保已经安装了[FastAPI](https://github.com/tiangolo/fastapi), 就可以启动服务了:
72 |
73 | ```shell
74 | $ gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is reinforce
75 | $ gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is ddpg
76 | $ gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker # if the model is bcq
77 | ```
78 |
79 |
80 |
81 | 另外三个文件:`tianchi.csv`, `user_map.json`, `item_map.json` 在 Flink 中使用,原则上可以放在任何地方,只要在 `Flink-Reinforcement-Learning/FlinkRL/src/main/resources/config.properties` 中指定绝对路径。
82 |
83 | 如果想要快速开始,可以直接将 `FlinkRL` 导入到 `IntelliJ IDEA` 这样的 IDE 中。而要在集群上运行,则使用 Maven 达成 jar 包:
84 |
85 | ```shell
86 | $ cd FlinkRL
87 | $ mvn clean package
88 | ```
89 |
90 | 将生成的 `FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar` 放到 Flink 安装目录。现在假设 `Kafka`, `MongoDB` and `Redis` 都已然启动,然后就可以启动 Flink 集群并执行任务:
91 |
92 | ```shell
93 | $ kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic flink-rl # cereate a topic called flink-rl in Kafka
94 | $ use flink-rl # cereate a database called flink-rl in MongoDB
95 | ```
96 |
97 | ```shell
98 | $ # first cd into the Flink installation folder
99 | $ ./bin/start-cluster.sh # start the cluster
100 | $ ./bin/flink run --class com.mass.task.FileToKafka FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # import data into Kafka
101 | $ ./bin/flink run --class com.mass.task.RecordToMongoDB FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # save log to MongoDB
102 | $ ./bin/flink run --class com.mass.task.IntervalPopularItems FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # compute realtime popular items
103 | $ ./bin/flink run --class com.mass.task.SeqRecommend FlinkRL-1.0-SNAPSHOT-jar-with-dependencies.jar # recommend items using reinforcement learning model
104 | $ ./bin/stop-cluster.sh # stop the cluster
105 | ```
106 |
107 | 打开 [http://localhost:8081](http://localhost:8081/) 可使用 Flink WebUI :
108 |
109 | 
110 |
111 |
112 |
113 | FastAPI 中也提供了交互式的 WebUI,可访问 http://127.0.0.1:8000/docs :
114 |
115 | 
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/pic/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/5.png
--------------------------------------------------------------------------------
/pic/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/6.jpg
--------------------------------------------------------------------------------
/pic/readme1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/readme1.png
--------------------------------------------------------------------------------
/pic/readme4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/massquantity/Flink-Reinforcement-Learning/4f1b252a6884fa35b05122417d174901508fb788/pic/readme4.png
--------------------------------------------------------------------------------
/python_api/bcq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import Optional, List
4 | from fastapi import FastAPI, HTTPException, Body
5 | from pydantic import BaseModel, Field
6 | from utils import load_model_bcq
7 |
8 | app = FastAPI()
9 | hist_num = 10
10 | action_dim = 32
11 | input_dim = action_dim * (hist_num + 1)
12 | hidden_size = 64
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 | model_path = "model_bcq.pt"
15 | user_embeddings_path = "tianchi_user_embeddings.npy"
16 | item_embeddings_path = "tianchi_item_embeddings.npy"
17 |
18 | generator, perturbator = load_model_bcq(
19 | model_path, user_embeddings_path, item_embeddings_path,
20 | input_dim, action_dim, hidden_size, device=device
21 | )
22 |
23 |
24 | class Seq(BaseModel):
25 | user: List[int]
26 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]])
27 | n_rec: int
28 |
29 |
30 | class State(BaseModel):
31 | user: List[int]
32 | embedding: List[List[float]]
33 | n_rec: int
34 |
35 |
36 | @app.post("/{algo}")
37 | async def recommend(algo: str, seq: Seq) -> list:
38 | if algo == "bcq":
39 | with torch.no_grad():
40 | data = {
41 | "user": torch.as_tensor(seq.user),
42 | "item": torch.as_tensor(seq.item)
43 | }
44 | state = generator.get_state(data)
45 | gen_actions = generator.decode(state)
46 | action = perturbator(state, gen_actions)
47 | scores = torch.matmul(action, generator.item_embeds.weight.T)
48 | _, res = torch.topk(scores, seq.n_rec, dim=1, sorted=False)
49 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}"
50 | return res.tolist()
51 | else:
52 | raise HTTPException(status_code=404, detail="wrong algorithm.")
53 |
54 |
55 | @app.post("/{algo}/state")
56 | async def recommend_with_state(algo: str, state: State):
57 | if algo == "bcq":
58 | with torch.no_grad():
59 | data = torch.as_tensor(state.embedding)
60 | gen_actions = generator.decode(data)
61 | action = perturbator(state, gen_actions)
62 | scores = torch.matmul(action, generator.item_embeds.weight.T)
63 | _, res = torch.topk(scores, state.n_rec, dim=1, sorted=False)
64 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}"
65 | return res.tolist()
66 | else:
67 | raise HTTPException(status_code=404, detail="wrong algorithm.")
68 |
69 |
70 | # gunicorn bcq:app -w 4 -k uvicorn.workers.UvicornWorker
71 | # curl -X POST "http://127.0.0.1:8000/bcq" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}'
72 |
--------------------------------------------------------------------------------
/python_api/ddpg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import Optional, List
4 | from fastapi import FastAPI, HTTPException, Body
5 | from pydantic import BaseModel, Field
6 | from utils import load_model_ddpg
7 |
8 | app = FastAPI()
9 | hist_num = 10
10 | action_dim = 32
11 | input_dim = action_dim * (hist_num + 1)
12 | hidden_size = 64
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 | model_path = "model_ddpg.pt"
15 | user_embeddings_path = "tianchi_user_embeddings.npy"
16 | item_embeddings_path = "tianchi_item_embeddings.npy"
17 |
18 | model = load_model_ddpg(
19 | model_path, user_embeddings_path, item_embeddings_path,
20 | input_dim, action_dim, hidden_size, device=device
21 | )
22 |
23 |
24 | class Seq(BaseModel):
25 | user: List[int]
26 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]])
27 | n_rec: int
28 |
29 |
30 | class State(BaseModel):
31 | user: List[int]
32 | embedding: List[List[float]]
33 | n_rec: int
34 |
35 |
36 | @app.post("/{algo}")
37 | async def recommend(algo: str, seq: Seq) -> list:
38 | if algo == "ddpg":
39 | with torch.no_grad():
40 | data = {
41 | "user": torch.as_tensor(seq.user),
42 | "item": torch.as_tensor(seq.item)
43 | }
44 | _, action = model(data)
45 | scores = torch.matmul(action, model.item_embeds.weight.T)
46 | _, res = torch.topk(scores, seq.n_rec, dim=1, sorted=False)
47 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}"
48 | return res.tolist()
49 | else:
50 | raise HTTPException(status_code=404, detail="wrong algorithm.")
51 |
52 |
53 | @app.post("/{algo}/state")
54 | async def recommend_with_state(algo: str, state: State):
55 | if algo == "ddpg":
56 | with torch.no_grad():
57 | data = torch.as_tensor(state.embedding)
58 | action = model.get_action(data)
59 | scores = torch.matmul(action, model.item_embeds.weight.T)
60 | _, res = torch.topk(scores, state.n_rec, dim=1, sorted=False)
61 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}"
62 | return res.tolist()
63 | else:
64 | raise HTTPException(status_code=404, detail="wrong algorithm.")
65 |
66 |
67 | # gunicorn ddpg:app -w 4 -k uvicorn.workers.UvicornWorker
68 | # curl -X POST "http://127.0.0.1:8000/ddpg" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}'
69 |
--------------------------------------------------------------------------------
/python_api/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def multihead_attention(
7 | attention,
8 | user_embeds,
9 | item_embeds,
10 | items,
11 | pad_val,
12 | att_mode="normal"
13 | ):
14 | mask = (items == pad_val)
15 | mask[torch.where((mask == torch.tensor(1)).all(1))] = False
16 | item_repr = item_embeds.transpose(1, 0)
17 | if att_mode == "normal":
18 | user_repr = user_embeds.unsqueeze(0)
19 | output, weight = attention(user_repr, item_repr, item_repr,
20 | key_padding_mask=mask)
21 | output = output.squeeze()
22 | elif att_mode == "self":
23 | output, weight = attention(item_repr, item_repr, item_repr,
24 | key_padding_mask=mask)
25 | output = output.transpose(1, 0).mean(dim=1) # max[0], sum
26 | else:
27 | raise ValueError("attention must be 'normal' or 'self'")
28 | return output
29 |
30 |
31 | class Actor(nn.Module):
32 | def __init__(
33 | self,
34 | input_dim,
35 | action_dim,
36 | hidden_size,
37 | user_embeds,
38 | item_embeds,
39 | attention=None,
40 | pad_val=0,
41 | head=1,
42 | device=torch.device("cpu")
43 | ):
44 | super(Actor, self).__init__()
45 | self.user_embeds = nn.Embedding.from_pretrained(
46 | torch.as_tensor(user_embeds), freeze=True,
47 | )
48 | self.item_embeds = nn.Embedding.from_pretrained(
49 | torch.as_tensor(item_embeds), freeze=True,
50 | )
51 | if attention is not None:
52 | self.attention = nn.MultiheadAttention(
53 | item_embeds.shape[1], head
54 | ).to(device)
55 | self.att_mode = attention
56 | else:
57 | self.attention = None
58 |
59 | # self.dropout = nn.Dropout(0.5)
60 | self.fc1 = nn.Linear(input_dim, hidden_size)
61 | self.fc2 = nn.Linear(hidden_size, hidden_size)
62 | self.fc3 = nn.Linear(hidden_size, action_dim)
63 | self.pad_val = pad_val
64 | self.device = device
65 |
66 | def get_state(self, data, next_state=False):
67 | user_repr = self.user_embeds(data["user"])
68 | item_col = "next_item" if next_state else "item"
69 | items = data[item_col]
70 |
71 | if not self.attention:
72 | batch_size = data[item_col].size(0)
73 | item_repr = self.item_embeds(data[item_col]).view(batch_size, -1)
74 |
75 | # true_num = torch.sum(items != self.pad_val, dim=1) + 1e-5
76 | # item_repr = self.item_embeds(items).sum(dim=1)
77 | # compute mean embeddings
78 | # item_repr = torch.div(item_repr, true_num.float().view(-1, 1))
79 | else:
80 | item_repr = self.item_embeds(items)
81 | item_repr = multihead_attention(self.attention, user_repr,
82 | item_repr, items, self.pad_val,
83 | self.att_mode)
84 |
85 | state = torch.cat([user_repr, item_repr], dim=1)
86 | return state
87 |
88 | def get_action(self, state, tanh=False):
89 | action = F.relu(self.fc1(state))
90 | action = F.relu(self.fc2(action))
91 | action = self.fc3(action)
92 | if tanh:
93 | action = F.tanh(action)
94 | return action
95 |
96 | def forward(self, data, tanh=False):
97 | state = self.get_state(data)
98 | action = self.get_action(state, tanh)
99 | return state, action
100 |
101 |
102 | class Critic(nn.Module):
103 | def __init__(
104 | self,
105 | input_dim,
106 | action_dim,
107 | hidden_size
108 | ):
109 | super(Critic, self).__init__()
110 | self.fc1 = nn.Linear(input_dim + action_dim, hidden_size)
111 | self.fc2 = nn.Linear(hidden_size, hidden_size)
112 | self.fc3 = nn.Linear(hidden_size, 1)
113 |
114 | def forward(self, state, action):
115 | out = torch.cat([state, action], dim=1)
116 | out = F.relu(self.fc1(out))
117 | out = F.relu(self.fc2(out))
118 | out = self.fc3(out)
119 | return out.squeeze()
120 |
121 |
122 | class PolicyPi(Actor):
123 | def __init__(
124 | self,
125 | input_dim,
126 | action_dim,
127 | hidden_size,
128 | user_embeds,
129 | item_embeds,
130 | attention=None,
131 | pad_val=0,
132 | head=1,
133 | device=torch.device("cpu")
134 | ):
135 | super(PolicyPi, self).__init__(
136 | input_dim,
137 | action_dim,
138 | hidden_size,
139 | user_embeds,
140 | item_embeds,
141 | attention,
142 | pad_val,
143 | head,
144 | device
145 | )
146 |
147 | # self.dropout = nn.Dropout(0.5)
148 | self.pfc1 = nn.Linear(input_dim, hidden_size)
149 | self.pfc2 = nn.Linear(hidden_size, hidden_size)
150 | self.softmax_fc = nn.Linear(hidden_size, action_dim)
151 |
152 | def get_action(self, state, tanh=False):
153 | action = F.relu(self.pfc1(state))
154 | action = F.relu(self.pfc2(action))
155 | if tanh:
156 | action = F.tanh(action)
157 | return action
158 |
159 | def forward(self, data, tanh=False):
160 | state = self.get_state(data)
161 | action = self.get_action(state, tanh)
162 | return state, action
163 |
164 | def get_log_probs(self, data):
165 | state, action = self.forward(data)
166 | logits = self.softmax_fc(action)
167 | log_prob = F.log_softmax(logits, dim=1)
168 | return state, log_prob, action
169 |
170 | def get_beta_state(self, data):
171 | user_repr = self.user_embeds(data["beta_user"])
172 | items = data["beta_item"]
173 | if not self.attention:
174 | # true_num = torch.sum(items != self.pad_val, dim=1) + 1e-5
175 | # item_repr = self.item_embeds(items).sum(dim=1)
176 | # item_repr = torch.div(item_repr, true_num.float().view(-1, 1))
177 | batch_size = data["item"].size(0)
178 | item_repr = self.item_embeds(data["item"]).view(batch_size, -1)
179 | else:
180 | item_repr = self.item_embeds(items)
181 | item_repr = multihead_attention(self.attention, user_repr,
182 | item_repr, items, self.pad_val,
183 | self.att_mode)
184 |
185 | state = torch.cat([user_repr, item_repr], dim=1)
186 | return state
187 |
188 |
189 | class Beta(nn.Module):
190 | def __init__(
191 | self,
192 | input_dim,
193 | action_dim,
194 | hidden_size
195 | ):
196 | super(Beta, self).__init__()
197 | self.bfc = nn.Linear(input_dim, hidden_size)
198 | self.softmax_fc = nn.Linear(hidden_size, action_dim)
199 |
200 | def forward(self, state):
201 | out = F.relu(self.bfc(state))
202 | return out
203 |
204 | def get_log_probs(self, state):
205 | action = self.forward(state)
206 | logits = self.softmax_fc(action)
207 | log_prob = F.log_softmax(logits, dim=1)
208 | return log_prob, logits
209 |
210 |
211 | class VAE(Actor):
212 | def __init__(
213 | self,
214 | input_dim,
215 | action_dim,
216 | latent_dim,
217 | hidden_size,
218 | user_embeds,
219 | item_embeds,
220 | attention=None,
221 | pad_val=0,
222 | head=1,
223 | device=torch.device("cpu")
224 | ):
225 | super(VAE, self).__init__(
226 | input_dim,
227 | action_dim,
228 | hidden_size,
229 | user_embeds,
230 | item_embeds,
231 | attention,
232 | pad_val,
233 | head,
234 | device
235 | )
236 |
237 | self.encoder1 = nn.Linear(input_dim + action_dim, hidden_size)
238 | self.encoder2 = nn.Linear(hidden_size, hidden_size)
239 |
240 | self.mean = nn.Linear(hidden_size, latent_dim)
241 | self.log_std = nn.Linear(hidden_size, latent_dim)
242 |
243 | self.decoder1 = nn.Linear(input_dim + latent_dim, hidden_size)
244 | self.decoder2 = nn.Linear(hidden_size, hidden_size)
245 | self.decoder3 = nn.Linear(hidden_size, action_dim)
246 | self.latent_dim = latent_dim
247 | self.device = device
248 |
249 | def forward(self, data, action):
250 | state = self.get_state(data)
251 | z = F.relu(self.encoder1(torch.cat([state, action], dim=1)))
252 | z = F.relu(self.encoder2(z))
253 |
254 | mean = self.mean(z)
255 | log_std = self.log_std(z).clamp(-4, 15)
256 | std = log_std.exp()
257 | z = mean + std * torch.randn_like(std)
258 |
259 | u = self.decode(state, z)
260 | return state, u, mean, std
261 |
262 | def decode(self, state, z=None):
263 | if z is None:
264 | z = torch.randn(state.size(0), self.latent_dim)
265 | z = z.clamp(-0.5, 0.5).to(self.device)
266 | action = F.relu(self.decoder1(torch.cat([state, z], dim=1)))
267 | action = F.relu(self.decoder2(action))
268 | action = self.decoder3(action)
269 | return action
270 |
271 |
272 | class Perturbator(nn.Module):
273 | def __init__(
274 | self,
275 | input_dim,
276 | action_dim,
277 | hidden_size,
278 | phi=0.05,
279 | action_range=None
280 | ):
281 | super(Perturbator, self).__init__()
282 | self.fc1 = nn.Linear(input_dim + action_dim, hidden_size)
283 | self.fc2 = nn.Linear(hidden_size, hidden_size)
284 | self.fc3 = nn.Linear(hidden_size, action_dim)
285 | self.phi = phi
286 | self.action_range = action_range
287 |
288 | def forward(self, state, action):
289 | a = F.relu(self.fc1(torch.cat([state, action], dim=1)))
290 | a = F.relu(self.fc2(a))
291 | a = self.fc3(a)
292 | if self.phi is not None:
293 | a = self.phi * torch.tanh(a)
294 |
295 | act = action + a
296 | if self.action_range is not None:
297 | act = act.clamp(self.action_range[0], self.action_range[1])
298 | return act
299 |
300 |
--------------------------------------------------------------------------------
/python_api/reinforce.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import Optional, List
4 | from fastapi import FastAPI, HTTPException, Body
5 | from pydantic import BaseModel, Field
6 | from utils import load_model_reinforce
7 |
8 | app = FastAPI()
9 | hist_num = 10
10 | embed_size = 32
11 | hidden_size = 64
12 | input_dim = embed_size * (hist_num + 1)
13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 | model_path = "model_reinforce.pt"
15 | user_embeddings_path = "tianchi_user_embeddings.npy"
16 | item_embeddings_path = "tianchi_item_embeddings.npy"
17 | with open(item_embeddings_path, "rb") as f:
18 | action_dim = len(np.load(f))
19 |
20 | model = load_model_reinforce(
21 | model_path, user_embeddings_path, item_embeddings_path,
22 | input_dim, action_dim, hidden_size, device=device
23 | )
24 |
25 |
26 | class Seq(BaseModel):
27 | user: List[int]
28 | item: List[List[int]] = Field(..., example=[[1,2,3,4,5,6,7,8,9,10]])
29 | n_rec: int
30 |
31 |
32 | class State(BaseModel):
33 | user: List[int]
34 | embedding: List[List[float]]
35 | n_rec: int
36 |
37 |
38 | @app.post("/{algo}")
39 | async def recommend(algo: str, seq: Seq) -> list:
40 | if algo == "reinforce":
41 | with torch.no_grad():
42 | data = {
43 | "user": torch.as_tensor(seq.user),
44 | "item": torch.as_tensor(seq.item)
45 | }
46 | _, action_probs, _ = model.get_log_probs(data)
47 | _, res = torch.topk(action_probs, seq.n_rec, dim=1, sorted=False)
48 | # return f"Recommend {seq.n_rec} items for user {seq.user}: {res}"
49 | return res.tolist()
50 | else:
51 | raise HTTPException(status_code=404, detail="wrong algorithm.")
52 |
53 |
54 | @app.post("/{algo}/state")
55 | async def recommend_with_state(algo: str, state: State):
56 | if algo == "reinforce":
57 | with torch.no_grad():
58 | data = torch.as_tensor(state.embedding)
59 | action = model.get_action(data)
60 | action_probs = model.softmax_fc(action)
61 | _, res = torch.topk(action_probs, state.n_rec, dim=1, sorted=False)
62 | # return f"Recommend {state.n_rec} items for user {state.user}: {res}"
63 | return res.tolist()
64 | else:
65 | raise HTTPException(status_code=404, detail="wrong algorithm.")
66 |
67 |
68 | # gunicorn reinforce:app -w 4 -k uvicorn.workers.UvicornWorker
69 | # curl -X POST "http://127.0.0.1:8000/reinforce" -H "accept: application/json" -d '{"user": [1], "item": [[1,2,3,4,5,6,7,8,9,10]], "n_rec": 8}'
70 |
--------------------------------------------------------------------------------
/python_api/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from net import Actor, PolicyPi, VAE, Perturbator
4 |
5 |
6 | def load_model_reinforce(
7 | model_path,
8 | user_embeddings_path,
9 | item_embeddings_path,
10 | input_dim,
11 | action_dim,
12 | hidden_size,
13 | device
14 | ):
15 | with open(user_embeddings_path, "rb") as f:
16 | user_embeddings = np.load(f)
17 | with open(item_embeddings_path, "rb") as f:
18 | item_embeddings = np.load(f)
19 | model = PolicyPi(
20 | input_dim, action_dim, hidden_size, user_embeddings, item_embeddings
21 | )
22 | model.load_state_dict(torch.load(model_path, map_location=device))
23 | model.eval()
24 | return model
25 |
26 |
27 | def load_model_ddpg(
28 | model_path,
29 | user_embeddings_path,
30 | item_embeddings_path,
31 | input_dim,
32 | action_dim,
33 | hidden_size,
34 | device
35 | ):
36 | with open(user_embeddings_path, "rb") as f:
37 | user_embeddings = np.load(f)
38 | with open(item_embeddings_path, "rb") as f:
39 | item_embeddings = np.load(f)
40 | model = Actor(
41 | input_dim, action_dim, hidden_size, user_embeddings, item_embeddings
42 | )
43 | model.load_state_dict(torch.load(model_path, map_location=device))
44 | model.eval()
45 | return model
46 |
47 |
48 | def load_model_bcq(
49 | model_path,
50 | user_embeddings_path,
51 | item_embeddings_path,
52 | input_dim,
53 | action_dim,
54 | hidden_size,
55 | device
56 | ):
57 | with open(user_embeddings_path, "rb") as f:
58 | user_embeddings = np.load(f)
59 | with open(item_embeddings_path, "rb") as f:
60 | item_embeddings = np.load(f)
61 | generator = VAE(
62 | input_dim, action_dim, action_dim * 2, hidden_size,
63 | user_embeddings, item_embeddings
64 | )
65 | perturbator = Perturbator(input_dim, action_dim, hidden_size)
66 | checkpoint = torch.load(model_path, map_location=device)
67 | generator.load_state_dict(checkpoint["generator"])
68 | perturbator.load_state_dict(checkpoint["perturbator"])
69 | generator.eval()
70 | perturbator.eval()
71 | return generator, perturbator
72 |
--------------------------------------------------------------------------------