├── .gitattributes ├── data ├── .DS_Store ├── yelp │ ├── .DS_Store │ └── .ipynb_checkpoints │ │ └── preprocessing-checkpoint.ipynb ├── ml-1m │ └── .DS_Store └── VideoGame │ └── .DS_Store ├── README.md ├── test.py ├── data_preprocessor.py ├── JCA.py └── utility.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /data/yelp/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/yelp/.DS_Store -------------------------------------------------------------------------------- /data/ml-1m/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/ml-1m/.DS_Store -------------------------------------------------------------------------------- /data/VideoGame/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/VideoGame/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Joint-Collaborative-Autoencoder 2 | The implementation of our paper: 3 | 4 | Ziwei Zhu, Jianling Wang and James Caverlee. Improving Top-K Recommendation via Joint Collaborative Autoencoders. In Proceedings of WWW'19, San Francisco, May 13-17, 2019 5 | 6 | The implementation is based on Tensorflow. 7 | 8 | Author: Ziwei Zhu (zhuziwei@tamu.edu) 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ziwei Zhu 3 | Computer Science and Engineering Department, Texas A&M University 4 | zhuziwei@tamu.edu 5 | """ 6 | from data_preprocessor import * 7 | import tensorflow as tf 8 | import time 9 | import argparse 10 | import os 11 | from JCA import JCA 12 | 13 | if __name__ == '__main__': 14 | neg_sample_rate = 1 15 | 16 | date = time.strftime('%y-%m-%d', time.localtime()) 17 | current_time = time.strftime('%H:%M:%S', time.localtime()) 18 | data_name = 'ml-1m' 19 | base = 'u' 20 | 21 | parser = argparse.ArgumentParser(description='JCA') 22 | 23 | parser.add_argument('--train_epoch', type=int, default=200) 24 | parser.add_argument('--batch_size', type=int, default=1500) 25 | parser.add_argument('--display_step', type=int, default=1) 26 | parser.add_argument('--lr', type=float, default=0.003) 27 | parser.add_argument('--lambda_value', type=float, default=0.001) 28 | parser.add_argument('--margin', type=float, default=0.15) 29 | parser.add_argument('--optimizer_method', choices=['Adam', 'Adadelta', 'Adagrad', 'RMSProp', 'GradientDescent', 30 | 'Momentum'], default='Adam') 31 | parser.add_argument('--g_act', choices=['Sigmoid', 'Relu', 'Elu', 'Tanh', "Identity"], default='Sigmoid') 32 | parser.add_argument('--f_act', choices=['Sigmoid', 'Relu', 'Elu', 'Tanh', "Identity"], default='Sigmoid') 33 | parser.add_argument('--U_hidden_neuron', type=int, default=160) 34 | parser.add_argument('--I_hidden_neuron', type=int, default=160) 35 | parser.add_argument('--base', type=str, default=base) 36 | parser.add_argument('--neg_sample_rate', type=int, default=neg_sample_rate) 37 | args = parser.parse_args() 38 | 39 | sess = tf.Session() 40 | 41 | train_R, test_R = yelp.test() 42 | metric_path = './metric_results_test/' + date + '/' 43 | if not os.path.exists(metric_path): 44 | os.makedirs(metric_path) 45 | metric_path = metric_path + '/' + str(parser.description) + "_" + str(current_time) 46 | jca = JCA(sess, args, train_R, test_R, metric_path, date, data_name) 47 | jca.run(train_R, test_R) 48 | -------------------------------------------------------------------------------- /data_preprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ziwei Zhu 3 | Computer Science and Engineering Department, Texas A&M University 4 | zhuziwei@tamu.edu 5 | """ 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | class ml1m: 11 | def __init__(self): 12 | return 13 | 14 | @staticmethod 15 | def train(n): 16 | train_df = pd.read_csv('./data/ml-1m/train_%d.csv' % n) 17 | vali_df = pd.read_csv('./data/ml-1m/vali_%d.csv' % n) 18 | num_users = np.max(train_df['userId']) 19 | num_items = np.max(train_df['movieId']) 20 | 21 | train_R = np.zeros((num_users, num_items)) # training rating matrix 22 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix 23 | 24 | train_mat = train_df.values 25 | for i in range(len(train_df)): 26 | user_idx = int(train_mat[i, 0]) - 1 27 | item_idx = int(train_mat[i, 1]) - 1 28 | train_R[user_idx, item_idx] = 1 29 | 30 | vali_mat = vali_df.values 31 | for i in range(len(vali_df)): 32 | user_idx = int(vali_mat[i, 0]) - 1 33 | item_idx = int(vali_mat[i, 1]) - 1 34 | vali_R[user_idx, item_idx] = 1 35 | return train_R, vali_R 36 | 37 | @staticmethod 38 | def test(): 39 | test_df = pd.read_csv('./data/ml-1m/test.csv') 40 | num_users = np.max(test_df['userId']) 41 | num_items = np.max(test_df['movieId']) 42 | 43 | test_R = np.zeros((num_users, num_items)) # testing rating matrix 44 | 45 | test_mat = test_df.values 46 | for i in range(len(test_df)): 47 | user_idx = int(test_mat[i, 0]) - 1 48 | item_idx = int(test_mat[i, 1]) - 1 49 | test_R[user_idx, item_idx] = 1 50 | 51 | train_df = pd.read_csv('./data/ml-1m/train.csv') 52 | num_users = np.max(train_df['userId']) 53 | num_items = np.max(train_df['movieId']) 54 | 55 | train_R = np.zeros((num_users, num_items)) # testing rating matrix 56 | 57 | train_mat = train_df.values 58 | for i in range(len(train_df)): 59 | user_idx = int(train_mat[i, 0]) - 1 60 | item_idx = int(train_mat[i, 1]) - 1 61 | train_R[user_idx, item_idx] = 1 62 | train_R[user_idx, item_idx] = 1 63 | 64 | return train_R, test_R 65 | 66 | 67 | 68 | class yelp: 69 | def __init__(self): 70 | return 71 | 72 | @staticmethod 73 | def train(n): 74 | train_df = pd.read_csv('./data/yelp/train_%d.csv' % n) 75 | vali_df = pd.read_csv('./data/yelp/vali_%d.csv' % n) 76 | num_users = np.max(train_df['userId']) 77 | num_items = np.max(train_df['itemId']) 78 | 79 | train_R = np.zeros((num_users, num_items)) # training rating matrix 80 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix 81 | 82 | train_mat = train_df.values 83 | for i in range(len(train_df)): 84 | user_idx = int(train_mat[i, 0]) - 1 85 | item_idx = int(train_mat[i, 1]) - 1 86 | train_R[user_idx, item_idx] = 1 87 | 88 | vali_mat = vali_df.values 89 | for i in range(len(vali_df)): 90 | user_idx = int(vali_mat[i, 0]) - 1 91 | item_idx = int(vali_mat[i, 1]) - 1 92 | vali_R[user_idx, item_idx] = 1 93 | return train_R, vali_R 94 | 95 | @staticmethod 96 | def test(): 97 | test_df = pd.read_csv('./data/yelp/test.csv') 98 | num_users = np.max(test_df['userId']) 99 | num_items = np.max(test_df['itemId']) 100 | 101 | test_R = np.zeros((num_users, num_items)) # testing rating matrix 102 | 103 | test_mat = test_df.values 104 | for i in range(len(test_df)): 105 | user_idx = int(test_mat[i, 0]) - 1 106 | item_idx = int(test_mat[i, 1]) - 1 107 | test_R[user_idx, item_idx] = 1 108 | 109 | train_df = pd.read_csv('./data/yelp/train.csv') 110 | num_users = np.max(train_df['userId']) 111 | num_items = np.max(train_df['itemId']) 112 | 113 | train_R = np.zeros((num_users, num_items)) # testing rating matrix 114 | 115 | train_mat = train_df.values 116 | for i in range(len(train_df)): 117 | user_idx = int(train_mat[i, 0]) - 1 118 | item_idx = int(train_mat[i, 1]) - 1 119 | train_R[user_idx, item_idx] = 1 120 | train_R[user_idx, item_idx] = 1 121 | 122 | return train_R, test_R 123 | 124 | 125 | class VideoGame: 126 | def __init__(self): 127 | return 128 | 129 | @staticmethod 130 | def train(n): 131 | train_df = pd.read_csv('./data/VideoGame/train_%d.csv' % n) 132 | vali_df = pd.read_csv('./data/VideoGame/vali_%d.csv' % n) 133 | num_users = np.max(train_df['userId']) 134 | num_items = np.max(train_df['itemId']) 135 | 136 | train_R = np.zeros((num_users, num_items)) # training rating matrix 137 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix 138 | 139 | train_mat = train_df.values 140 | for i in range(len(train_df)): 141 | user_idx = int(train_mat[i, 0]) - 1 142 | item_idx = int(train_mat[i, 1]) - 1 143 | train_R[user_idx, item_idx] = 1 144 | 145 | vali_mat = vali_df.values 146 | for i in range(len(vali_df)): 147 | user_idx = int(vali_mat[i, 0]) - 1 148 | item_idx = int(vali_mat[i, 1]) - 1 149 | vali_R[user_idx, item_idx] = 1 150 | return train_R, vali_R 151 | 152 | @staticmethod 153 | def test(): 154 | test_df = pd.read_csv('./data/VideoGame/test.csv') 155 | num_users = np.max(test_df['userId']) 156 | num_items = np.max(test_df['itemId']) 157 | 158 | test_R = np.zeros((int(num_users), int(num_items))) # testing rating matrix 159 | 160 | test_mat = test_df.values 161 | for i in range(len(test_df)): 162 | user_idx = int(test_mat[i, 0]) - 1 163 | item_idx = int(test_mat[i, 1]) - 1 164 | test_R[user_idx, item_idx] = 1 165 | 166 | train_df = pd.read_csv('./data/VideoGame/train.csv') 167 | num_users = np.max(train_df['userId']) 168 | num_items = np.max(train_df['itemId']) 169 | 170 | train_R = np.zeros((int(num_users), int(num_items))) # testing rating matrix 171 | 172 | train_mat = train_df.values 173 | for i in range(len(train_df)): 174 | user_idx = int(train_mat[i, 0]) - 1 175 | item_idx = int(train_mat[i, 1]) - 1 176 | train_R[user_idx, item_idx] = 1 177 | train_R[user_idx, item_idx] = 1 178 | 179 | return train_R, test_R 180 | 181 | -------------------------------------------------------------------------------- /JCA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ziwei Zhu 3 | Computer Science and Engineering Department, Texas A&M University 4 | zhuziwei@tamu.edu 5 | """ 6 | import tensorflow as tf 7 | import time 8 | import numpy as np 9 | import os 10 | import matplotlib 11 | import copy 12 | import utility 13 | 14 | 15 | class JCA: 16 | 17 | def __init__(self, sess, args, train_R, vali_R, metric_path, date, data_name, 18 | result_path=None): 19 | 20 | if args.f_act == "Sigmoid": 21 | f_act = tf.nn.sigmoid 22 | elif args.f_act == "Relu": 23 | f_act = tf.nn.relu 24 | elif args.f_act == "Tanh": 25 | f_act = tf.nn.tanh 26 | elif args.f_act == "Identity": 27 | f_act = tf.identity 28 | elif args.f_act == "Elu": 29 | f_act = tf.nn.elu 30 | else: 31 | raise NotImplementedError("ERROR") 32 | 33 | if args.g_act == "Sigmoid": 34 | g_act = tf.nn.sigmoid 35 | elif args.g_act == "Relu": 36 | g_act = tf.nn.relu 37 | elif args.g_act == "Tanh": 38 | g_act = tf.nn.tanh 39 | elif args.g_act == "Identity": 40 | g_act = tf.identity 41 | elif args.g_act == "Elu": 42 | g_act = tf.nn.elu 43 | else: 44 | raise NotImplementedError("ERROR") 45 | 46 | self.sess = sess 47 | self.args = args 48 | 49 | self.base = args.base 50 | 51 | self.num_rows = train_R.shape[0] 52 | self.num_cols = train_R.shape[1] 53 | self.U_hidden_neuron = args.U_hidden_neuron 54 | self.I_hidden_neuron = args.I_hidden_neuron 55 | 56 | self.train_R = train_R 57 | self.vali_R = vali_R 58 | self.num_test_ratings = np.sum(vali_R) 59 | 60 | self.train_epoch = args.train_epoch 61 | self.batch_size = args.batch_size 62 | self.num_batch_U = int(self.num_rows / float(self.batch_size)) + 1 63 | self.num_batch_I = int(self.num_cols / float(self.batch_size)) + 1 64 | 65 | self.lr = args.lr # learning rate 66 | self.optimizer_method = args.optimizer_method 67 | self.display_step = args.display_step 68 | self.margin = args.margin 69 | 70 | self.f_act = f_act # the activation function for the output layer 71 | self.g_act = g_act # the activation function for the hidden layer 72 | 73 | self.global_step = tf.Variable(0, trainable=False) 74 | 75 | self.lambda_value = args.lambda_value # regularization term trade-off 76 | 77 | self.result_path = result_path 78 | self.metric_path = metric_path 79 | self.date = date # today's date 80 | self.data_name = data_name 81 | 82 | self.neg_sample_rate = args.neg_sample_rate 83 | self.U_OH_mat = np.eye(self.num_rows, dtype=float) 84 | self.I_OH_mat = np.eye(self.num_cols, dtype=float) 85 | 86 | print('**********JCA**********') 87 | print(self.args) 88 | self.prepare_model() 89 | 90 | def run(self, train_R, vali_R): 91 | self.train_R = train_R 92 | self.vali_R = vali_R 93 | init = tf.global_variables_initializer() 94 | self.sess.run(init) 95 | for epoch_itr in xrange(self.train_epoch): 96 | self.train_model(epoch_itr) 97 | if epoch_itr % 1 == 0: 98 | self.test_model(epoch_itr) 99 | return self.make_records() 100 | 101 | def prepare_model(self): 102 | 103 | # input rating vector 104 | self.input_R_U = tf.placeholder(dtype=tf.float32, shape=[None, self.num_cols], name="input_R_U") 105 | self.input_R_I = tf.placeholder(dtype=tf.float32, shape=[self.num_rows, None], name="input_R_I") 106 | self.input_OH_I = tf.placeholder(dtype=tf.float32, shape=[None, self.num_cols], name="input_OH_I") 107 | self.input_P_cor = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="input_P_cor") 108 | self.input_N_cor = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="input_N_cor") 109 | 110 | # input indicator vector indicator 111 | self.row_idx = tf.placeholder(dtype=tf.int32, shape=[None, 1], name="row_idx") 112 | self.col_idx = tf.placeholder(dtype=tf.int32, shape=[None, 1], name="col_idx") 113 | 114 | # user component 115 | # first layer weights 116 | UV = tf.get_variable(name="UV", initializer=tf.truncated_normal(shape=[self.num_cols, self.U_hidden_neuron], 117 | mean=0, stddev=0.03), dtype=tf.float32) 118 | # second layer weights 119 | UW = tf.get_variable(name="UW", initializer=tf.truncated_normal(shape=[self.U_hidden_neuron, self.num_cols], 120 | mean=0, stddev=0.03), dtype=tf.float32) 121 | # first layer bias 122 | Ub1 = tf.get_variable(name="Ub1", initializer=tf.truncated_normal(shape=[1, self.U_hidden_neuron], 123 | mean=0, stddev=0.03), dtype=tf.float32) 124 | # second layer bias 125 | Ub2 = tf.get_variable(name="Ub2", initializer=tf.truncated_normal(shape=[1, self.num_cols], 126 | mean=0, stddev=0.03), dtype=tf.float32) 127 | 128 | # item component 129 | # first layer weights 130 | IV = tf.get_variable(name="IV", initializer=tf.truncated_normal(shape=[self.num_rows, self.I_hidden_neuron], 131 | mean=0, stddev=0.03), dtype=tf.float32) 132 | # second layer weights 133 | IW = tf.get_variable(name="IW", initializer=tf.truncated_normal(shape=[self.I_hidden_neuron, self.num_rows], 134 | mean=0, stddev=0.03), dtype=tf.float32) 135 | # first layer bias 136 | Ib1 = tf.get_variable(name="Ib1", initializer=tf.truncated_normal(shape=[1, self.I_hidden_neuron], 137 | mean=0, stddev=0.03), dtype=tf.float32) 138 | # second layer bias 139 | Ib2 = tf.get_variable(name="Ib2", initializer=tf.truncated_normal(shape=[1, self.num_rows], 140 | mean=0, stddev=0.03), dtype=tf.float32) 141 | 142 | 143 | I_factor_vector = tf.get_variable(name="I_factor_vector", initializer=tf.random_uniform(shape=[1, self.num_cols]), 144 | dtype=tf.float32) 145 | 146 | # user component 147 | U_pre_Encoder = tf.matmul(self.input_R_U, UV) + Ub1 # input to the hidden layer 148 | self.U_Encoder = self.g_act(U_pre_Encoder) # output of the hidden layer 149 | U_pre_Decoder = tf.matmul(self.U_Encoder, UW) + Ub2 # input to the output layer 150 | self.U_Decoder = self.f_act(U_pre_Decoder) # output of the output layer 151 | 152 | # item component 153 | I_pre_mul = tf.transpose(tf.matmul(I_factor_vector, tf.transpose(self.input_OH_I))) 154 | I_pre_Encoder = tf.matmul(tf.transpose(self.input_R_I), IV) + Ib1 # input to the hidden layer 155 | self.I_Encoder = self.g_act(I_pre_Encoder * I_pre_mul) # output of the hidden layer 156 | I_pre_Decoder = tf.matmul(self.I_Encoder, IW) + Ib2 # input to the output layer 157 | self.I_Decoder = self.f_act(I_pre_Decoder) # output of the output layer 158 | 159 | # final output 160 | self.Decoder = ((tf.transpose(tf.gather_nd(tf.transpose(self.U_Decoder), self.col_idx))) 161 | + tf.gather_nd(tf.transpose(self.I_Decoder), self.row_idx)) / 2.0 162 | 163 | pos_data = tf.gather_nd(self.Decoder, self.input_P_cor) 164 | neg_data = tf.gather_nd(self.Decoder, self.input_N_cor) 165 | 166 | pre_cost1 = tf.maximum(neg_data - pos_data + self.margin, 167 | tf.zeros(tf.shape(neg_data)[0])) 168 | cost1 = tf.reduce_sum(pre_cost1) # prediction squared error 169 | pre_cost2 = tf.square(self.l2_norm(UW)) + tf.square(self.l2_norm(UV)) \ 170 | + tf.square(self.l2_norm(IW)) + tf.square(self.l2_norm(IV))\ 171 | + tf.square(self.l2_norm(Ib1)) + tf.square(self.l2_norm(Ib2))\ 172 | + tf.square(self.l2_norm(Ub1)) + tf.square(self.l2_norm(Ub2)) 173 | cost2 = self.lambda_value * 0.5 * pre_cost2 # regularization term 174 | 175 | self.cost = cost1 + cost2 # the loss function 176 | 177 | if self.optimizer_method == "Adam": 178 | optimizer = tf.train.AdamOptimizer(self.lr) 179 | elif self.optimizer_method == "Adadelta": 180 | optimizer = tf.train.AdadeltaOptimizer(self.lr) 181 | elif self.optimizer_method == "Adagrad": 182 | optimizer = tf.train.AdadeltaOptimizer(self.lr) 183 | elif self.optimizer_method == "RMSProp": 184 | optimizer = tf.train.RMSPropOptimizer(self.lr) 185 | elif self.optimizer_method == "GradientDescent": 186 | optimizer = tf.train.GradientDescentOptimizer(self.lr) 187 | elif self.optimizer_method == "Momentum": 188 | optimizer = tf.train.MomentumOptimizer(self.lr, 0.9) 189 | else: 190 | raise ValueError("Optimizer Key ERROR") 191 | 192 | gvs = optimizer.compute_gradients(self.cost) 193 | self.optimizer = optimizer.apply_gradients(gvs, global_step=self.global_step) 194 | 195 | def train_model(self, itr): 196 | start_time = time.time() 197 | random_row_idx = np.random.permutation(self.num_rows) # randomly permute the rows 198 | random_col_idx = np.random.permutation(self.num_cols) # randomly permute the cols 199 | batch_cost = 0 200 | ts = 0 201 | for i in xrange(self.num_batch_U): # iterate each batch 202 | if i == self.num_batch_U - 1: 203 | row_idx = random_row_idx[i * self.batch_size:] 204 | else: 205 | row_idx = random_row_idx[(i * self.batch_size):((i + 1) * self.batch_size)] 206 | for j in xrange(self.num_batch_I): 207 | # get the indices of the current batch 208 | if j == self.num_batch_I - 1: 209 | col_idx = random_col_idx[j * self.batch_size:] 210 | else: 211 | col_idx = random_col_idx[(j * self.batch_size):((j + 1) * self.batch_size)] 212 | ts1 = time.time() 213 | p_input, n_input = utility.pairwise_neg_sampling(self.train_R, row_idx, col_idx, self.neg_sample_rate) 214 | ts2 = time.time() 215 | ts += (ts2 - ts1) 216 | input_tmp = self.train_R[row_idx, :] 217 | input_tmp = input_tmp[:, col_idx] 218 | 219 | input_R_U = self.train_R[row_idx, :] 220 | input_R_I = self.train_R[:, col_idx] 221 | _, cost = self.sess.run( # do the optimization by the minibatch 222 | [self.optimizer, self.cost], 223 | feed_dict={ 224 | self.input_R_U: input_R_U, 225 | self.input_R_I: input_R_I, 226 | self.input_OH_I: self.I_OH_mat[col_idx, :], 227 | self.input_P_cor: p_input, 228 | self.input_N_cor: n_input, 229 | self.row_idx: np.reshape(row_idx, (len(row_idx), 1)), 230 | self.col_idx: np.reshape(col_idx, (len(col_idx), 1))}) 231 | batch_cost = batch_cost + cost 232 | 233 | if itr % self.display_step == 0: 234 | print ("Training //", "Epoch %d //" % itr, " Total cost = {:.2f}".format(batch_cost), 235 | "Elapsed time : %d sec //" % (time.time() - start_time), "Sampling time: %d s //" %(ts)) 236 | 237 | def test_model(self, itr): # calculate the cost and rmse of testing set in each epoch 238 | start_time = time.time() 239 | _, Decoder = self.sess.run([self.cost, self.Decoder], 240 | feed_dict={ 241 | self.input_R_U: self.train_R, 242 | self.input_R_I: self.train_R, 243 | self.input_OH_I: self.I_OH_mat, 244 | self.input_P_cor: [[0, 0]], 245 | self.input_N_cor: [[0, 0]], 246 | self.row_idx: np.reshape(xrange(self.num_rows), (self.num_rows, 1)), 247 | self.col_idx: np.reshape(xrange(self.num_cols), (self.num_cols, 1))}) 248 | if itr % self.display_step == 0: 249 | 250 | pre_numerator = np.multiply((Decoder - self.vali_R), self.vali_R) 251 | numerator = np.sum(np.square(pre_numerator)) 252 | denominator = self.num_test_ratings 253 | RMSE = np.sqrt(numerator / float(denominator)) 254 | 255 | if itr % 1 == 0: 256 | if self.base == 'i': 257 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder.T, self.vali_R.T, 258 | self.train_R.T) 259 | else: 260 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder, self.vali_R, self.train_R) 261 | 262 | print ( 263 | "Testing //", "Epoch %d //" % itr, " Total cost = {:.2f}".format(numerator), 264 | " RMSE = {:.5f}".format(RMSE), 265 | "Elapsed time : %d sec" % (time.time() - start_time)) 266 | print "=" * 100 267 | 268 | def make_records(self): # record all the results' details into files 269 | _, Decoder = self.sess.run([self.cost, self.Decoder], 270 | feed_dict={ 271 | self.input_R_U: self.train_R, 272 | self.input_R_I: self.train_R, 273 | self.input_OH_I: self.I_OH_mat, 274 | self.input_P_cor: [[0, 0]], 275 | self.input_N_cor: [[0, 0]], 276 | self.row_idx: np.reshape(xrange(self.num_rows), (self.num_rows, 1)), 277 | self.col_idx: np.reshape(xrange(self.num_cols), (self.num_cols, 1))}) 278 | if self.base == 'i': 279 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder.T, self.vali_R.T, self.train_R.T) 280 | else: 281 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder, self.vali_R, self.train_R) 282 | 283 | utility.metric_record(precision, recall, f_score, NDCG, self.args, self.metric_path) 284 | 285 | utility.test_model_factor(Decoder, self.vali_R, self.train_R) 286 | 287 | return precision, recall, f_score, NDCG 288 | 289 | @staticmethod 290 | def l2_norm(tensor): 291 | return tf.sqrt(tf.reduce_sum(tf.square(tensor))) 292 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ziwei Zhu 3 | Computer Science and Engineering Department, Texas A&M University 4 | zhuziwei@tamu.edu 5 | """ 6 | from __future__ import division 7 | 8 | from math import log 9 | import numpy as np 10 | import pandas as pd 11 | import copy 12 | from operator import itemgetter 13 | import time 14 | 15 | 16 | # calculate NDCG@k 17 | def NDCG_at_k(predicted_list, ground_truth, k): 18 | dcg_value = [(v / log(i + 1 + 1, 2)) for i, v in enumerate(predicted_list[:k])] 19 | dcg = np.sum(dcg_value) 20 | if len(ground_truth) < k: 21 | ground_truth += [0 for i in range(k - len(ground_truth))] 22 | idcg_value = [(v / log(i + 1 + 1, 2)) for i, v in enumerate(ground_truth[:k])] 23 | idcg = np.sum(idcg_value) 24 | return dcg / idcg 25 | 26 | 27 | # calculate precision@k, recall@k, NDCG@k, where k = 1,5,10,15 28 | def user_precision_recall_ndcg(new_user_prediction, test): 29 | dcg_list = [] 30 | 31 | # compute the number of true positive items at top k 32 | count_1, count_5, count_10, count_15 = 0, 0, 0, 0 33 | for i in xrange(15): 34 | if i == 0 and new_user_prediction[i][0] in test: 35 | count_1 = 1.0 36 | if i < 5 and new_user_prediction[i][0] in test: 37 | count_5 += 1.0 38 | if i < 10 and new_user_prediction[i][0] in test: 39 | count_10 += 1.0 40 | if new_user_prediction[i][0] in test: 41 | count_15 += 1.0 42 | dcg_list.append(1) 43 | else: 44 | dcg_list.append(0) 45 | 46 | # calculate NDCG@k 47 | idcg_list = [1 for i in range(len(test))] 48 | ndcg_tmp_1 = NDCG_at_k(dcg_list, idcg_list, 1) 49 | ndcg_tmp_5 = NDCG_at_k(dcg_list, idcg_list, 5) 50 | ndcg_tmp_10 = NDCG_at_k(dcg_list, idcg_list, 10) 51 | ndcg_tmp_15 = NDCG_at_k(dcg_list, idcg_list, 15) 52 | 53 | # precision@k 54 | precision_1 = count_1 55 | precision_5 = count_5 / 5.0 56 | precision_10 = count_10 / 10.0 57 | precision_15 = count_15 / 15.0 58 | 59 | l = len(test) 60 | if l == 0: 61 | l = 1 62 | # recall@k 63 | recall_1 = count_1 / l 64 | recall_5 = count_5 / l 65 | recall_10 = count_10 / l 66 | recall_15 = count_15 / l 67 | 68 | # return precision, recall, ndcg_tmp 69 | return np.array([precision_1, precision_5, precision_10, precision_15]),\ 70 | np.array([recall_1, recall_5, recall_10, recall_15]),\ 71 | np.array([ndcg_tmp_1, ndcg_tmp_5, ndcg_tmp_10, ndcg_tmp_15]) 72 | 73 | 74 | # calculate the metrics of the result 75 | def test_model_all(prediction, test_mask, train_mask): 76 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000 77 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000 78 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000 79 | precision = np.array([precision_1, precision_5, precision_10, precision_15]) 80 | recall = np.array([recall_1, recall_5, recall_10, recall_15]) 81 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15]) 82 | 83 | prediction = prediction + train_mask * -100000.0 84 | 85 | user_num = prediction.shape[0] 86 | for u in range(user_num): # iterate each user 87 | u_test = test_mask[u, :] 88 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set 89 | u_pred = prediction[u, :] 90 | 91 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:] 92 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T 93 | top15 = sorted(top15, key=itemgetter(1), reverse=True) 94 | 95 | # calculate the metrics 96 | if not len(u_test) == 0: 97 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test) 98 | precision += precision_u 99 | recall += recall_u 100 | ndcg += ndcg_u 101 | else: 102 | user_num -= 1 103 | 104 | # compute the average over all users 105 | precision /= user_num 106 | recall /= user_num 107 | ndcg /= user_num 108 | print 'precision_1\t[%.7f],\t||\t precision_5\t[%.7f],\t||\t precision_10\t[%.7f],\t||\t precision_15\t[%.7f]' \ 109 | % (precision[0], 110 | precision[1], 111 | precision[2], 112 | precision[3]) 113 | print 'recall_1 \t[%.7f],\t||\t recall_5 \t[%.7f],\t||\t recall_10 \t[%.7f],\t||\t recall_15 \t[%.7f]' \ 114 | % (recall[0], recall[1], 115 | recall[2], recall[3]) 116 | f_measure_1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0]) if not precision[0] + recall[0] == 0 else 0 117 | f_measure_5 = 2 * (precision[1] * recall[1]) / (precision[1] + recall[1]) if not precision[1] + recall[1] == 0 else 0 118 | f_measure_10 = 2 * (precision[2] * recall[2]) / (precision[2] + recall[2]) if not precision[2] + recall[2] == 0 else 0 119 | f_measure_15 = 2 * (precision[3] * recall[3]) / (precision[3] + recall[3]) if not precision[3] + recall[3] == 0 else 0 120 | print 'f_measure_1\t[%.7f],\t||\t f_measure_5\t[%.7f],\t||\t f_measure_10\t[%.7f],\t||\t f_measure_15\t[%.7f]' \ 121 | % (f_measure_1, 122 | f_measure_5, 123 | f_measure_10, 124 | f_measure_15) 125 | f_score = [f_measure_1, f_measure_5, f_measure_10, f_measure_15] 126 | print 'ndcg_1 \t[%.7f],\t||\t ndcg_5 \t[%.7f],\t||\t ndcg_10 \t[%.7f],\t||\t ndcg_15 \t[%.7f]' \ 127 | % (ndcg[0], 128 | ndcg[1], 129 | ndcg[2], 130 | ndcg[3]) 131 | return precision, recall, f_score, ndcg 132 | 133 | 134 | def metric_record(precision, recall, f_score, NDCG, args, metric_path): # record all the results' details into files 135 | path = metric_path + '.txt' 136 | 137 | with open(path, 'w') as f: 138 | f.write(str(args) + '\n') 139 | f.write('precision:' + str(precision) + '\n') 140 | f.write('recall:' + str(recall) + '\n') 141 | f.write('f score:' + str(f_score) + '\n') 142 | f.write('NDCG:' + str(NDCG) + '\n') 143 | f.write('\n') 144 | f.close() 145 | 146 | 147 | def get_train_instances(train_R, neg_sample_rate): 148 | """ 149 | genderate training dataset for NCF models in each iteration 150 | :param train_R: 151 | :param neg_sample_rate: 152 | :return: 153 | """ 154 | # randomly sample negative samples 155 | mask = neg_sampling(train_R, range(train_R.shape[0]), neg_sample_rate) 156 | 157 | user_input, item_input, labels = [], [], [] 158 | idx = np.array(np.where(mask == 1)) 159 | for i in range(idx.shape[1]): 160 | # positive instance 161 | u_i = idx[0, i] 162 | i_i = idx[1, i] 163 | user_input.append(u_i) 164 | item_input.append(i_i) 165 | labels.append(train_R[u_i, i_i]) 166 | return user_input, item_input, labels 167 | 168 | 169 | def neg_sampling(train_R, idx, neg_sample_rate): 170 | """ 171 | randomly negative smaples 172 | :param train_R: 173 | :param idx: 174 | :param neg_sample_rate: 175 | :return: 176 | """ 177 | num_cols = train_R.shape[1] 178 | num_rows = train_R.shape[0] 179 | # randomly sample negative samples 180 | mask = copy.copy(train_R) 181 | if neg_sample_rate == 0: 182 | return mask 183 | for b_idx in idx: 184 | mask_list = mask[b_idx, :] 185 | unobsv_list = np.where(mask_list == 0) 186 | unobsv_list = unobsv_list[0] # unobserved indices 187 | obsv_num = num_cols - len(unobsv_list) 188 | neg_num = int(obsv_num * neg_sample_rate) 189 | if neg_num > len(unobsv_list): # if the observed positive ratings are more than the half 190 | neg_num = len(unobsv_list) 191 | if neg_num == 0: 192 | neg_num = 1 193 | neg_samp_list = np.random.choice(unobsv_list, size=neg_num, replace=False) 194 | mask_list[neg_samp_list] = 1 195 | mask[b_idx, :] = mask_list 196 | return mask 197 | 198 | 199 | def pairwise_neg_sampling(train_R, r_idx, c_idx, neg_sample_rate): 200 | R = train_R[r_idx, :] 201 | R = R[:, c_idx] 202 | p_input, n_input = [], [] 203 | obsv_list = np.where(R == 1) 204 | 205 | unobsv_mat = [] 206 | for r in range(R.shape[0]): 207 | unobsv_list = np.where(R[r, :] == 0) 208 | unobsv_list = unobsv_list[0] 209 | unobsv_mat.append(unobsv_list) 210 | 211 | for i in range(len(obsv_list[1])): 212 | # positive instance 213 | u = obsv_list[0][i] 214 | # negative instances 215 | unobsv_list = unobsv_mat[u] 216 | neg_samp_list = np.random.choice(unobsv_list, size=neg_sample_rate, replace=False) 217 | for ns in neg_samp_list: 218 | p_input.append([u, obsv_list[1][i]]) 219 | n_input.append([u, ns]) 220 | # print('dataset size = ' + str(len(p_input))) 221 | return np.array(p_input), np.array(n_input) 222 | 223 | 224 | # calculate the metrics of the result 225 | def test_model_batch(prediction, test_mask, train_mask): 226 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000 227 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000 228 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000 229 | precision = np.array([precision_1, precision_5, precision_10, precision_15]) 230 | recall = np.array([recall_1, recall_5, recall_10, recall_15]) 231 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15]) 232 | 233 | prediction = prediction + train_mask * -100000.0 234 | 235 | user_num = prediction.shape[0] 236 | for u in range(user_num): # iterate each user 237 | u_test = test_mask[u, :] 238 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set 239 | u_pred = prediction[u, :] 240 | 241 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:] 242 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T 243 | top15 = sorted(top15, key=itemgetter(1), reverse=True) 244 | 245 | # calculate the metrics 246 | if not len(u_test) == 0: 247 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test) 248 | precision += precision_u 249 | recall += recall_u 250 | ndcg += ndcg_u 251 | else: 252 | user_num -= 1 253 | 254 | return precision, recall, ndcg 255 | 256 | 257 | # calculate the metrics of the result 258 | def test_model_cold_start(prediction, test_mask, train_mask): 259 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000 260 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000 261 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000 262 | precision = np.array([precision_1, precision_5, precision_10, precision_15]) 263 | recall = np.array([recall_1, recall_5, recall_10, recall_15]) 264 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15]) 265 | 266 | prediction = prediction + train_mask * -100000.0 267 | 268 | user_num = prediction.shape[0] 269 | n = 0 270 | for u in range(user_num): # iterate each user 271 | u_test = test_mask[u, :] 272 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set 273 | if len(u_test) > 10: 274 | continue 275 | u_pred = prediction[u, :] 276 | 277 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:] 278 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T 279 | top15 = sorted(top15, key=itemgetter(1), reverse=True) 280 | 281 | # calculate the metrics 282 | if not len(u_test) == 0: 283 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test) 284 | precision += precision_u 285 | recall += recall_u 286 | ndcg += ndcg_u 287 | n += 1 288 | 289 | # compute the average over all users 290 | precision /= n 291 | recall /= n 292 | ndcg /= n 293 | print 'precision_1\t[%.7f],\t||\t precision_5\t[%.7f],\t||\t precision_10\t[%.7f],\t||\t precision_15\t[%.7f]' \ 294 | % (precision[0], 295 | precision[1], 296 | precision[2], 297 | precision[3]) 298 | print 'recall_1 \t[%.7f],\t||\t recall_5 \t[%.7f],\t||\t recall_10 \t[%.7f],\t||\t recall_15 \t[%.7f]' \ 299 | % (recall[0], recall[1], 300 | recall[2], recall[3]) 301 | f_measure_1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0]) if not precision[0] + recall[0] == 0 else 0 302 | f_measure_5 = 2 * (precision[1] * recall[1]) / (precision[1] + recall[1]) if not precision[1] + recall[1] == 0 else 0 303 | f_measure_10 = 2 * (precision[2] * recall[2]) / (precision[2] + recall[2]) if not precision[2] + recall[2] == 0 else 0 304 | f_measure_15 = 2 * (precision[3] * recall[3]) / (precision[3] + recall[3]) if not precision[3] + recall[3] == 0 else 0 305 | print 'f_measure_1\t[%.7f],\t||\t f_measure_5\t[%.7f],\t||\t f_measure_10\t[%.7f],\t||\t f_measure_15\t[%.7f]' \ 306 | % (f_measure_1, 307 | f_measure_5, 308 | f_measure_10, 309 | f_measure_15) 310 | f_score = [f_measure_1, f_measure_5, f_measure_10, f_measure_15] 311 | print 'ndcg_1 \t[%.7f],\t||\t ndcg_5 \t[%.7f],\t||\t ndcg_10 \t[%.7f],\t||\t ndcg_15 \t[%.7f]' \ 312 | % (ndcg[0], 313 | ndcg[1], 314 | ndcg[2], 315 | ndcg[3]) 316 | return precision, recall, f_score, ndcg 317 | 318 | 319 | def test_model_factor(prediction, test_mask, train_mask): 320 | item_list = np.zeros(train_mask.shape[1]) 321 | item_list_rank = np.zeros(train_mask.shape[1]) 322 | 323 | prediction = prediction + train_mask * -100000.0 324 | 325 | user_num = prediction.shape[0] 326 | for u in range(user_num): # iterate each user 327 | u_test = test_mask[u, :] 328 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set 329 | len_u_test = len(u_test) 330 | u_pred = prediction[u, :] 331 | 332 | top10_item_idx_no_train = np.argpartition(u_pred, -10)[-10:] 333 | item_list[top10_item_idx_no_train] += 1 334 | for i in range(len(top10_item_idx_no_train)): 335 | item_list_rank[top10_item_idx_no_train[i]] += (10 - i) 336 | 337 | item_count = np.sum(train_mask, axis=0) 338 | df = pd.DataFrame({'item_pred_freq': item_list, 'item_count': item_count}) 339 | df.to_csv('data/no-factor' + time.strftime('%y-%m-%d-%H-%M-%S', time.localtime()) + '.csv') 340 | df = pd.DataFrame({'item_pred_rank': item_list_rank, 'item_count': item_count}) 341 | df.to_csv('data/rank-no-factor' + time.strftime('%y-%m-%d-%H-%M-%S', time.localtime()) + '.csv') 342 | -------------------------------------------------------------------------------- /data/yelp/.ipynb_checkpoints/preprocessing-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/html": [ 20 | "
\n", 21 | "\n", 34 | "\n", 35 | " \n", 36 | " \n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | "
userIditemIdratingtimestamp
0004.01329148800
1004.01337011200
2014.01335888000
3024.01379260800
4032.01367856000
\n", 82 | "
" 83 | ], 84 | "text/plain": [ 85 | " userId itemId rating timestamp\n", 86 | "0 0 0 4.0 1329148800\n", 87 | "1 0 0 4.0 1337011200\n", 88 | "2 0 1 4.0 1335888000\n", 89 | "3 0 2 4.0 1379260800\n", 90 | "4 0 3 2.0 1367856000" 91 | ] 92 | }, 93 | "execution_count": 2, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "data_df = pd.read_csv('./yelp.csv', sep=',')\n", 100 | "data_df.head()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 3, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/html": [ 111 | "
\n", 112 | "\n", 125 | "\n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | "
userIditemIdrating
0004.0
1004.0
2014.0
3024.0
4032.0
\n", 167 | "
" 168 | ], 169 | "text/plain": [ 170 | " userId itemId rating\n", 171 | "0 0 0 4.0\n", 172 | "1 0 0 4.0\n", 173 | "2 0 1 4.0\n", 174 | "3 0 2 4.0\n", 175 | "4 0 3 2.0" 176 | ] 177 | }, 178 | "execution_count": 3, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "data_df.drop('timestamp', axis=1, inplace=True)\n", 185 | "data_df.head()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 46, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "data": { 195 | "text/plain": [ 196 | "731671" 197 | ] 198 | }, 199 | "execution_count": 46, 200 | "metadata": {}, 201 | "output_type": "execute_result" 202 | } 203 | ], 204 | "source": [ 205 | "len(data_df)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "# data_df['mean'] = data_df.groupby('userId')['rating'].transform('mean')" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 4, 220 | "metadata": { 221 | "scrolled": true 222 | }, 223 | "outputs": [ 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "0\n", 229 | "100000\n", 230 | "200000\n", 231 | "300000\n", 232 | "400000\n", 233 | "500000\n", 234 | "600000\n", 235 | "700000\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "for i in range(len(data_df)):\n", 241 | " if data_df.at[i, 'rating'] > 3:\n", 242 | " data_df.at[i, 'rating'] = 1\n", 243 | " else:\n", 244 | " data_df.at[i, 'rating'] = 0\n", 245 | " if i % 100000 == 0:\n", 246 | " print i" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 7, 252 | "metadata": { 253 | "scrolled": true 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "data_df['userId'] = data_df['userId'] + 1\n", 258 | "data_df['itemId'] = data_df['itemId'] + 1" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 47, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "df = data_df.copy()" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 48, 273 | "metadata": { 274 | "scrolled": true 275 | }, 276 | "outputs": [ 277 | { 278 | "data": { 279 | "text/html": [ 280 | "
\n", 281 | "\n", 294 | "\n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | "
userIditemIdrating
0111.0
1111.0
2121.0
3131.0
7171.0
8181.0
101101.0
121121.0
151151.0
161161.0
171171.0
211211.0
221221.0
231231.0
241241.0
271271.0
291291.0
301301.0
311311.0
321321.0
331331.0
341341.0
351351.0
361361.0
391391.0
401401.0
411411.0
421421.0
431421.0
451441.0
............
7316412567591091.0
7316422567589171.0
73164325675186291.0
7316442567525781.0
73164525675225571.0
73164625675202921.0
73164725675225551.0
73164825675241681.0
73164925675242551.0
73165025675246931.0
7316512567688501.0
7316522567619461.0
73165325676208371.0
73165425676208371.0
7316552567618401.0
7316562567660331.0
7316572567630171.0
73165825676200161.0
7316592567630421.0
73166025676192561.0
7316612567754491.0
73166225677148391.0
73166325677112331.0
731664256771661.0
7316652567785231.0
73166625677220871.0
7316672567756151.0
73166825677182231.0
7316692567754531.0
7316702567715891.0
\n", 672 | "

486499 rows × 3 columns

\n", 673 | "
" 674 | ], 675 | "text/plain": [ 676 | " userId itemId rating\n", 677 | "0 1 1 1.0\n", 678 | "1 1 1 1.0\n", 679 | "2 1 2 1.0\n", 680 | "3 1 3 1.0\n", 681 | "7 1 7 1.0\n", 682 | "8 1 8 1.0\n", 683 | "10 1 10 1.0\n", 684 | "12 1 12 1.0\n", 685 | "15 1 15 1.0\n", 686 | "16 1 16 1.0\n", 687 | "17 1 17 1.0\n", 688 | "21 1 21 1.0\n", 689 | "22 1 22 1.0\n", 690 | "23 1 23 1.0\n", 691 | "24 1 24 1.0\n", 692 | "27 1 27 1.0\n", 693 | "29 1 29 1.0\n", 694 | "30 1 30 1.0\n", 695 | "31 1 31 1.0\n", 696 | "32 1 32 1.0\n", 697 | "33 1 33 1.0\n", 698 | "34 1 34 1.0\n", 699 | "35 1 35 1.0\n", 700 | "36 1 36 1.0\n", 701 | "39 1 39 1.0\n", 702 | "40 1 40 1.0\n", 703 | "41 1 41 1.0\n", 704 | "42 1 42 1.0\n", 705 | "43 1 42 1.0\n", 706 | "45 1 44 1.0\n", 707 | "... ... ... ...\n", 708 | "731641 25675 9109 1.0\n", 709 | "731642 25675 8917 1.0\n", 710 | "731643 25675 18629 1.0\n", 711 | "731644 25675 2578 1.0\n", 712 | "731645 25675 22557 1.0\n", 713 | "731646 25675 20292 1.0\n", 714 | "731647 25675 22555 1.0\n", 715 | "731648 25675 24168 1.0\n", 716 | "731649 25675 24255 1.0\n", 717 | "731650 25675 24693 1.0\n", 718 | "731651 25676 8850 1.0\n", 719 | "731652 25676 1946 1.0\n", 720 | "731653 25676 20837 1.0\n", 721 | "731654 25676 20837 1.0\n", 722 | "731655 25676 1840 1.0\n", 723 | "731656 25676 6033 1.0\n", 724 | "731657 25676 3017 1.0\n", 725 | "731658 25676 20016 1.0\n", 726 | "731659 25676 3042 1.0\n", 727 | "731660 25676 19256 1.0\n", 728 | "731661 25677 5449 1.0\n", 729 | "731662 25677 14839 1.0\n", 730 | "731663 25677 11233 1.0\n", 731 | "731664 25677 166 1.0\n", 732 | "731665 25677 8523 1.0\n", 733 | "731666 25677 22087 1.0\n", 734 | "731667 25677 5615 1.0\n", 735 | "731668 25677 18223 1.0\n", 736 | "731669 25677 5453 1.0\n", 737 | "731670 25677 1589 1.0\n", 738 | "\n", 739 | "[486499 rows x 3 columns]" 740 | ] 741 | }, 742 | "execution_count": 48, 743 | "metadata": {}, 744 | "output_type": "execute_result" 745 | } 746 | ], 747 | "source": [ 748 | "df.drop(df.index[df['rating'] == 0], axis=0, inplace=True)\n", 749 | "df" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 49, 755 | "metadata": {}, 756 | "outputs": [], 757 | "source": [ 758 | "# df.drop('mean', axis=1, inplace=True)\n", 759 | "df.drop('rating', axis=1, inplace=True)" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 50, 765 | "metadata": { 766 | "scrolled": true 767 | }, 768 | "outputs": [ 769 | { 770 | "data": { 771 | "text/plain": [ 772 | "1115 819\n", 773 | "1252 809\n", 774 | "1864 613\n", 775 | "42 600\n", 776 | "151 570\n", 777 | "3309 553\n", 778 | "4149 550\n", 779 | "2253 531\n", 780 | "1901 524\n", 781 | "2954 517\n", 782 | "12 501\n", 783 | "106 497\n", 784 | "767 489\n", 785 | "4500 468\n", 786 | "9 468\n", 787 | "1447 460\n", 788 | "2324 458\n", 789 | "3923 448\n", 790 | "2111 438\n", 791 | "2730 431\n", 792 | "3822 413\n", 793 | "103 405\n", 794 | "1854 398\n", 795 | "3992 394\n", 796 | "1970 394\n", 797 | "621 393\n", 798 | "615 392\n", 799 | "510 388\n", 800 | "512 376\n", 801 | "829 374\n", 802 | " ... \n", 803 | "13373 1\n", 804 | "22115 1\n", 805 | "24162 1\n", 806 | "24602 1\n", 807 | "1129 1\n", 808 | "25290 1\n", 809 | "7109 1\n", 810 | "23501 1\n", 811 | "24094 1\n", 812 | "13504 1\n", 813 | "19663 1\n", 814 | "19660 1\n", 815 | "3130 1\n", 816 | "16238 1\n", 817 | "22150 1\n", 818 | "19811 1\n", 819 | "12861 1\n", 820 | "21710 1\n", 821 | "9664 1\n", 822 | "19555 1\n", 823 | "11119 1\n", 824 | "4968 1\n", 825 | "21199 1\n", 826 | "14274 1\n", 827 | "24527 1\n", 828 | "13072 1\n", 829 | "18975 1\n", 830 | "16926 1\n", 831 | "21088 1\n", 832 | "15213 1\n", 833 | "Name: itemId, Length: 24930, dtype: int64" 834 | ] 835 | }, 836 | "execution_count": 50, 837 | "metadata": {}, 838 | "output_type": "execute_result" 839 | } 840 | ], 841 | "source": [ 842 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 843 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n", 844 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n", 845 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n", 846 | "df['itemId'].value_counts()" 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": 51, 852 | "metadata": { 853 | "scrolled": true 854 | }, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "3730 596\n", 860 | "1420 580\n", 861 | "13384 569\n", 862 | "13063 556\n", 863 | "12942 501\n", 864 | "1447 490\n", 865 | "3782 466\n", 866 | "1805 445\n", 867 | "13232 445\n", 868 | "3555 445\n", 869 | "4478 413\n", 870 | "2236 406\n", 871 | "3609 394\n", 872 | "2176 387\n", 873 | "363 366\n", 874 | "4870 337\n", 875 | "5793 325\n", 876 | "14002 322\n", 877 | "16099 305\n", 878 | "4041 298\n", 879 | "14302 293\n", 880 | "3324 285\n", 881 | "4246 285\n", 882 | "4468 282\n", 883 | "3862 277\n", 884 | "1944 275\n", 885 | "9384 274\n", 886 | "2975 274\n", 887 | "22 271\n", 888 | "13880 265\n", 889 | " ... \n", 890 | "2116 3\n", 891 | "18799 3\n", 892 | "25145 3\n", 893 | "18962 3\n", 894 | "2704 3\n", 895 | "6877 3\n", 896 | "7762 3\n", 897 | "2529 3\n", 898 | "19199 3\n", 899 | "25618 3\n", 900 | "18829 3\n", 901 | "25333 3\n", 902 | "13455 3\n", 903 | "25051 3\n", 904 | "10556 3\n", 905 | "11482 2\n", 906 | "24397 2\n", 907 | "11881 2\n", 908 | "25219 2\n", 909 | "12712 2\n", 910 | "25274 2\n", 911 | "10963 2\n", 912 | "1879 2\n", 913 | "17221 2\n", 914 | "8924 2\n", 915 | "25488 2\n", 916 | "10506 2\n", 917 | "19117 1\n", 918 | "25564 1\n", 919 | "1934 1\n", 920 | "Name: userId, Length: 16066, dtype: int64" 921 | ] 922 | }, 923 | "execution_count": 51, 924 | "metadata": {}, 925 | "output_type": "execute_result" 926 | } 927 | ], 928 | "source": [ 929 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n", 930 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 931 | "df['userId'].value_counts()" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": 52, 937 | "metadata": { 938 | "scrolled": true 939 | }, 940 | "outputs": [ 941 | { 942 | "data": { 943 | "text/plain": [ 944 | "1252 774\n", 945 | "1115 761\n", 946 | "1864 583\n", 947 | "42 579\n", 948 | "151 548\n", 949 | "3309 527\n", 950 | "4149 525\n", 951 | "2253 509\n", 952 | "2954 497\n", 953 | "1901 492\n", 954 | "106 477\n", 955 | "12 475\n", 956 | "767 472\n", 957 | "9 448\n", 958 | "2324 438\n", 959 | "4500 435\n", 960 | "1447 435\n", 961 | "3923 431\n", 962 | "2111 424\n", 963 | "2730 404\n", 964 | "3822 402\n", 965 | "1854 388\n", 966 | "3992 386\n", 967 | "103 385\n", 968 | "615 378\n", 969 | "1970 377\n", 970 | "621 371\n", 971 | "510 366\n", 972 | "512 359\n", 973 | "829 355\n", 974 | " ... \n", 975 | "13381 6\n", 976 | "7042 6\n", 977 | "13678 6\n", 978 | "16639 6\n", 979 | "13223 6\n", 980 | "19481 6\n", 981 | "12670 6\n", 982 | "13092 6\n", 983 | "2550 6\n", 984 | "1671 6\n", 985 | "14907 6\n", 986 | "22576 6\n", 987 | "6843 6\n", 988 | "5984 6\n", 989 | "20163 6\n", 990 | "2359 6\n", 991 | "15841 6\n", 992 | "11745 6\n", 993 | "13587 5\n", 994 | "1637 5\n", 995 | "19117 5\n", 996 | "13374 5\n", 997 | "24956 5\n", 998 | "3065 5\n", 999 | "6304 5\n", 1000 | "13349 5\n", 1001 | "4306 5\n", 1002 | "24966 4\n", 1003 | "20899 4\n", 1004 | "13388 1\n", 1005 | "Name: itemId, Length: 10112, dtype: int64" 1006 | ] 1007 | }, 1008 | "execution_count": 52, 1009 | "metadata": {}, 1010 | "output_type": "execute_result" 1011 | } 1012 | ], 1013 | "source": [ 1014 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n", 1015 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 1016 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n", 1017 | "df['itemId'].value_counts()" 1018 | ] 1019 | }, 1020 | { 1021 | "cell_type": "code", 1022 | "execution_count": 53, 1023 | "metadata": { 1024 | "scrolled": true 1025 | }, 1026 | "outputs": [ 1027 | { 1028 | "data": { 1029 | "text/plain": [ 1030 | "3730 584\n", 1031 | "1420 573\n", 1032 | "13384 559\n", 1033 | "13063 548\n", 1034 | "12942 499\n", 1035 | "1447 484\n", 1036 | "3782 450\n", 1037 | "1805 442\n", 1038 | "13232 442\n", 1039 | "3555 439\n", 1040 | "4478 407\n", 1041 | "2236 402\n", 1042 | "3609 385\n", 1043 | "2176 384\n", 1044 | "363 365\n", 1045 | "4870 333\n", 1046 | "5793 325\n", 1047 | "14002 318\n", 1048 | "16099 304\n", 1049 | "4041 291\n", 1050 | "14302 290\n", 1051 | "4246 284\n", 1052 | "3324 281\n", 1053 | "4468 276\n", 1054 | "1944 273\n", 1055 | "3862 271\n", 1056 | "2975 271\n", 1057 | "9384 270\n", 1058 | "13880 264\n", 1059 | "22 262\n", 1060 | " ... \n", 1061 | "467 8\n", 1062 | "17192 8\n", 1063 | "8484 8\n", 1064 | "9089 8\n", 1065 | "1106 8\n", 1066 | "8190 8\n", 1067 | "2997 8\n", 1068 | "3053 8\n", 1069 | "25089 8\n", 1070 | "3069 7\n", 1071 | "3035 7\n", 1072 | "3205 7\n", 1073 | "25341 7\n", 1074 | "25223 7\n", 1075 | "3141 7\n", 1076 | "25113 7\n", 1077 | "25175 7\n", 1078 | "25311 7\n", 1079 | "3203 7\n", 1080 | "2639 7\n", 1081 | "3097 7\n", 1082 | "3025 7\n", 1083 | "3138 7\n", 1084 | "2037 7\n", 1085 | "24994 6\n", 1086 | "25391 6\n", 1087 | "1131 6\n", 1088 | "3132 6\n", 1089 | "3146 4\n", 1090 | "25119 4\n", 1091 | "Name: userId, Length: 13099, dtype: int64" 1092 | ] 1093 | }, 1094 | "execution_count": 53, 1095 | "metadata": {}, 1096 | "output_type": "execute_result" 1097 | } 1098 | ], 1099 | "source": [ 1100 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n", 1101 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 1102 | "df['userId'].value_counts()" 1103 | ] 1104 | }, 1105 | { 1106 | "cell_type": "code", 1107 | "execution_count": 64, 1108 | "metadata": { 1109 | "scrolled": true 1110 | }, 1111 | "outputs": [ 1112 | { 1113 | "data": { 1114 | "text/plain": [ 1115 | "1252 769\n", 1116 | "1115 756\n", 1117 | "1864 581\n", 1118 | "42 578\n", 1119 | "151 548\n", 1120 | "3309 525\n", 1121 | "4149 523\n", 1122 | "2253 507\n", 1123 | "2954 492\n", 1124 | "1901 489\n", 1125 | "106 472\n", 1126 | "767 471\n", 1127 | "12 470\n", 1128 | "9 445\n", 1129 | "2324 438\n", 1130 | "1447 434\n", 1131 | "4500 430\n", 1132 | "3923 428\n", 1133 | "2111 422\n", 1134 | "3822 400\n", 1135 | "2730 400\n", 1136 | "1854 388\n", 1137 | "3992 384\n", 1138 | "103 382\n", 1139 | "1970 377\n", 1140 | "615 373\n", 1141 | "621 369\n", 1142 | "510 366\n", 1143 | "512 358\n", 1144 | "2016 354\n", 1145 | " ... \n", 1146 | "13621 10\n", 1147 | "830 10\n", 1148 | "6679 10\n", 1149 | "9488 10\n", 1150 | "17460 10\n", 1151 | "17716 10\n", 1152 | "5620 10\n", 1153 | "11971 10\n", 1154 | "20463 10\n", 1155 | "12089 10\n", 1156 | "10296 10\n", 1157 | "788 10\n", 1158 | "15488 10\n", 1159 | "903 10\n", 1160 | "14272 10\n", 1161 | "8387 10\n", 1162 | "14208 10\n", 1163 | "20020 10\n", 1164 | "14611 10\n", 1165 | "24493 10\n", 1166 | "13536 10\n", 1167 | "7141 10\n", 1168 | "14139 10\n", 1169 | "8355 10\n", 1170 | "10979 10\n", 1171 | "6596 10\n", 1172 | "18743 10\n", 1173 | "2367 10\n", 1174 | "16950 10\n", 1175 | "2824 9\n", 1176 | "Name: itemId, Length: 9245, dtype: int64" 1177 | ] 1178 | }, 1179 | "execution_count": 64, 1180 | "metadata": {}, 1181 | "output_type": "execute_result" 1182 | } 1183 | ], 1184 | "source": [ 1185 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n", 1186 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 1187 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n", 1188 | "df['itemId'].value_counts()" 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "code", 1193 | "execution_count": 65, 1194 | "metadata": { 1195 | "scrolled": true 1196 | }, 1197 | "outputs": [ 1198 | { 1199 | "data": { 1200 | "text/plain": [ 1201 | "3730 584\n", 1202 | "1420 572\n", 1203 | "13384 557\n", 1204 | "13063 548\n", 1205 | "12942 497\n", 1206 | "1447 482\n", 1207 | "3782 448\n", 1208 | "13232 442\n", 1209 | "1805 442\n", 1210 | "3555 439\n", 1211 | "4478 405\n", 1212 | "2236 401\n", 1213 | "3609 385\n", 1214 | "2176 384\n", 1215 | "363 364\n", 1216 | "4870 332\n", 1217 | "5793 324\n", 1218 | "14002 318\n", 1219 | "16099 303\n", 1220 | "4041 291\n", 1221 | "14302 288\n", 1222 | "4246 284\n", 1223 | "3324 279\n", 1224 | "4468 275\n", 1225 | "1944 273\n", 1226 | "2975 271\n", 1227 | "3862 270\n", 1228 | "9384 269\n", 1229 | "13880 264\n", 1230 | "22 262\n", 1231 | " ... \n", 1232 | "9240 10\n", 1233 | "14768 10\n", 1234 | "12115 10\n", 1235 | "10322 10\n", 1236 | "6235 10\n", 1237 | "22611 10\n", 1238 | "24336 10\n", 1239 | "13488 10\n", 1240 | "11443 10\n", 1241 | "19694 10\n", 1242 | "1462 10\n", 1243 | "11490 10\n", 1244 | "21948 10\n", 1245 | "14363 10\n", 1246 | "19438 10\n", 1247 | "2077 10\n", 1248 | "11346 10\n", 1249 | "23997 10\n", 1250 | "19294 10\n", 1251 | "11090 10\n", 1252 | "5812 10\n", 1253 | "16991 10\n", 1254 | "14256 10\n", 1255 | "10162 10\n", 1256 | "4437 10\n", 1257 | "22460 10\n", 1258 | "183 10\n", 1259 | "24763 10\n", 1260 | "22716 10\n", 1261 | "24704 10\n", 1262 | "Name: userId, Length: 12704, dtype: int64" 1263 | ] 1264 | }, 1265 | "execution_count": 65, 1266 | "metadata": {}, 1267 | "output_type": "execute_result" 1268 | } 1269 | ], 1270 | "source": [ 1271 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n", 1272 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n", 1273 | "df['userId'].value_counts()" 1274 | ] 1275 | }, 1276 | { 1277 | "cell_type": "code", 1278 | "execution_count": 66, 1279 | "metadata": { 1280 | "scrolled": true 1281 | }, 1282 | "outputs": [ 1283 | { 1284 | "data": { 1285 | "text/plain": [ 1286 | "3730 584\n", 1287 | "1420 572\n", 1288 | "13384 557\n", 1289 | "13063 548\n", 1290 | "12942 497\n", 1291 | "1447 482\n", 1292 | "3782 448\n", 1293 | "13232 442\n", 1294 | "1805 442\n", 1295 | "3555 439\n", 1296 | "4478 405\n", 1297 | "2236 401\n", 1298 | "3609 385\n", 1299 | "2176 384\n", 1300 | "363 364\n", 1301 | "4870 332\n", 1302 | "5793 324\n", 1303 | "14002 318\n", 1304 | "16099 303\n", 1305 | "4041 291\n", 1306 | "14302 288\n", 1307 | "4246 284\n", 1308 | "3324 279\n", 1309 | "4468 275\n", 1310 | "1944 273\n", 1311 | "2975 271\n", 1312 | "3862 270\n", 1313 | "9384 269\n", 1314 | "13880 264\n", 1315 | "22 262\n", 1316 | " ... \n", 1317 | "9240 10\n", 1318 | "14768 10\n", 1319 | "12115 10\n", 1320 | "10322 10\n", 1321 | "6235 10\n", 1322 | "22611 10\n", 1323 | "24336 10\n", 1324 | "13488 10\n", 1325 | "11443 10\n", 1326 | "19694 10\n", 1327 | "1462 10\n", 1328 | "11490 10\n", 1329 | "21948 10\n", 1330 | "14363 10\n", 1331 | "19438 10\n", 1332 | "2077 10\n", 1333 | "11346 10\n", 1334 | "23997 10\n", 1335 | "19294 10\n", 1336 | "11090 10\n", 1337 | "5812 10\n", 1338 | "16991 10\n", 1339 | "14256 10\n", 1340 | "10162 10\n", 1341 | "4437 10\n", 1342 | "22460 10\n", 1343 | "183 10\n", 1344 | "24763 10\n", 1345 | "22716 10\n", 1346 | "24704 10\n", 1347 | "Name: userId, Length: 12704, dtype: int64" 1348 | ] 1349 | }, 1350 | "execution_count": 66, 1351 | "metadata": {}, 1352 | "output_type": "execute_result" 1353 | } 1354 | ], 1355 | "source": [ 1356 | "df['userId'].value_counts()" 1357 | ] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "execution_count": 67, 1362 | "metadata": { 1363 | "scrolled": true 1364 | }, 1365 | "outputs": [ 1366 | { 1367 | "data": { 1368 | "text/plain": [ 1369 | "1252 769\n", 1370 | "1115 756\n", 1371 | "1864 581\n", 1372 | "42 578\n", 1373 | "151 548\n", 1374 | "3309 525\n", 1375 | "4149 523\n", 1376 | "2253 507\n", 1377 | "2954 492\n", 1378 | "1901 489\n", 1379 | "106 472\n", 1380 | "767 471\n", 1381 | "12 470\n", 1382 | "9 445\n", 1383 | "2324 438\n", 1384 | "1447 434\n", 1385 | "4500 430\n", 1386 | "3923 428\n", 1387 | "2111 422\n", 1388 | "3822 400\n", 1389 | "2730 400\n", 1390 | "1854 388\n", 1391 | "3992 384\n", 1392 | "103 382\n", 1393 | "1970 377\n", 1394 | "615 373\n", 1395 | "621 369\n", 1396 | "510 366\n", 1397 | "512 358\n", 1398 | "2016 354\n", 1399 | " ... \n", 1400 | "2401 10\n", 1401 | "14181 10\n", 1402 | "23326 10\n", 1403 | "22891 10\n", 1404 | "2657 10\n", 1405 | "13777 10\n", 1406 | "25708 10\n", 1407 | "9427 10\n", 1408 | "20073 10\n", 1409 | "10084 10\n", 1410 | "16231 10\n", 1411 | "10768 10\n", 1412 | "20701 10\n", 1413 | "16745 10\n", 1414 | "10852 10\n", 1415 | "5364 10\n", 1416 | "19560 10\n", 1417 | "12049 10\n", 1418 | "20840 10\n", 1419 | "16081 10\n", 1420 | "9488 10\n", 1421 | "20585 10\n", 1422 | "5620 10\n", 1423 | "4961 10\n", 1424 | "2914 10\n", 1425 | "9319 10\n", 1426 | "3426 10\n", 1427 | "13669 10\n", 1428 | "788 10\n", 1429 | "14172 10\n", 1430 | "Name: itemId, Length: 9244, dtype: int64" 1431 | ] 1432 | }, 1433 | "execution_count": 67, 1434 | "metadata": {}, 1435 | "output_type": "execute_result" 1436 | } 1437 | ], 1438 | "source": [ 1439 | "df['itemId'].value_counts()" 1440 | ] 1441 | }, 1442 | { 1443 | "cell_type": "code", 1444 | "execution_count": 68, 1445 | "metadata": {}, 1446 | "outputs": [], 1447 | "source": [ 1448 | "df.drop('user_freq', axis=1, inplace=True)\n", 1449 | "df.drop('item_freq', axis=1, inplace=True)\n", 1450 | "df.reset_index(drop=True, inplace=True)" 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "code", 1455 | "execution_count": 69, 1456 | "metadata": {}, 1457 | "outputs": [ 1458 | { 1459 | "name": "stdout", 1460 | "output_type": "stream", 1461 | "text": [ 1462 | "start\n", 1463 | "u20000\n", 1464 | "m10000\n", 1465 | "m20000\n" 1466 | ] 1467 | } 1468 | ], 1469 | "source": [ 1470 | "import numpy as np\n", 1471 | "user_table = np.zeros(np.max(df['userId']) + 1)\n", 1472 | "movie_table = np.zeros(np.max(df['itemId']) + 1)\n", 1473 | "user_set = set(df['userId'].tolist())\n", 1474 | "movie_set = set(df['itemId'].tolist())\n", 1475 | "print('start')\n", 1476 | "u = 1\n", 1477 | "for i in range(1, np.max(df['userId']) + 1):\n", 1478 | " if i in user_set:\n", 1479 | " user_table[i] = u\n", 1480 | " u += 1\n", 1481 | " if i % 20000 == 0:\n", 1482 | " print('u' + str(i))\n", 1483 | "m = 1\n", 1484 | "for i in range(1, np.max(df['itemId']) + 1):\n", 1485 | " if i in movie_set:\n", 1486 | " movie_table[i] = m\n", 1487 | " m += 1\n", 1488 | " if i % 10000 == 0:\n", 1489 | " print('m' + str(i))" 1490 | ] 1491 | }, 1492 | { 1493 | "cell_type": "code", 1494 | "execution_count": 70, 1495 | "metadata": {}, 1496 | "outputs": [ 1497 | { 1498 | "name": "stdout", 1499 | "output_type": "stream", 1500 | "text": [ 1501 | "12705\n", 1502 | "9245\n" 1503 | ] 1504 | } 1505 | ], 1506 | "source": [ 1507 | "print u\n", 1508 | "print m" 1509 | ] 1510 | }, 1511 | { 1512 | "cell_type": "code", 1513 | "execution_count": 71, 1514 | "metadata": { 1515 | "scrolled": true 1516 | }, 1517 | "outputs": [ 1518 | { 1519 | "name": "stdout", 1520 | "output_type": "stream", 1521 | "text": [ 1522 | "0\n", 1523 | "100000\n", 1524 | "200000\n", 1525 | "300000\n" 1526 | ] 1527 | }, 1528 | { 1529 | "data": { 1530 | "text/html": [ 1531 | "
\n", 1532 | "\n", 1545 | "\n", 1546 | " \n", 1547 | " \n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | " \n", 1569 | " \n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | " \n", 1746 | " \n", 1747 | " \n", 1748 | " \n", 1749 | " \n", 1750 | " \n", 1751 | " \n", 1752 | " \n", 1753 | " \n", 1754 | " \n", 1755 | " \n", 1756 | " \n", 1757 | " \n", 1758 | " \n", 1759 | " \n", 1760 | " \n", 1761 | " \n", 1762 | " \n", 1763 | " \n", 1764 | " \n", 1765 | " \n", 1766 | " \n", 1767 | " \n", 1768 | " \n", 1769 | " \n", 1770 | " \n", 1771 | " \n", 1772 | " \n", 1773 | " \n", 1774 | " \n", 1775 | " \n", 1776 | " \n", 1777 | " \n", 1778 | " \n", 1779 | " \n", 1780 | " \n", 1781 | " \n", 1782 | " \n", 1783 | " \n", 1784 | " \n", 1785 | " \n", 1786 | " \n", 1787 | " \n", 1788 | " \n", 1789 | " \n", 1790 | " \n", 1791 | " \n", 1792 | " \n", 1793 | " \n", 1794 | " \n", 1795 | " \n", 1796 | " \n", 1797 | " \n", 1798 | " \n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1804 | " \n", 1805 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1827 | " \n", 1828 | " \n", 1829 | " \n", 1830 | " \n", 1831 | " \n", 1832 | " \n", 1833 | " \n", 1834 | " \n", 1835 | " \n", 1836 | " \n", 1837 | " \n", 1838 | " \n", 1839 | " \n", 1840 | " \n", 1841 | " \n", 1842 | " \n", 1843 | " \n", 1844 | " \n", 1845 | " \n", 1846 | " \n", 1847 | " \n", 1848 | " \n", 1849 | " \n", 1850 | " \n", 1851 | " \n", 1852 | " \n", 1853 | " \n", 1854 | " \n", 1855 | " \n", 1856 | " \n", 1857 | " \n", 1858 | " \n", 1859 | " \n", 1860 | "
userIditemId
011
112
214
316
418
5111
6112
7113
8116
9117
10118
11119
12121
13122
14123
15124
16127
17128
18128
19129
20130
21131
22133
23134
24135
25136
26139
27140
28142
29145
.........
318284127026209
318285127024980
318286127027463
318287127025212
318288127025212
318289127025212
318290127025212
318291127025191
318292127028499
318293127035006
318294127037926
318295127037787
318296127034000
318297127034005
318298127035008
318299127035008
318300127036521
318301127038368
318302127032852
318303127039009
318304127044829
318305127041103
318306127048649
318307127048649
318308127041026
318309127043684
318310127041709
318311127048506
318312127041727
318313127048406
\n", 1861 | "

318314 rows × 2 columns

\n", 1862 | "
" 1863 | ], 1864 | "text/plain": [ 1865 | " userId itemId\n", 1866 | "0 1 1\n", 1867 | "1 1 2\n", 1868 | "2 1 4\n", 1869 | "3 1 6\n", 1870 | "4 1 8\n", 1871 | "5 1 11\n", 1872 | "6 1 12\n", 1873 | "7 1 13\n", 1874 | "8 1 16\n", 1875 | "9 1 17\n", 1876 | "10 1 18\n", 1877 | "11 1 19\n", 1878 | "12 1 21\n", 1879 | "13 1 22\n", 1880 | "14 1 23\n", 1881 | "15 1 24\n", 1882 | "16 1 27\n", 1883 | "17 1 28\n", 1884 | "18 1 28\n", 1885 | "19 1 29\n", 1886 | "20 1 30\n", 1887 | "21 1 31\n", 1888 | "22 1 33\n", 1889 | "23 1 34\n", 1890 | "24 1 35\n", 1891 | "25 1 36\n", 1892 | "26 1 39\n", 1893 | "27 1 40\n", 1894 | "28 1 42\n", 1895 | "29 1 45\n", 1896 | "... ... ...\n", 1897 | "318284 12702 6209\n", 1898 | "318285 12702 4980\n", 1899 | "318286 12702 7463\n", 1900 | "318287 12702 5212\n", 1901 | "318288 12702 5212\n", 1902 | "318289 12702 5212\n", 1903 | "318290 12702 5212\n", 1904 | "318291 12702 5191\n", 1905 | "318292 12702 8499\n", 1906 | "318293 12703 5006\n", 1907 | "318294 12703 7926\n", 1908 | "318295 12703 7787\n", 1909 | "318296 12703 4000\n", 1910 | "318297 12703 4005\n", 1911 | "318298 12703 5008\n", 1912 | "318299 12703 5008\n", 1913 | "318300 12703 6521\n", 1914 | "318301 12703 8368\n", 1915 | "318302 12703 2852\n", 1916 | "318303 12703 9009\n", 1917 | "318304 12704 4829\n", 1918 | "318305 12704 1103\n", 1919 | "318306 12704 8649\n", 1920 | "318307 12704 8649\n", 1921 | "318308 12704 1026\n", 1922 | "318309 12704 3684\n", 1923 | "318310 12704 1709\n", 1924 | "318311 12704 8506\n", 1925 | "318312 12704 1727\n", 1926 | "318313 12704 8406\n", 1927 | "\n", 1928 | "[318314 rows x 2 columns]" 1929 | ] 1930 | }, 1931 | "execution_count": 71, 1932 | "metadata": {}, 1933 | "output_type": "execute_result" 1934 | } 1935 | ], 1936 | "source": [ 1937 | "tmp = df.values\n", 1938 | "for i in range(len(df)):\n", 1939 | " tmp[i, 0] = user_table[int(tmp[i, 0])]\n", 1940 | " tmp[i, 1] = movie_table[int(tmp[i, 1])]\n", 1941 | " if i % 100000 == 0:\n", 1942 | " print i\n", 1943 | "df = pd.DataFrame(tmp, columns=['userId', 'itemId'])\n", 1944 | "df" 1945 | ] 1946 | }, 1947 | { 1948 | "cell_type": "code", 1949 | "execution_count": 72, 1950 | "metadata": {}, 1951 | "outputs": [ 1952 | { 1953 | "name": "stdout", 1954 | "output_type": "stream", 1955 | "text": [ 1956 | "number of users = 12704\n", 1957 | "number of items = 9244\n", 1958 | "sparsity = 0.002710536864\n" 1959 | ] 1960 | } 1961 | ], 1962 | "source": [ 1963 | "num_user = u - 1\n", 1964 | "num_movie = m - 1\n", 1965 | "print('number of users = ' + str(num_user))\n", 1966 | "print('number of items = ' + str(num_movie))\n", 1967 | "sparsity = len(df) * 1.0 / (num_user * num_movie)\n", 1968 | "print('sparsity = ' + str(sparsity))" 1969 | ] 1970 | }, 1971 | { 1972 | "cell_type": "code", 1973 | "execution_count": 35, 1974 | "metadata": {}, 1975 | "outputs": [], 1976 | "source": [ 1977 | "df.to_csv('./data.csv', index=False)" 1978 | ] 1979 | }, 1980 | { 1981 | "cell_type": "code", 1982 | "execution_count": 73, 1983 | "metadata": {}, 1984 | "outputs": [], 1985 | "source": [ 1986 | "train_df = df.copy()\n", 1987 | "test_df = df.copy()" 1988 | ] 1989 | }, 1990 | { 1991 | "cell_type": "code", 1992 | "execution_count": 37, 1993 | "metadata": {}, 1994 | "outputs": [ 1995 | { 1996 | "name": "stdout", 1997 | "output_type": "stream", 1998 | "text": [ 1999 | "500\n", 2000 | "1000\n", 2001 | "1500\n", 2002 | "2000\n", 2003 | "2500\n", 2004 | "3000\n", 2005 | "3500\n", 2006 | "4000\n", 2007 | "4500\n", 2008 | "5000\n", 2009 | "5500\n", 2010 | "6000\n", 2011 | "6500\n", 2012 | "7000\n", 2013 | "7500\n", 2014 | "8000\n", 2015 | "8500\n", 2016 | "9000\n", 2017 | "9500\n", 2018 | "10000\n", 2019 | "10500\n", 2020 | "11000\n", 2021 | "11500\n", 2022 | "12000\n", 2023 | "12500\n", 2024 | "13000\n", 2025 | "13500\n", 2026 | "14000\n", 2027 | "14500\n", 2028 | "15000\n", 2029 | "15500\n", 2030 | "16000\n", 2031 | "16500\n", 2032 | "17000\n", 2033 | "17500\n", 2034 | "18000\n", 2035 | "18500\n", 2036 | "19000\n", 2037 | "19500\n", 2038 | "20000\n", 2039 | "20500\n", 2040 | "21000\n", 2041 | "21500\n", 2042 | "22000\n", 2043 | "22500\n", 2044 | "23000\n", 2045 | "23500\n", 2046 | "24000\n", 2047 | "24500\n", 2048 | "25000\n", 2049 | "25500\n", 2050 | "26000\n", 2051 | "26500\n", 2052 | "27000\n", 2053 | "27500\n", 2054 | "28000\n", 2055 | "28500\n", 2056 | "29000\n", 2057 | "29500\n", 2058 | "30000\n", 2059 | "30500\n", 2060 | "31000\n", 2061 | "31500\n", 2062 | "32000\n", 2063 | "32500\n", 2064 | "33000\n", 2065 | "33500\n", 2066 | "34000\n", 2067 | "34500\n", 2068 | "35000\n", 2069 | "35500\n", 2070 | "36000\n", 2071 | "36500\n", 2072 | "37000\n", 2073 | "37500\n", 2074 | "38000\n", 2075 | "38500\n", 2076 | "39000\n", 2077 | "39500\n", 2078 | "40000\n", 2079 | "40500\n", 2080 | "41000\n", 2081 | "41500\n", 2082 | "42000\n", 2083 | "42500\n", 2084 | "43000\n", 2085 | "43500\n", 2086 | "44000\n", 2087 | "44500\n", 2088 | "45000\n", 2089 | "45500\n", 2090 | "46000\n", 2091 | "46500\n", 2092 | "47000\n", 2093 | "47500\n", 2094 | "48000\n", 2095 | "48500\n", 2096 | "49000\n", 2097 | "49500\n", 2098 | "50000\n", 2099 | "50500\n", 2100 | "51000\n", 2101 | "51500\n", 2102 | "52000\n", 2103 | "52500\n", 2104 | "53000\n", 2105 | "53500\n", 2106 | "54000\n", 2107 | "54500\n", 2108 | "55000\n", 2109 | "55500\n", 2110 | "56000\n", 2111 | "56500\n", 2112 | "57000\n", 2113 | "57500\n", 2114 | "58000\n", 2115 | "58500\n", 2116 | "59000\n", 2117 | "59500\n", 2118 | "60000\n", 2119 | "60500\n", 2120 | "61000\n", 2121 | "61500\n", 2122 | "62000\n", 2123 | "62500\n", 2124 | "63000\n", 2125 | "63500\n", 2126 | "64000\n", 2127 | "64500\n", 2128 | "65000\n", 2129 | "65500\n", 2130 | "66000\n", 2131 | "66500\n", 2132 | "67000\n", 2133 | "67500\n", 2134 | "68000\n", 2135 | "68500\n", 2136 | "69000\n" 2137 | ] 2138 | } 2139 | ], 2140 | "source": [ 2141 | "train_ratio = 0.8\n", 2142 | "test_ratio = 1 - train_ratio\n", 2143 | "num_users = np.max(df['userId'])\n", 2144 | "num_items = np.max(df['itemId'])\n", 2145 | "test_idx = []\n", 2146 | "for u in range(1, num_users+1):\n", 2147 | " u_idx = train_df.index[train_df['userId'] == u]\n", 2148 | " idx_len = len(u_idx)\n", 2149 | " test_len = int(idx_len * test_ratio)\n", 2150 | " if test_len == 0:\n", 2151 | " test_len = 1\n", 2152 | " tmp = np.random.choice(u_idx, size=test_len, replace=False)\n", 2153 | " test_idx += tmp.tolist()\n", 2154 | " if u % 500 == 0:\n", 2155 | " print u" 2156 | ] 2157 | }, 2158 | { 2159 | "cell_type": "code", 2160 | "execution_count": 38, 2161 | "metadata": {}, 2162 | "outputs": [], 2163 | "source": [ 2164 | "test_set = set(test_idx)\n", 2165 | "train_set = set(range(len(df)))\n", 2166 | "train_set -= test_set\n", 2167 | "train_idx = list(train_set)\n", 2168 | "train_df.drop(test_idx, axis=0, inplace=True)\n", 2169 | "test_df.drop(train_idx, axis=0, inplace=True)" 2170 | ] 2171 | }, 2172 | { 2173 | "cell_type": "code", 2174 | "execution_count": 39, 2175 | "metadata": {}, 2176 | "outputs": [ 2177 | { 2178 | "name": "stdout", 2179 | "output_type": "stream", 2180 | "text": [ 2181 | "1134946\n", 2182 | "4746650\n" 2183 | ] 2184 | } 2185 | ], 2186 | "source": [ 2187 | "print len(test_df)\n", 2188 | "print len(train_df)" 2189 | ] 2190 | }, 2191 | { 2192 | "cell_type": "code", 2193 | "execution_count": 40, 2194 | "metadata": { 2195 | "scrolled": true 2196 | }, 2197 | "outputs": [ 2198 | { 2199 | "data": { 2200 | "text/html": [ 2201 | "
\n", 2202 | "\n", 2215 | "\n", 2216 | " \n", 2217 | " \n", 2218 | " \n", 2219 | " \n", 2220 | " \n", 2221 | " \n", 2222 | " \n", 2223 | " \n", 2224 | " \n", 2225 | " \n", 2226 | " \n", 2227 | " \n", 2228 | " \n", 2229 | " \n", 2230 | " \n", 2231 | " \n", 2232 | " \n", 2233 | " \n", 2234 | " \n", 2235 | " \n", 2236 | " \n", 2237 | " \n", 2238 | " \n", 2239 | " \n", 2240 | " \n", 2241 | " \n", 2242 | " \n", 2243 | " \n", 2244 | " \n", 2245 | " \n", 2246 | " \n", 2247 | " \n", 2248 | " \n", 2249 | " \n", 2250 | " \n", 2251 | " \n", 2252 | " \n", 2253 | " \n", 2254 | " \n", 2255 | " \n", 2256 | " \n", 2257 | " \n", 2258 | " \n", 2259 | " \n", 2260 | " \n", 2261 | " \n", 2262 | " \n", 2263 | " \n", 2264 | " \n", 2265 | " \n", 2266 | " \n", 2267 | " \n", 2268 | " \n", 2269 | " \n", 2270 | " \n", 2271 | " \n", 2272 | " \n", 2273 | " \n", 2274 | " \n", 2275 | " \n", 2276 | " \n", 2277 | " \n", 2278 | " \n", 2279 | " \n", 2280 | " \n", 2281 | " \n", 2282 | " \n", 2283 | " \n", 2284 | " \n", 2285 | " \n", 2286 | " \n", 2287 | " \n", 2288 | " \n", 2289 | " \n", 2290 | " \n", 2291 | " \n", 2292 | " \n", 2293 | " \n", 2294 | " \n", 2295 | " \n", 2296 | " \n", 2297 | " \n", 2298 | " \n", 2299 | " \n", 2300 | " \n", 2301 | " \n", 2302 | " \n", 2303 | " \n", 2304 | " \n", 2305 | " \n", 2306 | " \n", 2307 | " \n", 2308 | " \n", 2309 | " \n", 2310 | " \n", 2311 | " \n", 2312 | " \n", 2313 | " \n", 2314 | " \n", 2315 | " \n", 2316 | " \n", 2317 | " \n", 2318 | " \n", 2319 | " \n", 2320 | " \n", 2321 | " \n", 2322 | " \n", 2323 | " \n", 2324 | " \n", 2325 | " \n", 2326 | " \n", 2327 | " \n", 2328 | " \n", 2329 | " \n", 2330 | " \n", 2331 | " \n", 2332 | " \n", 2333 | " \n", 2334 | " \n", 2335 | " \n", 2336 | " \n", 2337 | " \n", 2338 | " \n", 2339 | " \n", 2340 | " \n", 2341 | " \n", 2342 | " \n", 2343 | " \n", 2344 | " \n", 2345 | " \n", 2346 | " \n", 2347 | " \n", 2348 | " \n", 2349 | " \n", 2350 | " \n", 2351 | " \n", 2352 | " \n", 2353 | " \n", 2354 | " \n", 2355 | " \n", 2356 | " \n", 2357 | " \n", 2358 | " \n", 2359 | " \n", 2360 | " \n", 2361 | " \n", 2362 | " \n", 2363 | " \n", 2364 | " \n", 2365 | " \n", 2366 | " \n", 2367 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2371 | " \n", 2372 | " \n", 2373 | " \n", 2374 | " \n", 2375 | " \n", 2376 | " \n", 2377 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2381 | " \n", 2382 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2386 | " \n", 2387 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2391 | " \n", 2392 | " \n", 2393 | " \n", 2394 | " \n", 2395 | " \n", 2396 | " \n", 2397 | " \n", 2398 | " \n", 2399 | " \n", 2400 | " \n", 2401 | " \n", 2402 | " \n", 2403 | " \n", 2404 | " \n", 2405 | " \n", 2406 | " \n", 2407 | " \n", 2408 | " \n", 2409 | " \n", 2410 | " \n", 2411 | " \n", 2412 | " \n", 2413 | " \n", 2414 | " \n", 2415 | " \n", 2416 | " \n", 2417 | " \n", 2418 | " \n", 2419 | " \n", 2420 | " \n", 2421 | " \n", 2422 | " \n", 2423 | " \n", 2424 | " \n", 2425 | " \n", 2426 | " \n", 2427 | " \n", 2428 | " \n", 2429 | " \n", 2430 | " \n", 2431 | " \n", 2432 | " \n", 2433 | " \n", 2434 | " \n", 2435 | " \n", 2436 | " \n", 2437 | " \n", 2438 | " \n", 2439 | " \n", 2440 | " \n", 2441 | " \n", 2442 | " \n", 2443 | " \n", 2444 | " \n", 2445 | " \n", 2446 | " \n", 2447 | " \n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | " \n", 2457 | " \n", 2458 | " \n", 2459 | " \n", 2460 | " \n", 2461 | " \n", 2462 | " \n", 2463 | " \n", 2464 | " \n", 2465 | " \n", 2466 | " \n", 2467 | " \n", 2468 | " \n", 2469 | " \n", 2470 | " \n", 2471 | " \n", 2472 | " \n", 2473 | " \n", 2474 | " \n", 2475 | " \n", 2476 | " \n", 2477 | " \n", 2478 | " \n", 2479 | " \n", 2480 | " \n", 2481 | " \n", 2482 | " \n", 2483 | " \n", 2484 | " \n", 2485 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2499 | " \n", 2500 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2505 | " \n", 2506 | " \n", 2507 | " \n", 2508 | " \n", 2509 | " \n", 2510 | " \n", 2511 | " \n", 2512 | " \n", 2513 | " \n", 2514 | " \n", 2515 | " \n", 2516 | " \n", 2517 | " \n", 2518 | " \n", 2519 | " \n", 2520 | " \n", 2521 | " \n", 2522 | " \n", 2523 | " \n", 2524 | " \n", 2525 | " \n", 2526 | " \n", 2527 | " \n", 2528 | " \n", 2529 | " \n", 2530 | "
userIdmovieId
01119
11180
21225
31286
41310
51322
61348
71349
81355
91357
121408
141468
151508
161527
171570
181572
201577
211597
222107
232146
242208
2621083
2721173
2821179
2921203
3121461
3221491
3423152
3523403
3624199
.........
588155969412312
588156069412332
588156169412335
588156269412341
588156369412373
588156469412374
588156669412572
588156769412574
588156969412578
588157069413107
588157269413254
588157369413310
588157569413573
588157669413741
588157769413749
588157869413783
588157969413848
5881580694131020
5881581694131072
5881582694131126
5881584694131139
5881585694131143
5881587694131281
5881588694131314
5881589694131430
5881591694131657
5881592694131755
5881593694131758
5881594694131865
5881595694132116
\n", 2531 | "

4746650 rows × 2 columns

\n", 2532 | "
" 2533 | ], 2534 | "text/plain": [ 2535 | " userId movieId\n", 2536 | "0 1 119\n", 2537 | "1 1 180\n", 2538 | "2 1 225\n", 2539 | "3 1 286\n", 2540 | "4 1 310\n", 2541 | "5 1 322\n", 2542 | "6 1 348\n", 2543 | "7 1 349\n", 2544 | "8 1 355\n", 2545 | "9 1 357\n", 2546 | "12 1 408\n", 2547 | "14 1 468\n", 2548 | "15 1 508\n", 2549 | "16 1 527\n", 2550 | "17 1 570\n", 2551 | "18 1 572\n", 2552 | "20 1 577\n", 2553 | "21 1 597\n", 2554 | "22 2 107\n", 2555 | "23 2 146\n", 2556 | "24 2 208\n", 2557 | "26 2 1083\n", 2558 | "27 2 1173\n", 2559 | "28 2 1179\n", 2560 | "29 2 1203\n", 2561 | "31 2 1461\n", 2562 | "32 2 1491\n", 2563 | "34 2 3152\n", 2564 | "35 2 3403\n", 2565 | "36 2 4199\n", 2566 | "... ... ...\n", 2567 | "5881559 69412 312\n", 2568 | "5881560 69412 332\n", 2569 | "5881561 69412 335\n", 2570 | "5881562 69412 341\n", 2571 | "5881563 69412 373\n", 2572 | "5881564 69412 374\n", 2573 | "5881566 69412 572\n", 2574 | "5881567 69412 574\n", 2575 | "5881569 69412 578\n", 2576 | "5881570 69413 107\n", 2577 | "5881572 69413 254\n", 2578 | "5881573 69413 310\n", 2579 | "5881575 69413 573\n", 2580 | "5881576 69413 741\n", 2581 | "5881577 69413 749\n", 2582 | "5881578 69413 783\n", 2583 | "5881579 69413 848\n", 2584 | "5881580 69413 1020\n", 2585 | "5881581 69413 1072\n", 2586 | "5881582 69413 1126\n", 2587 | "5881584 69413 1139\n", 2588 | "5881585 69413 1143\n", 2589 | "5881587 69413 1281\n", 2590 | "5881588 69413 1314\n", 2591 | "5881589 69413 1430\n", 2592 | "5881591 69413 1657\n", 2593 | "5881592 69413 1755\n", 2594 | "5881593 69413 1758\n", 2595 | "5881594 69413 1865\n", 2596 | "5881595 69413 2116\n", 2597 | "\n", 2598 | "[4746650 rows x 2 columns]" 2599 | ] 2600 | }, 2601 | "execution_count": 40, 2602 | "metadata": {}, 2603 | "output_type": "execute_result" 2604 | } 2605 | ], 2606 | "source": [ 2607 | "train_df" 2608 | ] 2609 | }, 2610 | { 2611 | "cell_type": "code", 2612 | "execution_count": 41, 2613 | "metadata": { 2614 | "scrolled": true 2615 | }, 2616 | "outputs": [ 2617 | { 2618 | "data": { 2619 | "text/html": [ 2620 | "
\n", 2621 | "\n", 2634 | "\n", 2635 | " \n", 2636 | " \n", 2637 | " \n", 2638 | " \n", 2639 | " \n", 2640 | " \n", 2641 | " \n", 2642 | " \n", 2643 | " \n", 2644 | " \n", 2645 | " \n", 2646 | " \n", 2647 | " \n", 2648 | " \n", 2649 | " \n", 2650 | " \n", 2651 | " \n", 2652 | " \n", 2653 | " \n", 2654 | " \n", 2655 | " \n", 2656 | " \n", 2657 | " \n", 2658 | " \n", 2659 | " \n", 2660 | " \n", 2661 | " \n", 2662 | " \n", 2663 | " \n", 2664 | " \n", 2665 | " \n", 2666 | " \n", 2667 | " \n", 2668 | " \n", 2669 | " \n", 2670 | " \n", 2671 | " \n", 2672 | " \n", 2673 | " \n", 2674 | " \n", 2675 | " \n", 2676 | " \n", 2677 | " \n", 2678 | " \n", 2679 | " \n", 2680 | " \n", 2681 | " \n", 2682 | " \n", 2683 | " \n", 2684 | " \n", 2685 | " \n", 2686 | " \n", 2687 | " \n", 2688 | " \n", 2689 | " \n", 2690 | " \n", 2691 | " \n", 2692 | " \n", 2693 | " \n", 2694 | " \n", 2695 | " \n", 2696 | " \n", 2697 | " \n", 2698 | " \n", 2699 | " \n", 2700 | " \n", 2701 | " \n", 2702 | " \n", 2703 | " \n", 2704 | " \n", 2705 | " \n", 2706 | " \n", 2707 | " \n", 2708 | " \n", 2709 | " \n", 2710 | " \n", 2711 | " \n", 2712 | " \n", 2713 | " \n", 2714 | " \n", 2715 | " \n", 2716 | " \n", 2717 | " \n", 2718 | " \n", 2719 | " \n", 2720 | " \n", 2721 | " \n", 2722 | " \n", 2723 | " \n", 2724 | " \n", 2725 | " \n", 2726 | " \n", 2727 | " \n", 2728 | " \n", 2729 | " \n", 2730 | " \n", 2731 | " \n", 2732 | " \n", 2733 | " \n", 2734 | " \n", 2735 | " \n", 2736 | " \n", 2737 | " \n", 2738 | " \n", 2739 | " \n", 2740 | " \n", 2741 | " \n", 2742 | " \n", 2743 | " \n", 2744 | " \n", 2745 | " \n", 2746 | " \n", 2747 | " \n", 2748 | " \n", 2749 | " \n", 2750 | " \n", 2751 | " \n", 2752 | " \n", 2753 | " \n", 2754 | " \n", 2755 | " \n", 2756 | " \n", 2757 | " \n", 2758 | " \n", 2759 | " \n", 2760 | " \n", 2761 | " \n", 2762 | " \n", 2763 | " \n", 2764 | " \n", 2765 | " \n", 2766 | " \n", 2767 | " \n", 2768 | " \n", 2769 | " \n", 2770 | " \n", 2771 | " \n", 2772 | " \n", 2773 | " \n", 2774 | " \n", 2775 | " \n", 2776 | " \n", 2777 | " \n", 2778 | " \n", 2779 | " \n", 2780 | " \n", 2781 | " \n", 2782 | " \n", 2783 | " \n", 2784 | " \n", 2785 | " \n", 2786 | " \n", 2787 | " \n", 2788 | " \n", 2789 | " \n", 2790 | " \n", 2791 | " \n", 2792 | " \n", 2793 | " \n", 2794 | " \n", 2795 | " \n", 2796 | " \n", 2797 | " \n", 2798 | " \n", 2799 | " \n", 2800 | " \n", 2801 | " \n", 2802 | " \n", 2803 | " \n", 2804 | " \n", 2805 | " \n", 2806 | " \n", 2807 | " \n", 2808 | " \n", 2809 | " \n", 2810 | " \n", 2811 | " \n", 2812 | " \n", 2813 | " \n", 2814 | " \n", 2815 | " \n", 2816 | " \n", 2817 | " \n", 2818 | " \n", 2819 | " \n", 2820 | " \n", 2821 | " \n", 2822 | " \n", 2823 | " \n", 2824 | " \n", 2825 | " \n", 2826 | " \n", 2827 | " \n", 2828 | " \n", 2829 | " \n", 2830 | " \n", 2831 | " \n", 2832 | " \n", 2833 | " \n", 2834 | " \n", 2835 | " \n", 2836 | " \n", 2837 | " \n", 2838 | " \n", 2839 | " \n", 2840 | " \n", 2841 | " \n", 2842 | " \n", 2843 | " \n", 2844 | " \n", 2845 | " \n", 2846 | " \n", 2847 | " \n", 2848 | " \n", 2849 | " \n", 2850 | " \n", 2851 | " \n", 2852 | " \n", 2853 | " \n", 2854 | " \n", 2855 | " \n", 2856 | " \n", 2857 | " \n", 2858 | " \n", 2859 | " \n", 2860 | " \n", 2861 | " \n", 2862 | " \n", 2863 | " \n", 2864 | " \n", 2865 | " \n", 2866 | " \n", 2867 | " \n", 2868 | " \n", 2869 | " \n", 2870 | " \n", 2871 | " \n", 2872 | " \n", 2873 | " \n", 2874 | " \n", 2875 | " \n", 2876 | " \n", 2877 | " \n", 2878 | " \n", 2879 | " \n", 2880 | " \n", 2881 | " \n", 2882 | " \n", 2883 | " \n", 2884 | " \n", 2885 | " \n", 2886 | " \n", 2887 | " \n", 2888 | " \n", 2889 | " \n", 2890 | " \n", 2891 | " \n", 2892 | " \n", 2893 | " \n", 2894 | " \n", 2895 | " \n", 2896 | " \n", 2897 | " \n", 2898 | " \n", 2899 | " \n", 2900 | " \n", 2901 | " \n", 2902 | " \n", 2903 | " \n", 2904 | " \n", 2905 | " \n", 2906 | " \n", 2907 | " \n", 2908 | " \n", 2909 | " \n", 2910 | " \n", 2911 | " \n", 2912 | " \n", 2913 | " \n", 2914 | " \n", 2915 | " \n", 2916 | " \n", 2917 | " \n", 2918 | " \n", 2919 | " \n", 2920 | " \n", 2921 | " \n", 2922 | " \n", 2923 | " \n", 2924 | " \n", 2925 | " \n", 2926 | " \n", 2927 | " \n", 2928 | " \n", 2929 | " \n", 2930 | " \n", 2931 | " \n", 2932 | " \n", 2933 | " \n", 2934 | " \n", 2935 | " \n", 2936 | " \n", 2937 | " \n", 2938 | " \n", 2939 | " \n", 2940 | " \n", 2941 | " \n", 2942 | " \n", 2943 | " \n", 2944 | " \n", 2945 | " \n", 2946 | " \n", 2947 | " \n", 2948 | " \n", 2949 | "
userIdmovieId
101363
111370
131454
191573
252574
3021329
3321562
4527220
4627224
50334
543156
623468
673573
73447
754108
814327
884520
904529
984869
1004873
1014876
1044990
11541136
12441207
13251124
13951319
14351521
14551618
15353699
158699
.........
5881455694111044
5881456694111061
5881459694111126
5881461694111134
5881468694111157
5881471694111162
5881475694111177
5881477694111179
5881479694111197
5881480694111205
5881486694111329
5881491694111779
5881498694111895
5881500694112158
5881505694112347
5881509694112435
5881512694112791
5881524694113163
5881529694113238
5881537694121
588154769412122
588155269412229
588155769412290
588156569412565
588156869412575
588157169413191
588157469413468
5881583694131130
5881586694131246
5881590694131599
\n", 2950 | "

1134946 rows × 2 columns

\n", 2951 | "
" 2952 | ], 2953 | "text/plain": [ 2954 | " userId movieId\n", 2955 | "10 1 363\n", 2956 | "11 1 370\n", 2957 | "13 1 454\n", 2958 | "19 1 573\n", 2959 | "25 2 574\n", 2960 | "30 2 1329\n", 2961 | "33 2 1562\n", 2962 | "45 2 7220\n", 2963 | "46 2 7224\n", 2964 | "50 3 34\n", 2965 | "54 3 156\n", 2966 | "62 3 468\n", 2967 | "67 3 573\n", 2968 | "73 4 47\n", 2969 | "75 4 108\n", 2970 | "81 4 327\n", 2971 | "88 4 520\n", 2972 | "90 4 529\n", 2973 | "98 4 869\n", 2974 | "100 4 873\n", 2975 | "101 4 876\n", 2976 | "104 4 990\n", 2977 | "115 4 1136\n", 2978 | "124 4 1207\n", 2979 | "132 5 1124\n", 2980 | "139 5 1319\n", 2981 | "143 5 1521\n", 2982 | "145 5 1618\n", 2983 | "153 5 3699\n", 2984 | "158 6 99\n", 2985 | "... ... ...\n", 2986 | "5881455 69411 1044\n", 2987 | "5881456 69411 1061\n", 2988 | "5881459 69411 1126\n", 2989 | "5881461 69411 1134\n", 2990 | "5881468 69411 1157\n", 2991 | "5881471 69411 1162\n", 2992 | "5881475 69411 1177\n", 2993 | "5881477 69411 1179\n", 2994 | "5881479 69411 1197\n", 2995 | "5881480 69411 1205\n", 2996 | "5881486 69411 1329\n", 2997 | "5881491 69411 1779\n", 2998 | "5881498 69411 1895\n", 2999 | "5881500 69411 2158\n", 3000 | "5881505 69411 2347\n", 3001 | "5881509 69411 2435\n", 3002 | "5881512 69411 2791\n", 3003 | "5881524 69411 3163\n", 3004 | "5881529 69411 3238\n", 3005 | "5881537 69412 1\n", 3006 | "5881547 69412 122\n", 3007 | "5881552 69412 229\n", 3008 | "5881557 69412 290\n", 3009 | "5881565 69412 565\n", 3010 | "5881568 69412 575\n", 3011 | "5881571 69413 191\n", 3012 | "5881574 69413 468\n", 3013 | "5881583 69413 1130\n", 3014 | "5881586 69413 1246\n", 3015 | "5881590 69413 1599\n", 3016 | "\n", 3017 | "[1134946 rows x 2 columns]" 3018 | ] 3019 | }, 3020 | "execution_count": 41, 3021 | "metadata": {}, 3022 | "output_type": "execute_result" 3023 | } 3024 | ], 3025 | "source": [ 3026 | "test_df" 3027 | ] 3028 | }, 3029 | { 3030 | "cell_type": "code", 3031 | "execution_count": 45, 3032 | "metadata": {}, 3033 | "outputs": [], 3034 | "source": [ 3035 | "train_df.to_csv('./train.csv', index=False)\n", 3036 | "test_df.to_csv('./test.csv', index=False)" 3037 | ] 3038 | }, 3039 | { 3040 | "cell_type": "code", 3041 | "execution_count": 43, 3042 | "metadata": {}, 3043 | "outputs": [], 3044 | "source": [ 3045 | "train_df.reset_index(drop=True, inplace=True)" 3046 | ] 3047 | }, 3048 | { 3049 | "cell_type": "code", 3050 | "execution_count": 44, 3051 | "metadata": {}, 3052 | "outputs": [ 3053 | { 3054 | "name": "stdout", 3055 | "output_type": "stream", 3056 | "text": [ 3057 | "3000\n", 3058 | "6000\n", 3059 | "9000\n", 3060 | "12000\n", 3061 | "15000\n", 3062 | "18000\n", 3063 | "21000\n", 3064 | "24000\n", 3065 | "27000\n", 3066 | "30000\n", 3067 | "33000\n", 3068 | "36000\n", 3069 | "39000\n", 3070 | "42000\n", 3071 | "45000\n", 3072 | "48000\n", 3073 | "51000\n", 3074 | "54000\n", 3075 | "57000\n", 3076 | "60000\n", 3077 | "63000\n", 3078 | "66000\n", 3079 | "69000\n", 3080 | "3000\n", 3081 | "6000\n", 3082 | "9000\n", 3083 | "12000\n", 3084 | "15000\n", 3085 | "18000\n", 3086 | "21000\n", 3087 | "24000\n", 3088 | "27000\n", 3089 | "30000\n", 3090 | "33000\n", 3091 | "36000\n", 3092 | "39000\n", 3093 | "42000\n", 3094 | "45000\n", 3095 | "48000\n", 3096 | "51000\n", 3097 | "54000\n", 3098 | "57000\n", 3099 | "60000\n", 3100 | "63000\n", 3101 | "66000\n", 3102 | "69000\n", 3103 | "3000\n", 3104 | "6000\n", 3105 | "9000\n", 3106 | "12000\n", 3107 | "15000\n", 3108 | "18000\n", 3109 | "21000\n", 3110 | "24000\n", 3111 | "27000\n", 3112 | "30000\n", 3113 | "33000\n", 3114 | "36000\n", 3115 | "39000\n", 3116 | "42000\n", 3117 | "45000\n", 3118 | "48000\n", 3119 | "51000\n", 3120 | "54000\n", 3121 | "57000\n", 3122 | "60000\n", 3123 | "63000\n", 3124 | "66000\n", 3125 | "69000\n", 3126 | "3000\n", 3127 | "6000\n", 3128 | "9000\n", 3129 | "12000\n", 3130 | "15000\n", 3131 | "18000\n", 3132 | "21000\n", 3133 | "24000\n", 3134 | "27000\n", 3135 | "30000\n", 3136 | "33000\n", 3137 | "36000\n", 3138 | "39000\n", 3139 | "42000\n", 3140 | "45000\n", 3141 | "48000\n", 3142 | "51000\n", 3143 | "54000\n", 3144 | "57000\n", 3145 | "60000\n", 3146 | "63000\n", 3147 | "66000\n", 3148 | "69000\n", 3149 | "3000\n", 3150 | "6000\n", 3151 | "9000\n", 3152 | "12000\n", 3153 | "15000\n", 3154 | "18000\n", 3155 | "21000\n", 3156 | "24000\n", 3157 | "27000\n", 3158 | "30000\n", 3159 | "33000\n", 3160 | "36000\n", 3161 | "39000\n", 3162 | "42000\n", 3163 | "45000\n", 3164 | "48000\n", 3165 | "51000\n", 3166 | "54000\n", 3167 | "57000\n", 3168 | "60000\n", 3169 | "63000\n", 3170 | "66000\n", 3171 | "69000\n" 3172 | ] 3173 | } 3174 | ], 3175 | "source": [ 3176 | "train_ratio = 0.9\n", 3177 | "vali_ratio = 1 - train_ratio\n", 3178 | "for i in range(5):\n", 3179 | " train_tmp_df = train_df.copy()\n", 3180 | " vali_df = train_df.copy()\n", 3181 | " vali_idx = []\n", 3182 | " for u in range(1, num_users+1):\n", 3183 | " u_idx = train_tmp_df.index[train_tmp_df['userId'] == u]\n", 3184 | " idx_len = len(u_idx)\n", 3185 | " vali_len = int(idx_len * vali_ratio)\n", 3186 | " if vali_len == 0:\n", 3187 | " vali_len = 1\n", 3188 | " tmp = np.random.choice(u_idx, size=vali_len, replace=False)\n", 3189 | " vali_idx += tmp.tolist()\n", 3190 | " if u % 3000 == 0:\n", 3191 | " print u\n", 3192 | " vali_set = set(vali_idx)\n", 3193 | " train_set = set(range(len(train_tmp_df)))\n", 3194 | " train_set -= vali_set\n", 3195 | " train_tmp_idx = list(train_set)\n", 3196 | " train_tmp_df.drop(vali_idx, axis=0, inplace=True)\n", 3197 | " vali_df.drop(train_tmp_idx, axis=0, inplace=True)\n", 3198 | " train_tmp_df.to_csv('./train_'+str(i)+'.csv', index=False)\n", 3199 | " vali_df.to_csv('./vali_'+str(i)+'.csv', index=False)" 3200 | ] 3201 | } 3202 | ], 3203 | "metadata": { 3204 | "kernelspec": { 3205 | "display_name": "Python 2", 3206 | "language": "python", 3207 | "name": "python2" 3208 | }, 3209 | "language_info": { 3210 | "codemirror_mode": { 3211 | "name": "ipython", 3212 | "version": 2 3213 | }, 3214 | "file_extension": ".py", 3215 | "mimetype": "text/x-python", 3216 | "name": "python", 3217 | "nbconvert_exporter": "python", 3218 | "pygments_lexer": "ipython2", 3219 | "version": "2.7.12" 3220 | } 3221 | }, 3222 | "nbformat": 4, 3223 | "nbformat_minor": 2 3224 | } 3225 | --------------------------------------------------------------------------------