├── .idea
└── vcs.xml
├── README.md
├── cf_gbdt_lr_predict.py
├── data
├── u.data
├── u.info
└── u.item
├── data_process.py
└── gbdt_lr.py
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cf_gbdt_lr
2 | 简单的实现推荐系统的召回模型和排序模型,其中召回模型使用协同过滤算法,排序模型使用gbdt+lr算法
3 |
4 | 使用的数据为ml-100k的数据,
5 | data_process.py 为数据处理脚本
6 | gbdt_lr.model 为gbdt和lr的排序模型的训练脚本
7 | cf_gbdt_lr_prdict.py 为融合ALS和gbdt_lr整体的预测,其中ALS为召回模型
--------------------------------------------------------------------------------
/cf_gbdt_lr_predict.py:
--------------------------------------------------------------------------------
1 | from pyspark import SparkContext, HiveContext
2 | from pyspark.ml.recommendation import ALS
3 |
4 | import numpy as np
5 | from sklearn.externals import joblib
6 | from data_process import Rating_info, transfromed_feature, item_info, user_info
7 |
8 | def predict(rating_file_path, user_file_path, item_file_path):
9 | rating = Rating_info(sc, rating_file_path)
10 | als = ALS(rank=10, maxIter=6)
11 | alsmodel = als.fit(rating)
12 | user_item = alsmodel.recommendForAllUsers(10).map(lambda x:x[1]).flatmap(lambda x:x).toDF()
13 | item_info_df = item_info(sc, item_file_path)
14 | user_info_df = user_info(sc, user_file_path)
15 | df = user_item.join(user_info_df, on='user_id', how='left') \
16 | .join(item_info_df, on='item_id', how='left')
17 | feature = ['age', 'gender', 'action', 'adventure', 'animation', 'childrens', 'comedy', \
18 | 'crime', 'documentary', 'drama', 'fantasy', 'film_noir', 'horrormusical', 'mystery', 'romance', \
19 | 'sci_fi', 'thriller', 'unknow', 'war', 'western']
20 | predict_data = [[float(data[i])] for i in range(feature) for data in df.select(feature).collect()]
21 |
22 | print("starting gdbt...")
23 | gbdt_model = joblib.load('../model/gbdt_model/gbdt.model')
24 | leaf = gbdt_model.apply(predict_data)[:,:,0].astype(int)
25 | print("starting transform")
26 | transform_feature = transfromed_feature(leaf, leaf.max())
27 | print("starting lr model...")
28 | lr_model = joblib.load("../model/lr.model")
29 | y_pred = lr_model.predict(transform_feature)
30 | print(y_pred[:10])
31 |
32 | if __name__=="__main__":
33 | sc = SparkContext('local', 'predict')
34 | sqlcontext = HiveContext(sc)
35 | sc.setLogLevel("ERROR")
36 | rating_file_path = "E:/data/ml-100k/u.data"
37 | user_file_path = "E:/data/ml-100k/u.user"
38 | item_file_path = "E:/data/ml-100k/u.item"
39 | predict(rating_file_path, user_file_path, item_file_path)
40 |
--------------------------------------------------------------------------------
/data/u.info:
--------------------------------------------------------------------------------
1 | 943 users
2 | 1682 items
3 | 100000 ratings
4 |
--------------------------------------------------------------------------------
/data/u.item:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuxinping1992/cf_gbdt_lr/2345f68db309a8577ac358a8da1a5a436abe4bf7/data/u.item
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | from pyspark import SparkContext, HiveContext
2 | from pyspark.mllib.recommendation import Rating
3 | from pyspark.sql import Row
4 | import numpy as np
5 |
6 |
7 |
8 | #load rating data
9 |
10 | def Rating_info(sc, file_path):
11 | rating = sc.textFile(file_path).map(lambda x:x.split('\t')).map(lambda x: Rating(x[0],x[1],x[2]))
12 | return rating
13 |
14 | def split_age(age):
15 | if age <= 20:
16 | return 0
17 | elif age <= 45:
18 | return 1
19 | elif age <= 60:
20 | return 2
21 | else:
22 | return 3
23 | def transform_gender(gender):
24 | if gender == 'F':
25 | return 0
26 | else:
27 | return 1
28 |
29 | # transform feature
30 | def transfromed_feature(leaf, num_leaf):
31 | transfrom_feature_matrix = np.zeros([len(leaf), len(leaf[0]) * num_leaf], dtype=np.int64)
32 | for i in range(len(leaf)):
33 | temp = np.arange(len(leaf[0])) * num_leaf - 1 + np.array(leaf[i])
34 | transfrom_feature_matrix[i][temp] += 1
35 | return transfrom_feature_matrix
36 |
37 | #load user info data
38 |
39 | def user_info(sc, file_path):
40 | user_info=sc.textFile(file_path).map(lambda x: x.split('|'))
41 | user_info_df = sc.parallelize((Row(user_id=data[0], age=float(split_age(int(data[1]))), gender=transform_gender(data[2]))) for data in user_info.collect()).toDF()
42 | return user_info_df
43 |
44 | #load item info data
45 |
46 | def item_info(sc, file_path):
47 | item_info = sc.textFile(file_path).map(lambda line: line.split("|"))
48 | item_info_df = sc.parallelize((Row(item_id=data[0], movie_title=data[1], release_date=data[2], \
49 | video_release_data=data[3], imdb_url=data[4], unknow=float(data[5]), \
50 | action=float(data[6]), adventure=float(data[7]), animation=float(data[8]), \
51 | childrens=float(data[9]), comedy=float(data[10]), \
52 | crime=float(data[11]), documentary=float(data[12]), drama=float(data[13]),
53 | fantasy=float(data[14]), film_noir=float(data[15]), horror=float(data[5]), \
54 | musical=float(data[16]), mystery=float(data[17]), romance=float(data[18]), \
55 | sci_fi=float(data[19]), thriller=float(data[10]), war=float(data[21]), \
56 | western=float(data[22]))) for data in item_info.collect()).toDF()
57 | return item_info_df
58 |
59 | #load sample
60 |
61 |
62 | def sample(sc, rating_file_path, user_file_path, item_file_path, k):
63 | item_info_df = item_info(sc, item_file_path)
64 | user_info_df = user_info(sc, user_file_path)
65 | num_item = range(item_info_df.count())
66 | #pos example
67 | pos = sc.textFile(rating_file_path).map(lambda x:x.split('\t'))
68 | pos_sample = sc.parallelize((Row(user_id=data[0], item_id=data[1], label=float(1))) for data in pos.collect()).toDF()
69 | pos_user_item = [[int(data[0]), int(data[1])] for data in pos_sample.select(['user_id','item_id']).collect()]
70 | pos_user_item_dict = {}
71 | neg_sample = []
72 | print("starting...")
73 | for data in pos_user_item:
74 | if data[0] not in pos_user_item_dict.keys():
75 | pos_user_item_dict[data[0]] = [data[1]]
76 | else:
77 | pos_user_item_dict[data[0]].append(data[1])
78 | for data in pos_user_item:
79 | i = 0
80 | while i