├── model ├── __init__.py ├── cf.py ├── prank.py └── lfm.py ├── workflow ├── __init__.py ├── cf_workflow.py ├── lfm_workflow.py └── prank_workflow.py ├── 前置需求.txt ├── 运行备注.txt ├── run_system.py ├── .gitignore ├── manage.py ├── preprocess.py ├── db.py └── excute.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /前置需求.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | numpy -------------------------------------------------------------------------------- /运行备注.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HyZhu39/recommendation/HEAD/运行备注.txt -------------------------------------------------------------------------------- /run_system.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | os.system("python db.py") 5 | arg = sys.argv[1] 6 | os.system("python excute.py preprocess") 7 | os.system("python excute.py excute "+arg) 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # project temp files 2 | *.pyc 3 | *.log 4 | _build/ 5 | temp/ 6 | data/ 7 | 8 | # pycharm temp files 9 | .idea/ 10 | __pycache__/ 11 | 12 | # mac temp files 13 | .DS_Store 14 | 15 | # wheel temp files 16 | build/ 17 | *.egg-info/ 18 | 19 | # Vim temp files 20 | *.swp 21 | *.swo 22 | -------------------------------------------------------------------------------- /workflow/cf_workflow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import os 4 | from model.cf import UserCf 5 | 6 | 7 | def run(): 8 | assert os.path.exists('data/ratings.csv'), \ 9 | 'File not exists in path, run preprocess.py before this.' 10 | print('Start..') 11 | start = time.time() 12 | movies = UserCf().calculate() 13 | #for movie in movies: 14 | #print(movie) 15 | print('Cost time: %f' % (time.time() - start)) 16 | return movies 17 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from preprocess import Channel 4 | from workflow.cf_workflow import run as user_cf 5 | from workflow.lfm_workflow import run as lfm 6 | from workflow.prank_workflow import run as prank 7 | 8 | 9 | def manage(): 10 | arg = sys.argv[1] 11 | if arg == 'preprocess': 12 | Channel().process() 13 | elif arg == 'cf': 14 | user_cf() 15 | elif arg == 'lfm': 16 | lfm() 17 | elif arg == 'prank': 18 | prank() 19 | else: 20 | print('Args must in ["preprocess", "cf", "lfm","prank"].') 21 | sys.exit() 22 | 23 | 24 | if __name__ == '__main__': 25 | manage() 26 | -------------------------------------------------------------------------------- /workflow/lfm_workflow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import os 4 | from model.lfm import LFM, Corpus 5 | 6 | #改动这里的user_id,即可改推荐的目标用户(函数参数里的userid) 7 | def run(userid): 8 | assert os.path.exists('data/ratings.csv'), \ 9 | 'File not exists in path, run preprocess.py before this.' 10 | print('Start..') 11 | start = time.time() 12 | if not os.path.exists('data/lfm_items.dict'): 13 | Corpus.pre_process() 14 | if not os.path.exists('data/lfm.model'): 15 | LFM().train() 16 | movies = LFM().predict(user_id=userid) 17 | #for movie in movies: 18 | #print(movie) 19 | print('Cost time: %f' % (time.time() - start)) 20 | return movies 21 | -------------------------------------------------------------------------------- /workflow/prank_workflow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | import os 4 | from model.prank import Graph, PersonalRank 5 | 6 | #改动这里的user_id,即可改推荐的目标用户 7 | def run(userid): 8 | assert os.path.exists('data/ratings.csv'), \ 9 | 'File not exists in path, run preprocess.py before this.' 10 | print('Start..') 11 | start = time.time() 12 | if not os.path.exists('data/prank.graph'): 13 | Graph.gen_graph() 14 | if not os.path.exists('data/prank_'+str(userid)+'.model'): 15 | PersonalRank().train(user_id=userid) 16 | movies = PersonalRank().predict(user_id=userid) 17 | #for movie in movies: 18 | #print(movie) 19 | print('Cost time: %f' % (time.time() - start)) 20 | return movies 21 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Origin resource from MovieLens: http://grouplens.org/datasets/movielens/1m 3 | import pandas as pd 4 | 5 | 6 | class Channel: 7 | """ 8 | simple processing for *.dat to *.csv 9 | """ 10 | 11 | def __init__(self): 12 | self.origin_path = 'data/{}' 13 | 14 | def process(self): 15 | print('Process user data...') 16 | self._process_user_data() 17 | print('Process movies data...') 18 | self._process_movies_date() 19 | print('Process rating data...') 20 | self._process_rating_data() 21 | print('End.') 22 | 23 | def _process_user_data(self, file='users.dat'): 24 | f = pd.read_table(self.origin_path.format(file), sep='::', engine='python', 25 | names=['userID', 'Gender', 'Age', 'Occupation', 'Zip-code']) 26 | f.to_csv(self.origin_path.format('users.csv'), index=False) 27 | 28 | def _process_rating_data(self, file='ratings.dat'): 29 | f = pd.read_table(self.origin_path.format(file), sep='::', engine='python', 30 | names=['UserID', 'MovieID', 'Rating', 'Timestamp']) 31 | f.to_csv(self.origin_path.format('ratings.csv'), index=False) 32 | 33 | def _process_movies_date(self, file='movies.dat'): 34 | f = pd.read_table(self.origin_path.format(file), sep='::', engine='python', 35 | names=['MovieID', 'Title', 'Genres']) 36 | f.to_csv(self.origin_path.format('movies.csv'), index=False) 37 | 38 | 39 | if __name__ == '__main__': 40 | Channel().process() 41 | -------------------------------------------------------------------------------- /db.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import pymysql 4 | import traceback 5 | 6 | # 打开数据库连接 7 | db = pymysql.connect(host='localhost',user='root',passwd='ascent',db='today',port=3306,charset='utf8') 8 | 9 | # 使用cursor()方法获取操作游标 10 | cursor = db.cursor() 11 | 12 | # SQL 查询语句 13 | sql = "SELECT * FROM users" 14 | sql2 = "SELECT * FROM orders" 15 | sql3 = "SELECT * FROM dishes" 16 | 17 | 18 | 19 | #['user_id', 'gender', 'age', 'occupation', 'zip']#设列名称 20 | #['user_id', 'movie_id', 'rating', 'timestamp'] 21 | #['movie_id', 'title', 'genres'] 22 | 23 | with open('data\\users.dat', 'w', encoding='utf-8') as f: 24 | # 数据库读取的特点是每读一次会把前一次的盖掉 25 | try: 26 | # 执行SQL语句 27 | cursor.execute(sql) 28 | # 获取所有记录列表 29 | results = cursor.fetchall() 30 | for row in results: 31 | userid = row[0] 32 | gender = row[5] 33 | age = row[6] 34 | occupation = row[8] 35 | zip = row[9] 36 | # 打印结果 37 | f.write(str(userid) + "::" + str(gender) + "::" + str(age) + "::" + str(occupation) + "::" + str(zip)+"\n") 38 | print("userid=%s,gender=%s,age=%s,occupation=%s,zip=%s" % \ 39 | (userid, gender, age, occupation, zip)) 40 | except: 41 | print("Error: unable to fetch data") 42 | msg = traceback.format_exc() # 方式1 43 | print(msg) 44 | 45 | 46 | 47 | with open('data\\ratings.dat', 'w', encoding='utf-8') as f2: 48 | try: 49 | # 执行SQL语句 50 | cursor.execute(sql2) 51 | # 获取所有记录列表 52 | results = cursor.fetchall() 53 | for row in results: 54 | userid = row[1] 55 | movieid = row[5] 56 | rating = row[4] 57 | timestamp = row[3] 58 | # 打印结果 59 | f2.write(str(userid)+"::"+str(movieid)+"::"+str(rating)+"::"+str(timestamp)+"\n") 60 | print("userid=%s,movieid=%s,rating=%s,timestamp=%s" % \ 61 | (userid, movieid, rating, timestamp)) 62 | except: 63 | print("Error: unable to fecth data") 64 | 65 | with open('data\\movies.dat', 'w', encoding='utf-8') as f3: 66 | try: 67 | # 执行SQL语句 68 | cursor.execute(sql3) 69 | # 获取所有记录列表 70 | results = cursor.fetchall() 71 | for row in results: 72 | movieid = row[0] 73 | title = row[2] 74 | genres = row[5] 75 | # 打印结果 76 | f3.write(str(movieid)+"::"+str(title)+"::"+str(genres)+"\n") 77 | print("movieid=%s,title=%s,genres=%s" % \ 78 | (movieid, title, genres)) 79 | except: 80 | print("Error: unable to fecth data") 81 | 82 | # 关闭数据库连接 83 | db.close() -------------------------------------------------------------------------------- /excute.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from preprocess import Channel 4 | from workflow.cf_workflow import run as user_cf 5 | from workflow.lfm_workflow import run as lfm 6 | from workflow.prank_workflow import run as prank 7 | import traceback 8 | 9 | def delete_substr_method2(in_str, in_substr): 10 | start_loc = in_str.find(in_substr) 11 | len_substr = len(in_substr) 12 | res_str = in_str[:start_loc] + in_str[start_loc + len_substr:] 13 | return res_str 14 | 15 | def manage(): 16 | arg = sys.argv[1] 17 | if arg == 'preprocess': 18 | Channel().process() 19 | elif arg == 'excute': 20 | target = int(sys.argv[2]) # 这里用运行参数代表用户id赋值给target 21 | result1 = user_cf() 22 | result2 = lfm(target) 23 | result3 = prank(target)#三个result的输出格式为[[str,浮点数],[str,浮点数]......] 24 | result1 = list(result1) 25 | result2 = list(result2) 26 | result3 = list(result3) 27 | result = [] 28 | count = 0 29 | 30 | for i in result2: 31 | for j in result3: 32 | for k in result1: 33 | i = list(i) 34 | j = list(j) 35 | k = list(k) 36 | if i[0] == delete_substr_method2(j[0],"item_") == k[0]: 37 | result.append(i[0]) 38 | i[1] = 100 39 | j[1] = 100 40 | k[1] = 100 41 | elif i[0] == delete_substr_method2(j[0],"item_") : 42 | result.append(i[0]) 43 | i[1] = 100 44 | j[1] = 100 45 | elif k[0] == i[0]: 46 | result.append(i[0]) 47 | k[1] = 100 48 | i[1] = 100 49 | elif k[0] == delete_substr_method2(j[0],"item_") : 50 | result.append(k[0]) 51 | k[1] = 100 52 | j[1] = 100 53 | else: 54 | for i in result2: 55 | for j in result3: 56 | for k in result1: 57 | if i[1]!=100: 58 | result.append(i[0]) 59 | if j[1]!=100: 60 | result.append(delete_substr_method2(j[0],"item_")) 61 | if k[1] != 100: 62 | result.append(k[0]) 63 | 64 | with open("recommand_result_"+sys.argv[2]+".txt",'w', encoding='utf-8') as f: 65 | for ok in result: 66 | if count < 10: 67 | print(ok) 68 | f.write(str(ok)+"\n") 69 | count = count+1 70 | else: 71 | print('Args must in ["preprocess", "excute"].') 72 | sys.exit() 73 | 74 | 75 | if __name__ == '__main__': 76 | target = 5 77 | manage() 78 | -------------------------------------------------------------------------------- /model/cf.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 -*- 2 | import math 3 | import pandas as pd 4 | 5 | 6 | class UserCf: 7 | 8 | def __init__(self): 9 | self.file_path = 'data/ratings.csv' 10 | self._init_frame() 11 | 12 | def _init_frame(self): 13 | self.frame = pd.read_csv(self.file_path) 14 | 15 | @staticmethod 16 | def _cosine_sim(target_movies, movies): 17 | ''' 18 | simple method for calculate cosine distance. 19 | e.g: x = [1 0 1 1 0], y = [0 1 1 0 1] 20 | cosine = (x1*y1+x2*y2+...) / [sqrt(x1^2+x2^2+...)+sqrt(y1^2+y2^2+...)] 21 | that means union_len(movies1, movies2) / sqrt(len(movies1)*len(movies2)) 22 | ''' 23 | union_len = len(set(target_movies) & set(movies)) 24 | if union_len == 0: return 0.0 25 | product = len(target_movies) * len(movies) 26 | cosine = union_len / math.sqrt(product) 27 | return cosine 28 | 29 | def _get_top_n_users(self, target_user_id, top_n): 30 | ''' 31 | calculate similarity between all users and return Top N similar users. 32 | ''' 33 | target_movies = self.frame[self.frame['UserID'] == target_user_id]['MovieID'] 34 | other_users_id = [i for i in set(self.frame['UserID']) if i != target_user_id] 35 | other_movies = [self.frame[self.frame['UserID'] == i]['MovieID'] for i in other_users_id] 36 | 37 | sim_list = [self._cosine_sim(target_movies, movies) for movies in other_movies] 38 | sim_list = sorted(zip(other_users_id, sim_list), key=lambda x: x[1], reverse=True) 39 | return sim_list[:top_n] 40 | 41 | def _get_candidates_items(self, target_user_id): 42 | """ 43 | Find all movies in source data and target_user did not meet before. 44 | """ 45 | target_user_movies = set(self.frame[self.frame['UserID'] == target_user_id]['MovieID']) 46 | other_user_movies = set(self.frame[self.frame['UserID'] != target_user_id]['MovieID']) 47 | candidates_movies = list(target_user_movies ^ other_user_movies) 48 | return candidates_movies 49 | 50 | def _get_top_n_items(self, top_n_users, candidates_movies, top_n): 51 | """ 52 | calculate interest of candidates movies and return top n movies. 53 | e.g. interest = sum(sim * normalize_rating) 54 | """ 55 | top_n_user_data = [self.frame[self.frame['UserID'] == k] for k, _ in top_n_users] 56 | interest_list = [] 57 | for movie_id in candidates_movies: 58 | tmp = [] 59 | for user_data in top_n_user_data: 60 | if movie_id in user_data['MovieID'].values: 61 | tmp.append(user_data[user_data['MovieID'] == movie_id]['Rating'].values[0]/5) 62 | else: 63 | tmp.append(0) 64 | interest = sum([top_n_users[i][1] * tmp[i] for i in range(len(top_n_users))]) 65 | interest_list.append((movie_id, interest)) 66 | interest_list = sorted(interest_list, key=lambda x: x[1], reverse=True) 67 | return interest_list[:top_n] 68 | 69 | def calculate(self, target_user_id=1, top_n=10): 70 | """ 71 | user-cf for movies recommendation. 72 | """ 73 | # most similar top n users 74 | top_n_users = self._get_top_n_users(target_user_id, top_n) 75 | # candidates movies for recommendation 76 | candidates_movies = self._get_candidates_items(target_user_id) 77 | # most interest top n movies 78 | top_n_movies = self._get_top_n_items(top_n_users, candidates_movies, top_n) 79 | return top_n_movies 80 | -------------------------------------------------------------------------------- /model/prank.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 -*- 2 | import pickle 3 | import pandas as pd 4 | 5 | 6 | class Graph:#推荐目标为user_id 7 | 8 | graph_path = 'data/prank.graph' 9 | 10 | @classmethod 11 | def _gen_user_graph(cls, user_id): 12 | print('Gen graph user: {}'.format(user_id)) 13 | item_ids = list(set(cls.frame[cls.frame['UserID'] == user_id]['MovieID'])) 14 | graph_dict = {'item_{}'.format(item_id): 1 for item_id in item_ids} 15 | return graph_dict 16 | 17 | @classmethod 18 | def _gen_item_graph(cls, item_id): 19 | print('Gen graph item: {}'.format(item_id)) 20 | user_ids = list(set(cls.frame[cls.frame['MovieID'] == item_id]['UserID'])) 21 | graph_dict = {'user_{}'.format(user_id): 1 for user_id in user_ids} 22 | return graph_dict 23 | 24 | @classmethod 25 | def gen_graph(cls): 26 | """ 27 | Gen graph.Each user,movie define as a node, and every movie rated by user means 28 | that there is a edge between user and movie, edge weight is 1 simply. 29 | """ 30 | file_path = 'data/ratings.csv' 31 | cls.frame = pd.read_csv(file_path) 32 | user_ids = list(set(cls.frame['UserID'])) 33 | item_ids = list(set(cls.frame['MovieID'])) 34 | cls.graph = {'user_{}'.format(user_id): cls._gen_user_graph(user_id) for user_id in user_ids} 35 | for item_id in item_ids: 36 | cls.graph['item_{}'.format(item_id)] = cls._gen_item_graph(item_id) 37 | cls.save() 38 | 39 | @classmethod 40 | def save(cls): 41 | f = open(cls.graph_path, 'wb') 42 | pickle.dump(cls.graph, f) 43 | f.close() 44 | 45 | @classmethod 46 | def load(cls): 47 | f = open(cls.graph_path, 'rb') 48 | graph = pickle.load(f) 49 | f.close() 50 | return graph 51 | 52 | 53 | class PersonalRank: 54 | 55 | def __init__(self): 56 | self.graph = Graph.load() 57 | self.alpha = 0.6 58 | self.iter_count = 20 59 | self._init_model() 60 | 61 | def _init_model(self): 62 | """ 63 | Initialize prob of every node, zero default. 64 | """ 65 | self.params = {k: 0 for k in self.graph.keys()} 66 | 67 | def train(self, user_id): 68 | """ 69 | For target user, every round will start at that node, means prob will be 1. 70 | And node will be updated by formula like: 71 | for each node, if node j have edge between i: 72 | prob_i_j = alpha * prob_j / edge_num_out_of_node_j 73 | then prob_i += prob_i_j 74 | alpha means the prob of continue walk. 75 | """ 76 | self.params['user_{}'.format(user_id)] = 1 77 | for count in range(self.iter_count): 78 | print('Step {}...'.format(count)) 79 | tmp = {k: 0 for k in self.graph.keys()} 80 | # edges mean all edge out of node 81 | for node, edges in self.graph.items(): 82 | for next_node, _ in edges.items(): 83 | # every edge come in next_node update prob 84 | tmp[next_node] += self.alpha * self.params[node] / len(edges) 85 | # root node. 86 | tmp['user_' + str(user_id)] += 1 - self.alpha 87 | self.params = tmp 88 | self.params = sorted(self.params.items(), key=lambda x: x[1], reverse=True) 89 | self.save(user_id) 90 | 91 | def predict(self, user_id, top_n=10): 92 | """ 93 | Return top n node without movie target user have been rated and other user. 94 | """ 95 | self.load(user_id) 96 | frame = pd.read_csv('data/ratings.csv') 97 | item_ids = ['item_' + str(item_id) for item_id in list(set(frame[frame['UserID'] == user_id]['MovieID']))] 98 | #item_ids = [str(item_id) for item_id in list(set(frame[frame['UserID'] == user_id]['MovieID']))] 99 | candidates = [(key, value) for key, value in self.params if key not in item_ids and 'user' not in key] 100 | return candidates[:top_n] 101 | 102 | def save(self, user_id): 103 | f = open('data/prank_{}.model'.format(user_id), 'wb') 104 | pickle.dump(self.params, f) 105 | f.close() 106 | 107 | def load(self, user_id): 108 | f = open('data/prank_{}.model'.format(user_id), 'rb') 109 | self.params = pickle.load(f) 110 | f.close() 111 | -------------------------------------------------------------------------------- /model/lfm.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 -*- 2 | import random 3 | import pickle 4 | import pandas as pd 5 | import numpy as np 6 | from math import exp 7 | 8 | 9 | class Corpus: 10 | 11 | items_dict_path = 'data/lfm_items.dict' 12 | 13 | @classmethod 14 | def pre_process(cls): 15 | file_path = 'data/ratings.csv' 16 | cls.frame = pd.read_csv(file_path) 17 | cls.user_ids = set(cls.frame['UserID'].values) 18 | cls.item_ids = set(cls.frame['MovieID'].values) 19 | cls.items_dict = {user_id: cls._get_pos_neg_item(user_id) for user_id in list(cls.user_ids)} 20 | cls.save() 21 | 22 | @classmethod 23 | def _get_pos_neg_item(cls, user_id): 24 | """ 25 | Define the pos and neg item for user. 26 | pos_item mean items that user have rating, and neg_item can be items 27 | that user never see before. 28 | Simple down sample method to solve unbalance sample. 29 | """ 30 | print('Process: {}'.format(user_id)) 31 | pos_item_ids = set(cls.frame[cls.frame['UserID'] == user_id]['MovieID']) 32 | neg_item_ids = cls.item_ids ^ pos_item_ids 33 | # neg_item_ids = [(item_id, len(self.frame[self.frame['MovieID'] == item_id]['UserID'])) for item_id in neg_item_ids] 34 | # neg_item_ids = sorted(neg_item_ids, key=lambda x: x[1], reverse=True) 35 | neg_item_ids = list(neg_item_ids)[:len(pos_item_ids)] 36 | item_dict = {} 37 | for item in pos_item_ids: item_dict[item] = 1 38 | for item in neg_item_ids: item_dict[item] = 0 39 | return item_dict 40 | 41 | @classmethod 42 | def save(cls): 43 | f = open(cls.items_dict_path, 'wb') 44 | pickle.dump(cls.items_dict, f) 45 | f.close() 46 | 47 | @classmethod 48 | def load(cls): 49 | f = open(cls.items_dict_path, 'rb') 50 | items_dict = pickle.load(f) 51 | f.close() 52 | return items_dict 53 | 54 | 55 | class LFM: 56 | 57 | def __init__(self): 58 | self.class_count = 5 59 | self.iter_count = 5 60 | self.lr = 0.02 61 | self.lam = 0.01 62 | self._init_model() 63 | 64 | def _init_model(self): 65 | """ 66 | Get corpus and initialize model params. 67 | """ 68 | file_path = 'data/ratings.csv' 69 | self.frame = pd.read_csv(file_path) 70 | self.user_ids = set(self.frame['UserID'].values) 71 | self.item_ids = set(self.frame['MovieID'].values) 72 | self.items_dict = Corpus.load() 73 | 74 | array_p = np.random.randn(len(self.user_ids), self.class_count) 75 | array_q = np.random.randn(len(self.item_ids), self.class_count) 76 | self.p = pd.DataFrame(array_p, columns=range(0, self.class_count), index=list(self.user_ids)) 77 | self.q = pd.DataFrame(array_q, columns=range(0, self.class_count), index=list(self.item_ids)) 78 | 79 | def _predict(self, user_id, item_id): 80 | """ 81 | Calculate interest between user_id and item_id. 82 | p is the look-up-table for user's interest of each class. 83 | q means the probability of each item being classified as each class. 84 | """ 85 | p = np.mat(self.p.ix[user_id].values) 86 | q = np.mat(self.q.ix[item_id].values).T 87 | r = (p * q).sum() 88 | logit = 1.0 / (1 + exp(-r)) 89 | return logit 90 | 91 | def _loss(self, user_id, item_id, y, step): 92 | """ 93 | Loss Function define as MSE, the code write here not that formula you think. 94 | """ 95 | e = y - self._predict(user_id, item_id) 96 | print('Step: {}, user_id: {}, item_id: {}, y: {}, loss: {}'. 97 | format(step, user_id, item_id, y, e)) 98 | return e 99 | 100 | def _optimize(self, user_id, item_id, e): 101 | """ 102 | Use SGD as optimizer, with L2 p, q square regular. 103 | e.g: E = 1/2 * (y - predict)^2, predict = matrix_p * matrix_q 104 | derivation(E, p) = -matrix_q*(y - predict), derivation(E, q) = -matrix_p*(y - predict), 105 | derivation(l2_square,p) = lam * p, derivation(l2_square, q) = lam * q 106 | delta_p = lr * (derivation(E, p) + derivation(l2_square,p)) 107 | delta_q = lr * (derivation(E, q) + derivation(l2_square, q)) 108 | """ 109 | gradient_p = -e * self.q.ix[item_id].values 110 | l2_p = self.lam * self.p.ix[user_id].values 111 | delta_p = self.lr * (gradient_p + l2_p) 112 | 113 | gradient_q = -e * self.p.ix[user_id].values 114 | l2_q = self.lam * self.q.ix[item_id].values 115 | delta_q = self.lr * (gradient_q + l2_q) 116 | 117 | self.p.loc[user_id] -= delta_p 118 | self.q.loc[item_id] -= delta_q 119 | 120 | def train(self): 121 | """ 122 | Train model. 123 | """ 124 | for step in range(0, self.iter_count): 125 | for user_id, item_dict in self.items_dict.items(): 126 | item_ids = list(item_dict.keys()) 127 | random.shuffle(item_ids) 128 | for item_id in item_ids: 129 | e = self._loss(user_id, item_id, item_dict[item_id], step) 130 | self._optimize(user_id, item_id, e) 131 | self.lr *= 0.9 132 | self.save() 133 | 134 | def predict(self, user_id, top_n=10): 135 | """ 136 | Calculate all item user have not meet before and return the top n interest items. 137 | """ 138 | self.load() 139 | user_item_ids = set(self.frame[self.frame['UserID'] == user_id]['MovieID']) 140 | other_item_ids = self.item_ids ^ user_item_ids 141 | interest_list = [self._predict(user_id, item_id) for item_id in other_item_ids] 142 | candidates = sorted(zip(list(other_item_ids), interest_list), key=lambda x: x[1], reverse=True) 143 | return candidates[:top_n] 144 | 145 | def save(self): 146 | """ 147 | Save model params. 148 | """ 149 | f = open('data/lfm.model', 'wb') 150 | pickle.dump((self.p, self.q), f) 151 | f.close() 152 | 153 | def load(self): 154 | """ 155 | Load model params. 156 | """ 157 | f = open('data/lfm.model', 'rb') 158 | self.p, self.q = pickle.load(f) 159 | f.close() 160 | --------------------------------------------------------------------------------