├── .idea └── vcs.xml ├── README.md ├── cf_gbdt_lr_predict.py ├── data ├── u.data ├── u.info └── u.item ├── data_process.py └── gbdt_lr.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cf_gbdt_lr 2 | 简单的实现推荐系统的召回模型和排序模型,其中召回模型使用协同过滤算法,排序模型使用gbdt+lr算法 3 | 4 | 使用的数据为ml-100k的数据, 5 | data_process.py 为数据处理脚本 6 | gbdt_lr.model 为gbdt和lr的排序模型的训练脚本 7 | cf_gbdt_lr_prdict.py 为融合ALS和gbdt_lr整体的预测,其中ALS为召回模型 -------------------------------------------------------------------------------- /cf_gbdt_lr_predict.py: -------------------------------------------------------------------------------- 1 | from pyspark import SparkContext, HiveContext 2 | from pyspark.ml.recommendation import ALS 3 | 4 | import numpy as np 5 | from sklearn.externals import joblib 6 | from data_process import Rating_info, transfromed_feature, item_info, user_info 7 | 8 | def predict(rating_file_path, user_file_path, item_file_path): 9 | rating = Rating_info(sc, rating_file_path) 10 | als = ALS(rank=10, maxIter=6) 11 | alsmodel = als.fit(rating) 12 | user_item = alsmodel.recommendForAllUsers(10).map(lambda x:x[1]).flatmap(lambda x:x).toDF() 13 | item_info_df = item_info(sc, item_file_path) 14 | user_info_df = user_info(sc, user_file_path) 15 | df = user_item.join(user_info_df, on='user_id', how='left') \ 16 | .join(item_info_df, on='item_id', how='left') 17 | feature = ['age', 'gender', 'action', 'adventure', 'animation', 'childrens', 'comedy', \ 18 | 'crime', 'documentary', 'drama', 'fantasy', 'film_noir', 'horrormusical', 'mystery', 'romance', \ 19 | 'sci_fi', 'thriller', 'unknow', 'war', 'western'] 20 | predict_data = [[float(data[i])] for i in range(feature) for data in df.select(feature).collect()] 21 | 22 | print("starting gdbt...") 23 | gbdt_model = joblib.load('../model/gbdt_model/gbdt.model') 24 | leaf = gbdt_model.apply(predict_data)[:,:,0].astype(int) 25 | print("starting transform") 26 | transform_feature = transfromed_feature(leaf, leaf.max()) 27 | print("starting lr model...") 28 | lr_model = joblib.load("../model/lr.model") 29 | y_pred = lr_model.predict(transform_feature) 30 | print(y_pred[:10]) 31 | 32 | if __name__=="__main__": 33 | sc = SparkContext('local', 'predict') 34 | sqlcontext = HiveContext(sc) 35 | sc.setLogLevel("ERROR") 36 | rating_file_path = "E:/data/ml-100k/u.data" 37 | user_file_path = "E:/data/ml-100k/u.user" 38 | item_file_path = "E:/data/ml-100k/u.item" 39 | predict(rating_file_path, user_file_path, item_file_path) 40 | -------------------------------------------------------------------------------- /data/u.info: -------------------------------------------------------------------------------- 1 | 943 users 2 | 1682 items 3 | 100000 ratings 4 | -------------------------------------------------------------------------------- /data/u.item: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuxinping1992/cf_gbdt_lr/2345f68db309a8577ac358a8da1a5a436abe4bf7/data/u.item -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | from pyspark import SparkContext, HiveContext 2 | from pyspark.mllib.recommendation import Rating 3 | from pyspark.sql import Row 4 | import numpy as np 5 | 6 | 7 | 8 | #load rating data 9 | 10 | def Rating_info(sc, file_path): 11 | rating = sc.textFile(file_path).map(lambda x:x.split('\t')).map(lambda x: Rating(x[0],x[1],x[2])) 12 | return rating 13 | 14 | def split_age(age): 15 | if age <= 20: 16 | return 0 17 | elif age <= 45: 18 | return 1 19 | elif age <= 60: 20 | return 2 21 | else: 22 | return 3 23 | def transform_gender(gender): 24 | if gender == 'F': 25 | return 0 26 | else: 27 | return 1 28 | 29 | # transform feature 30 | def transfromed_feature(leaf, num_leaf): 31 | transfrom_feature_matrix = np.zeros([len(leaf), len(leaf[0]) * num_leaf], dtype=np.int64) 32 | for i in range(len(leaf)): 33 | temp = np.arange(len(leaf[0])) * num_leaf - 1 + np.array(leaf[i]) 34 | transfrom_feature_matrix[i][temp] += 1 35 | return transfrom_feature_matrix 36 | 37 | #load user info data 38 | 39 | def user_info(sc, file_path): 40 | user_info=sc.textFile(file_path).map(lambda x: x.split('|')) 41 | user_info_df = sc.parallelize((Row(user_id=data[0], age=float(split_age(int(data[1]))), gender=transform_gender(data[2]))) for data in user_info.collect()).toDF() 42 | return user_info_df 43 | 44 | #load item info data 45 | 46 | def item_info(sc, file_path): 47 | item_info = sc.textFile(file_path).map(lambda line: line.split("|")) 48 | item_info_df = sc.parallelize((Row(item_id=data[0], movie_title=data[1], release_date=data[2], \ 49 | video_release_data=data[3], imdb_url=data[4], unknow=float(data[5]), \ 50 | action=float(data[6]), adventure=float(data[7]), animation=float(data[8]), \ 51 | childrens=float(data[9]), comedy=float(data[10]), \ 52 | crime=float(data[11]), documentary=float(data[12]), drama=float(data[13]), 53 | fantasy=float(data[14]), film_noir=float(data[15]), horror=float(data[5]), \ 54 | musical=float(data[16]), mystery=float(data[17]), romance=float(data[18]), \ 55 | sci_fi=float(data[19]), thriller=float(data[10]), war=float(data[21]), \ 56 | western=float(data[22]))) for data in item_info.collect()).toDF() 57 | return item_info_df 58 | 59 | #load sample 60 | 61 | 62 | def sample(sc, rating_file_path, user_file_path, item_file_path, k): 63 | item_info_df = item_info(sc, item_file_path) 64 | user_info_df = user_info(sc, user_file_path) 65 | num_item = range(item_info_df.count()) 66 | #pos example 67 | pos = sc.textFile(rating_file_path).map(lambda x:x.split('\t')) 68 | pos_sample = sc.parallelize((Row(user_id=data[0], item_id=data[1], label=float(1))) for data in pos.collect()).toDF() 69 | pos_user_item = [[int(data[0]), int(data[1])] for data in pos_sample.select(['user_id','item_id']).collect()] 70 | pos_user_item_dict = {} 71 | neg_sample = [] 72 | print("starting...") 73 | for data in pos_user_item: 74 | if data[0] not in pos_user_item_dict.keys(): 75 | pos_user_item_dict[data[0]] = [data[1]] 76 | else: 77 | pos_user_item_dict[data[0]].append(data[1]) 78 | for data in pos_user_item: 79 | i = 0 80 | while i