├── .gitattributes
├── data
├── .DS_Store
├── yelp
│ ├── .DS_Store
│ └── .ipynb_checkpoints
│ │ └── preprocessing-checkpoint.ipynb
├── ml-1m
│ └── .DS_Store
└── VideoGame
│ └── .DS_Store
├── README.md
├── test.py
├── data_preprocessor.py
├── JCA.py
└── utility.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-detectable=false
2 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/.DS_Store
--------------------------------------------------------------------------------
/data/yelp/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/yelp/.DS_Store
--------------------------------------------------------------------------------
/data/ml-1m/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/ml-1m/.DS_Store
--------------------------------------------------------------------------------
/data/VideoGame/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zziwei/Joint-Collaborative-Autoencoder/HEAD/data/VideoGame/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Joint-Collaborative-Autoencoder
2 | The implementation of our paper:
3 |
4 | Ziwei Zhu, Jianling Wang and James Caverlee. Improving Top-K Recommendation via Joint Collaborative Autoencoders. In Proceedings of WWW'19, San Francisco, May 13-17, 2019
5 |
6 | The implementation is based on Tensorflow.
7 |
8 | Author: Ziwei Zhu (zhuziwei@tamu.edu)
9 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Ziwei Zhu
3 | Computer Science and Engineering Department, Texas A&M University
4 | zhuziwei@tamu.edu
5 | """
6 | from data_preprocessor import *
7 | import tensorflow as tf
8 | import time
9 | import argparse
10 | import os
11 | from JCA import JCA
12 |
13 | if __name__ == '__main__':
14 | neg_sample_rate = 1
15 |
16 | date = time.strftime('%y-%m-%d', time.localtime())
17 | current_time = time.strftime('%H:%M:%S', time.localtime())
18 | data_name = 'ml-1m'
19 | base = 'u'
20 |
21 | parser = argparse.ArgumentParser(description='JCA')
22 |
23 | parser.add_argument('--train_epoch', type=int, default=200)
24 | parser.add_argument('--batch_size', type=int, default=1500)
25 | parser.add_argument('--display_step', type=int, default=1)
26 | parser.add_argument('--lr', type=float, default=0.003)
27 | parser.add_argument('--lambda_value', type=float, default=0.001)
28 | parser.add_argument('--margin', type=float, default=0.15)
29 | parser.add_argument('--optimizer_method', choices=['Adam', 'Adadelta', 'Adagrad', 'RMSProp', 'GradientDescent',
30 | 'Momentum'], default='Adam')
31 | parser.add_argument('--g_act', choices=['Sigmoid', 'Relu', 'Elu', 'Tanh', "Identity"], default='Sigmoid')
32 | parser.add_argument('--f_act', choices=['Sigmoid', 'Relu', 'Elu', 'Tanh', "Identity"], default='Sigmoid')
33 | parser.add_argument('--U_hidden_neuron', type=int, default=160)
34 | parser.add_argument('--I_hidden_neuron', type=int, default=160)
35 | parser.add_argument('--base', type=str, default=base)
36 | parser.add_argument('--neg_sample_rate', type=int, default=neg_sample_rate)
37 | args = parser.parse_args()
38 |
39 | sess = tf.Session()
40 |
41 | train_R, test_R = yelp.test()
42 | metric_path = './metric_results_test/' + date + '/'
43 | if not os.path.exists(metric_path):
44 | os.makedirs(metric_path)
45 | metric_path = metric_path + '/' + str(parser.description) + "_" + str(current_time)
46 | jca = JCA(sess, args, train_R, test_R, metric_path, date, data_name)
47 | jca.run(train_R, test_R)
48 |
--------------------------------------------------------------------------------
/data_preprocessor.py:
--------------------------------------------------------------------------------
1 | """
2 | Ziwei Zhu
3 | Computer Science and Engineering Department, Texas A&M University
4 | zhuziwei@tamu.edu
5 | """
6 | import numpy as np
7 | import pandas as pd
8 |
9 |
10 | class ml1m:
11 | def __init__(self):
12 | return
13 |
14 | @staticmethod
15 | def train(n):
16 | train_df = pd.read_csv('./data/ml-1m/train_%d.csv' % n)
17 | vali_df = pd.read_csv('./data/ml-1m/vali_%d.csv' % n)
18 | num_users = np.max(train_df['userId'])
19 | num_items = np.max(train_df['movieId'])
20 |
21 | train_R = np.zeros((num_users, num_items)) # training rating matrix
22 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix
23 |
24 | train_mat = train_df.values
25 | for i in range(len(train_df)):
26 | user_idx = int(train_mat[i, 0]) - 1
27 | item_idx = int(train_mat[i, 1]) - 1
28 | train_R[user_idx, item_idx] = 1
29 |
30 | vali_mat = vali_df.values
31 | for i in range(len(vali_df)):
32 | user_idx = int(vali_mat[i, 0]) - 1
33 | item_idx = int(vali_mat[i, 1]) - 1
34 | vali_R[user_idx, item_idx] = 1
35 | return train_R, vali_R
36 |
37 | @staticmethod
38 | def test():
39 | test_df = pd.read_csv('./data/ml-1m/test.csv')
40 | num_users = np.max(test_df['userId'])
41 | num_items = np.max(test_df['movieId'])
42 |
43 | test_R = np.zeros((num_users, num_items)) # testing rating matrix
44 |
45 | test_mat = test_df.values
46 | for i in range(len(test_df)):
47 | user_idx = int(test_mat[i, 0]) - 1
48 | item_idx = int(test_mat[i, 1]) - 1
49 | test_R[user_idx, item_idx] = 1
50 |
51 | train_df = pd.read_csv('./data/ml-1m/train.csv')
52 | num_users = np.max(train_df['userId'])
53 | num_items = np.max(train_df['movieId'])
54 |
55 | train_R = np.zeros((num_users, num_items)) # testing rating matrix
56 |
57 | train_mat = train_df.values
58 | for i in range(len(train_df)):
59 | user_idx = int(train_mat[i, 0]) - 1
60 | item_idx = int(train_mat[i, 1]) - 1
61 | train_R[user_idx, item_idx] = 1
62 | train_R[user_idx, item_idx] = 1
63 |
64 | return train_R, test_R
65 |
66 |
67 |
68 | class yelp:
69 | def __init__(self):
70 | return
71 |
72 | @staticmethod
73 | def train(n):
74 | train_df = pd.read_csv('./data/yelp/train_%d.csv' % n)
75 | vali_df = pd.read_csv('./data/yelp/vali_%d.csv' % n)
76 | num_users = np.max(train_df['userId'])
77 | num_items = np.max(train_df['itemId'])
78 |
79 | train_R = np.zeros((num_users, num_items)) # training rating matrix
80 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix
81 |
82 | train_mat = train_df.values
83 | for i in range(len(train_df)):
84 | user_idx = int(train_mat[i, 0]) - 1
85 | item_idx = int(train_mat[i, 1]) - 1
86 | train_R[user_idx, item_idx] = 1
87 |
88 | vali_mat = vali_df.values
89 | for i in range(len(vali_df)):
90 | user_idx = int(vali_mat[i, 0]) - 1
91 | item_idx = int(vali_mat[i, 1]) - 1
92 | vali_R[user_idx, item_idx] = 1
93 | return train_R, vali_R
94 |
95 | @staticmethod
96 | def test():
97 | test_df = pd.read_csv('./data/yelp/test.csv')
98 | num_users = np.max(test_df['userId'])
99 | num_items = np.max(test_df['itemId'])
100 |
101 | test_R = np.zeros((num_users, num_items)) # testing rating matrix
102 |
103 | test_mat = test_df.values
104 | for i in range(len(test_df)):
105 | user_idx = int(test_mat[i, 0]) - 1
106 | item_idx = int(test_mat[i, 1]) - 1
107 | test_R[user_idx, item_idx] = 1
108 |
109 | train_df = pd.read_csv('./data/yelp/train.csv')
110 | num_users = np.max(train_df['userId'])
111 | num_items = np.max(train_df['itemId'])
112 |
113 | train_R = np.zeros((num_users, num_items)) # testing rating matrix
114 |
115 | train_mat = train_df.values
116 | for i in range(len(train_df)):
117 | user_idx = int(train_mat[i, 0]) - 1
118 | item_idx = int(train_mat[i, 1]) - 1
119 | train_R[user_idx, item_idx] = 1
120 | train_R[user_idx, item_idx] = 1
121 |
122 | return train_R, test_R
123 |
124 |
125 | class VideoGame:
126 | def __init__(self):
127 | return
128 |
129 | @staticmethod
130 | def train(n):
131 | train_df = pd.read_csv('./data/VideoGame/train_%d.csv' % n)
132 | vali_df = pd.read_csv('./data/VideoGame/vali_%d.csv' % n)
133 | num_users = np.max(train_df['userId'])
134 | num_items = np.max(train_df['itemId'])
135 |
136 | train_R = np.zeros((num_users, num_items)) # training rating matrix
137 | vali_R = np.zeros((num_users, num_items)) # validation rating matrix
138 |
139 | train_mat = train_df.values
140 | for i in range(len(train_df)):
141 | user_idx = int(train_mat[i, 0]) - 1
142 | item_idx = int(train_mat[i, 1]) - 1
143 | train_R[user_idx, item_idx] = 1
144 |
145 | vali_mat = vali_df.values
146 | for i in range(len(vali_df)):
147 | user_idx = int(vali_mat[i, 0]) - 1
148 | item_idx = int(vali_mat[i, 1]) - 1
149 | vali_R[user_idx, item_idx] = 1
150 | return train_R, vali_R
151 |
152 | @staticmethod
153 | def test():
154 | test_df = pd.read_csv('./data/VideoGame/test.csv')
155 | num_users = np.max(test_df['userId'])
156 | num_items = np.max(test_df['itemId'])
157 |
158 | test_R = np.zeros((int(num_users), int(num_items))) # testing rating matrix
159 |
160 | test_mat = test_df.values
161 | for i in range(len(test_df)):
162 | user_idx = int(test_mat[i, 0]) - 1
163 | item_idx = int(test_mat[i, 1]) - 1
164 | test_R[user_idx, item_idx] = 1
165 |
166 | train_df = pd.read_csv('./data/VideoGame/train.csv')
167 | num_users = np.max(train_df['userId'])
168 | num_items = np.max(train_df['itemId'])
169 |
170 | train_R = np.zeros((int(num_users), int(num_items))) # testing rating matrix
171 |
172 | train_mat = train_df.values
173 | for i in range(len(train_df)):
174 | user_idx = int(train_mat[i, 0]) - 1
175 | item_idx = int(train_mat[i, 1]) - 1
176 | train_R[user_idx, item_idx] = 1
177 | train_R[user_idx, item_idx] = 1
178 |
179 | return train_R, test_R
180 |
181 |
--------------------------------------------------------------------------------
/JCA.py:
--------------------------------------------------------------------------------
1 | """
2 | Ziwei Zhu
3 | Computer Science and Engineering Department, Texas A&M University
4 | zhuziwei@tamu.edu
5 | """
6 | import tensorflow as tf
7 | import time
8 | import numpy as np
9 | import os
10 | import matplotlib
11 | import copy
12 | import utility
13 |
14 |
15 | class JCA:
16 |
17 | def __init__(self, sess, args, train_R, vali_R, metric_path, date, data_name,
18 | result_path=None):
19 |
20 | if args.f_act == "Sigmoid":
21 | f_act = tf.nn.sigmoid
22 | elif args.f_act == "Relu":
23 | f_act = tf.nn.relu
24 | elif args.f_act == "Tanh":
25 | f_act = tf.nn.tanh
26 | elif args.f_act == "Identity":
27 | f_act = tf.identity
28 | elif args.f_act == "Elu":
29 | f_act = tf.nn.elu
30 | else:
31 | raise NotImplementedError("ERROR")
32 |
33 | if args.g_act == "Sigmoid":
34 | g_act = tf.nn.sigmoid
35 | elif args.g_act == "Relu":
36 | g_act = tf.nn.relu
37 | elif args.g_act == "Tanh":
38 | g_act = tf.nn.tanh
39 | elif args.g_act == "Identity":
40 | g_act = tf.identity
41 | elif args.g_act == "Elu":
42 | g_act = tf.nn.elu
43 | else:
44 | raise NotImplementedError("ERROR")
45 |
46 | self.sess = sess
47 | self.args = args
48 |
49 | self.base = args.base
50 |
51 | self.num_rows = train_R.shape[0]
52 | self.num_cols = train_R.shape[1]
53 | self.U_hidden_neuron = args.U_hidden_neuron
54 | self.I_hidden_neuron = args.I_hidden_neuron
55 |
56 | self.train_R = train_R
57 | self.vali_R = vali_R
58 | self.num_test_ratings = np.sum(vali_R)
59 |
60 | self.train_epoch = args.train_epoch
61 | self.batch_size = args.batch_size
62 | self.num_batch_U = int(self.num_rows / float(self.batch_size)) + 1
63 | self.num_batch_I = int(self.num_cols / float(self.batch_size)) + 1
64 |
65 | self.lr = args.lr # learning rate
66 | self.optimizer_method = args.optimizer_method
67 | self.display_step = args.display_step
68 | self.margin = args.margin
69 |
70 | self.f_act = f_act # the activation function for the output layer
71 | self.g_act = g_act # the activation function for the hidden layer
72 |
73 | self.global_step = tf.Variable(0, trainable=False)
74 |
75 | self.lambda_value = args.lambda_value # regularization term trade-off
76 |
77 | self.result_path = result_path
78 | self.metric_path = metric_path
79 | self.date = date # today's date
80 | self.data_name = data_name
81 |
82 | self.neg_sample_rate = args.neg_sample_rate
83 | self.U_OH_mat = np.eye(self.num_rows, dtype=float)
84 | self.I_OH_mat = np.eye(self.num_cols, dtype=float)
85 |
86 | print('**********JCA**********')
87 | print(self.args)
88 | self.prepare_model()
89 |
90 | def run(self, train_R, vali_R):
91 | self.train_R = train_R
92 | self.vali_R = vali_R
93 | init = tf.global_variables_initializer()
94 | self.sess.run(init)
95 | for epoch_itr in xrange(self.train_epoch):
96 | self.train_model(epoch_itr)
97 | if epoch_itr % 1 == 0:
98 | self.test_model(epoch_itr)
99 | return self.make_records()
100 |
101 | def prepare_model(self):
102 |
103 | # input rating vector
104 | self.input_R_U = tf.placeholder(dtype=tf.float32, shape=[None, self.num_cols], name="input_R_U")
105 | self.input_R_I = tf.placeholder(dtype=tf.float32, shape=[self.num_rows, None], name="input_R_I")
106 | self.input_OH_I = tf.placeholder(dtype=tf.float32, shape=[None, self.num_cols], name="input_OH_I")
107 | self.input_P_cor = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="input_P_cor")
108 | self.input_N_cor = tf.placeholder(dtype=tf.int32, shape=[None, 2], name="input_N_cor")
109 |
110 | # input indicator vector indicator
111 | self.row_idx = tf.placeholder(dtype=tf.int32, shape=[None, 1], name="row_idx")
112 | self.col_idx = tf.placeholder(dtype=tf.int32, shape=[None, 1], name="col_idx")
113 |
114 | # user component
115 | # first layer weights
116 | UV = tf.get_variable(name="UV", initializer=tf.truncated_normal(shape=[self.num_cols, self.U_hidden_neuron],
117 | mean=0, stddev=0.03), dtype=tf.float32)
118 | # second layer weights
119 | UW = tf.get_variable(name="UW", initializer=tf.truncated_normal(shape=[self.U_hidden_neuron, self.num_cols],
120 | mean=0, stddev=0.03), dtype=tf.float32)
121 | # first layer bias
122 | Ub1 = tf.get_variable(name="Ub1", initializer=tf.truncated_normal(shape=[1, self.U_hidden_neuron],
123 | mean=0, stddev=0.03), dtype=tf.float32)
124 | # second layer bias
125 | Ub2 = tf.get_variable(name="Ub2", initializer=tf.truncated_normal(shape=[1, self.num_cols],
126 | mean=0, stddev=0.03), dtype=tf.float32)
127 |
128 | # item component
129 | # first layer weights
130 | IV = tf.get_variable(name="IV", initializer=tf.truncated_normal(shape=[self.num_rows, self.I_hidden_neuron],
131 | mean=0, stddev=0.03), dtype=tf.float32)
132 | # second layer weights
133 | IW = tf.get_variable(name="IW", initializer=tf.truncated_normal(shape=[self.I_hidden_neuron, self.num_rows],
134 | mean=0, stddev=0.03), dtype=tf.float32)
135 | # first layer bias
136 | Ib1 = tf.get_variable(name="Ib1", initializer=tf.truncated_normal(shape=[1, self.I_hidden_neuron],
137 | mean=0, stddev=0.03), dtype=tf.float32)
138 | # second layer bias
139 | Ib2 = tf.get_variable(name="Ib2", initializer=tf.truncated_normal(shape=[1, self.num_rows],
140 | mean=0, stddev=0.03), dtype=tf.float32)
141 |
142 |
143 | I_factor_vector = tf.get_variable(name="I_factor_vector", initializer=tf.random_uniform(shape=[1, self.num_cols]),
144 | dtype=tf.float32)
145 |
146 | # user component
147 | U_pre_Encoder = tf.matmul(self.input_R_U, UV) + Ub1 # input to the hidden layer
148 | self.U_Encoder = self.g_act(U_pre_Encoder) # output of the hidden layer
149 | U_pre_Decoder = tf.matmul(self.U_Encoder, UW) + Ub2 # input to the output layer
150 | self.U_Decoder = self.f_act(U_pre_Decoder) # output of the output layer
151 |
152 | # item component
153 | I_pre_mul = tf.transpose(tf.matmul(I_factor_vector, tf.transpose(self.input_OH_I)))
154 | I_pre_Encoder = tf.matmul(tf.transpose(self.input_R_I), IV) + Ib1 # input to the hidden layer
155 | self.I_Encoder = self.g_act(I_pre_Encoder * I_pre_mul) # output of the hidden layer
156 | I_pre_Decoder = tf.matmul(self.I_Encoder, IW) + Ib2 # input to the output layer
157 | self.I_Decoder = self.f_act(I_pre_Decoder) # output of the output layer
158 |
159 | # final output
160 | self.Decoder = ((tf.transpose(tf.gather_nd(tf.transpose(self.U_Decoder), self.col_idx)))
161 | + tf.gather_nd(tf.transpose(self.I_Decoder), self.row_idx)) / 2.0
162 |
163 | pos_data = tf.gather_nd(self.Decoder, self.input_P_cor)
164 | neg_data = tf.gather_nd(self.Decoder, self.input_N_cor)
165 |
166 | pre_cost1 = tf.maximum(neg_data - pos_data + self.margin,
167 | tf.zeros(tf.shape(neg_data)[0]))
168 | cost1 = tf.reduce_sum(pre_cost1) # prediction squared error
169 | pre_cost2 = tf.square(self.l2_norm(UW)) + tf.square(self.l2_norm(UV)) \
170 | + tf.square(self.l2_norm(IW)) + tf.square(self.l2_norm(IV))\
171 | + tf.square(self.l2_norm(Ib1)) + tf.square(self.l2_norm(Ib2))\
172 | + tf.square(self.l2_norm(Ub1)) + tf.square(self.l2_norm(Ub2))
173 | cost2 = self.lambda_value * 0.5 * pre_cost2 # regularization term
174 |
175 | self.cost = cost1 + cost2 # the loss function
176 |
177 | if self.optimizer_method == "Adam":
178 | optimizer = tf.train.AdamOptimizer(self.lr)
179 | elif self.optimizer_method == "Adadelta":
180 | optimizer = tf.train.AdadeltaOptimizer(self.lr)
181 | elif self.optimizer_method == "Adagrad":
182 | optimizer = tf.train.AdadeltaOptimizer(self.lr)
183 | elif self.optimizer_method == "RMSProp":
184 | optimizer = tf.train.RMSPropOptimizer(self.lr)
185 | elif self.optimizer_method == "GradientDescent":
186 | optimizer = tf.train.GradientDescentOptimizer(self.lr)
187 | elif self.optimizer_method == "Momentum":
188 | optimizer = tf.train.MomentumOptimizer(self.lr, 0.9)
189 | else:
190 | raise ValueError("Optimizer Key ERROR")
191 |
192 | gvs = optimizer.compute_gradients(self.cost)
193 | self.optimizer = optimizer.apply_gradients(gvs, global_step=self.global_step)
194 |
195 | def train_model(self, itr):
196 | start_time = time.time()
197 | random_row_idx = np.random.permutation(self.num_rows) # randomly permute the rows
198 | random_col_idx = np.random.permutation(self.num_cols) # randomly permute the cols
199 | batch_cost = 0
200 | ts = 0
201 | for i in xrange(self.num_batch_U): # iterate each batch
202 | if i == self.num_batch_U - 1:
203 | row_idx = random_row_idx[i * self.batch_size:]
204 | else:
205 | row_idx = random_row_idx[(i * self.batch_size):((i + 1) * self.batch_size)]
206 | for j in xrange(self.num_batch_I):
207 | # get the indices of the current batch
208 | if j == self.num_batch_I - 1:
209 | col_idx = random_col_idx[j * self.batch_size:]
210 | else:
211 | col_idx = random_col_idx[(j * self.batch_size):((j + 1) * self.batch_size)]
212 | ts1 = time.time()
213 | p_input, n_input = utility.pairwise_neg_sampling(self.train_R, row_idx, col_idx, self.neg_sample_rate)
214 | ts2 = time.time()
215 | ts += (ts2 - ts1)
216 | input_tmp = self.train_R[row_idx, :]
217 | input_tmp = input_tmp[:, col_idx]
218 |
219 | input_R_U = self.train_R[row_idx, :]
220 | input_R_I = self.train_R[:, col_idx]
221 | _, cost = self.sess.run( # do the optimization by the minibatch
222 | [self.optimizer, self.cost],
223 | feed_dict={
224 | self.input_R_U: input_R_U,
225 | self.input_R_I: input_R_I,
226 | self.input_OH_I: self.I_OH_mat[col_idx, :],
227 | self.input_P_cor: p_input,
228 | self.input_N_cor: n_input,
229 | self.row_idx: np.reshape(row_idx, (len(row_idx), 1)),
230 | self.col_idx: np.reshape(col_idx, (len(col_idx), 1))})
231 | batch_cost = batch_cost + cost
232 |
233 | if itr % self.display_step == 0:
234 | print ("Training //", "Epoch %d //" % itr, " Total cost = {:.2f}".format(batch_cost),
235 | "Elapsed time : %d sec //" % (time.time() - start_time), "Sampling time: %d s //" %(ts))
236 |
237 | def test_model(self, itr): # calculate the cost and rmse of testing set in each epoch
238 | start_time = time.time()
239 | _, Decoder = self.sess.run([self.cost, self.Decoder],
240 | feed_dict={
241 | self.input_R_U: self.train_R,
242 | self.input_R_I: self.train_R,
243 | self.input_OH_I: self.I_OH_mat,
244 | self.input_P_cor: [[0, 0]],
245 | self.input_N_cor: [[0, 0]],
246 | self.row_idx: np.reshape(xrange(self.num_rows), (self.num_rows, 1)),
247 | self.col_idx: np.reshape(xrange(self.num_cols), (self.num_cols, 1))})
248 | if itr % self.display_step == 0:
249 |
250 | pre_numerator = np.multiply((Decoder - self.vali_R), self.vali_R)
251 | numerator = np.sum(np.square(pre_numerator))
252 | denominator = self.num_test_ratings
253 | RMSE = np.sqrt(numerator / float(denominator))
254 |
255 | if itr % 1 == 0:
256 | if self.base == 'i':
257 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder.T, self.vali_R.T,
258 | self.train_R.T)
259 | else:
260 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder, self.vali_R, self.train_R)
261 |
262 | print (
263 | "Testing //", "Epoch %d //" % itr, " Total cost = {:.2f}".format(numerator),
264 | " RMSE = {:.5f}".format(RMSE),
265 | "Elapsed time : %d sec" % (time.time() - start_time))
266 | print "=" * 100
267 |
268 | def make_records(self): # record all the results' details into files
269 | _, Decoder = self.sess.run([self.cost, self.Decoder],
270 | feed_dict={
271 | self.input_R_U: self.train_R,
272 | self.input_R_I: self.train_R,
273 | self.input_OH_I: self.I_OH_mat,
274 | self.input_P_cor: [[0, 0]],
275 | self.input_N_cor: [[0, 0]],
276 | self.row_idx: np.reshape(xrange(self.num_rows), (self.num_rows, 1)),
277 | self.col_idx: np.reshape(xrange(self.num_cols), (self.num_cols, 1))})
278 | if self.base == 'i':
279 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder.T, self.vali_R.T, self.train_R.T)
280 | else:
281 | [precision, recall, f_score, NDCG] = utility.test_model_all(Decoder, self.vali_R, self.train_R)
282 |
283 | utility.metric_record(precision, recall, f_score, NDCG, self.args, self.metric_path)
284 |
285 | utility.test_model_factor(Decoder, self.vali_R, self.train_R)
286 |
287 | return precision, recall, f_score, NDCG
288 |
289 | @staticmethod
290 | def l2_norm(tensor):
291 | return tf.sqrt(tf.reduce_sum(tf.square(tensor)))
292 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | """
2 | Ziwei Zhu
3 | Computer Science and Engineering Department, Texas A&M University
4 | zhuziwei@tamu.edu
5 | """
6 | from __future__ import division
7 |
8 | from math import log
9 | import numpy as np
10 | import pandas as pd
11 | import copy
12 | from operator import itemgetter
13 | import time
14 |
15 |
16 | # calculate NDCG@k
17 | def NDCG_at_k(predicted_list, ground_truth, k):
18 | dcg_value = [(v / log(i + 1 + 1, 2)) for i, v in enumerate(predicted_list[:k])]
19 | dcg = np.sum(dcg_value)
20 | if len(ground_truth) < k:
21 | ground_truth += [0 for i in range(k - len(ground_truth))]
22 | idcg_value = [(v / log(i + 1 + 1, 2)) for i, v in enumerate(ground_truth[:k])]
23 | idcg = np.sum(idcg_value)
24 | return dcg / idcg
25 |
26 |
27 | # calculate precision@k, recall@k, NDCG@k, where k = 1,5,10,15
28 | def user_precision_recall_ndcg(new_user_prediction, test):
29 | dcg_list = []
30 |
31 | # compute the number of true positive items at top k
32 | count_1, count_5, count_10, count_15 = 0, 0, 0, 0
33 | for i in xrange(15):
34 | if i == 0 and new_user_prediction[i][0] in test:
35 | count_1 = 1.0
36 | if i < 5 and new_user_prediction[i][0] in test:
37 | count_5 += 1.0
38 | if i < 10 and new_user_prediction[i][0] in test:
39 | count_10 += 1.0
40 | if new_user_prediction[i][0] in test:
41 | count_15 += 1.0
42 | dcg_list.append(1)
43 | else:
44 | dcg_list.append(0)
45 |
46 | # calculate NDCG@k
47 | idcg_list = [1 for i in range(len(test))]
48 | ndcg_tmp_1 = NDCG_at_k(dcg_list, idcg_list, 1)
49 | ndcg_tmp_5 = NDCG_at_k(dcg_list, idcg_list, 5)
50 | ndcg_tmp_10 = NDCG_at_k(dcg_list, idcg_list, 10)
51 | ndcg_tmp_15 = NDCG_at_k(dcg_list, idcg_list, 15)
52 |
53 | # precision@k
54 | precision_1 = count_1
55 | precision_5 = count_5 / 5.0
56 | precision_10 = count_10 / 10.0
57 | precision_15 = count_15 / 15.0
58 |
59 | l = len(test)
60 | if l == 0:
61 | l = 1
62 | # recall@k
63 | recall_1 = count_1 / l
64 | recall_5 = count_5 / l
65 | recall_10 = count_10 / l
66 | recall_15 = count_15 / l
67 |
68 | # return precision, recall, ndcg_tmp
69 | return np.array([precision_1, precision_5, precision_10, precision_15]),\
70 | np.array([recall_1, recall_5, recall_10, recall_15]),\
71 | np.array([ndcg_tmp_1, ndcg_tmp_5, ndcg_tmp_10, ndcg_tmp_15])
72 |
73 |
74 | # calculate the metrics of the result
75 | def test_model_all(prediction, test_mask, train_mask):
76 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000
77 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000
78 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000
79 | precision = np.array([precision_1, precision_5, precision_10, precision_15])
80 | recall = np.array([recall_1, recall_5, recall_10, recall_15])
81 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15])
82 |
83 | prediction = prediction + train_mask * -100000.0
84 |
85 | user_num = prediction.shape[0]
86 | for u in range(user_num): # iterate each user
87 | u_test = test_mask[u, :]
88 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set
89 | u_pred = prediction[u, :]
90 |
91 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:]
92 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T
93 | top15 = sorted(top15, key=itemgetter(1), reverse=True)
94 |
95 | # calculate the metrics
96 | if not len(u_test) == 0:
97 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test)
98 | precision += precision_u
99 | recall += recall_u
100 | ndcg += ndcg_u
101 | else:
102 | user_num -= 1
103 |
104 | # compute the average over all users
105 | precision /= user_num
106 | recall /= user_num
107 | ndcg /= user_num
108 | print 'precision_1\t[%.7f],\t||\t precision_5\t[%.7f],\t||\t precision_10\t[%.7f],\t||\t precision_15\t[%.7f]' \
109 | % (precision[0],
110 | precision[1],
111 | precision[2],
112 | precision[3])
113 | print 'recall_1 \t[%.7f],\t||\t recall_5 \t[%.7f],\t||\t recall_10 \t[%.7f],\t||\t recall_15 \t[%.7f]' \
114 | % (recall[0], recall[1],
115 | recall[2], recall[3])
116 | f_measure_1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0]) if not precision[0] + recall[0] == 0 else 0
117 | f_measure_5 = 2 * (precision[1] * recall[1]) / (precision[1] + recall[1]) if not precision[1] + recall[1] == 0 else 0
118 | f_measure_10 = 2 * (precision[2] * recall[2]) / (precision[2] + recall[2]) if not precision[2] + recall[2] == 0 else 0
119 | f_measure_15 = 2 * (precision[3] * recall[3]) / (precision[3] + recall[3]) if not precision[3] + recall[3] == 0 else 0
120 | print 'f_measure_1\t[%.7f],\t||\t f_measure_5\t[%.7f],\t||\t f_measure_10\t[%.7f],\t||\t f_measure_15\t[%.7f]' \
121 | % (f_measure_1,
122 | f_measure_5,
123 | f_measure_10,
124 | f_measure_15)
125 | f_score = [f_measure_1, f_measure_5, f_measure_10, f_measure_15]
126 | print 'ndcg_1 \t[%.7f],\t||\t ndcg_5 \t[%.7f],\t||\t ndcg_10 \t[%.7f],\t||\t ndcg_15 \t[%.7f]' \
127 | % (ndcg[0],
128 | ndcg[1],
129 | ndcg[2],
130 | ndcg[3])
131 | return precision, recall, f_score, ndcg
132 |
133 |
134 | def metric_record(precision, recall, f_score, NDCG, args, metric_path): # record all the results' details into files
135 | path = metric_path + '.txt'
136 |
137 | with open(path, 'w') as f:
138 | f.write(str(args) + '\n')
139 | f.write('precision:' + str(precision) + '\n')
140 | f.write('recall:' + str(recall) + '\n')
141 | f.write('f score:' + str(f_score) + '\n')
142 | f.write('NDCG:' + str(NDCG) + '\n')
143 | f.write('\n')
144 | f.close()
145 |
146 |
147 | def get_train_instances(train_R, neg_sample_rate):
148 | """
149 | genderate training dataset for NCF models in each iteration
150 | :param train_R:
151 | :param neg_sample_rate:
152 | :return:
153 | """
154 | # randomly sample negative samples
155 | mask = neg_sampling(train_R, range(train_R.shape[0]), neg_sample_rate)
156 |
157 | user_input, item_input, labels = [], [], []
158 | idx = np.array(np.where(mask == 1))
159 | for i in range(idx.shape[1]):
160 | # positive instance
161 | u_i = idx[0, i]
162 | i_i = idx[1, i]
163 | user_input.append(u_i)
164 | item_input.append(i_i)
165 | labels.append(train_R[u_i, i_i])
166 | return user_input, item_input, labels
167 |
168 |
169 | def neg_sampling(train_R, idx, neg_sample_rate):
170 | """
171 | randomly negative smaples
172 | :param train_R:
173 | :param idx:
174 | :param neg_sample_rate:
175 | :return:
176 | """
177 | num_cols = train_R.shape[1]
178 | num_rows = train_R.shape[0]
179 | # randomly sample negative samples
180 | mask = copy.copy(train_R)
181 | if neg_sample_rate == 0:
182 | return mask
183 | for b_idx in idx:
184 | mask_list = mask[b_idx, :]
185 | unobsv_list = np.where(mask_list == 0)
186 | unobsv_list = unobsv_list[0] # unobserved indices
187 | obsv_num = num_cols - len(unobsv_list)
188 | neg_num = int(obsv_num * neg_sample_rate)
189 | if neg_num > len(unobsv_list): # if the observed positive ratings are more than the half
190 | neg_num = len(unobsv_list)
191 | if neg_num == 0:
192 | neg_num = 1
193 | neg_samp_list = np.random.choice(unobsv_list, size=neg_num, replace=False)
194 | mask_list[neg_samp_list] = 1
195 | mask[b_idx, :] = mask_list
196 | return mask
197 |
198 |
199 | def pairwise_neg_sampling(train_R, r_idx, c_idx, neg_sample_rate):
200 | R = train_R[r_idx, :]
201 | R = R[:, c_idx]
202 | p_input, n_input = [], []
203 | obsv_list = np.where(R == 1)
204 |
205 | unobsv_mat = []
206 | for r in range(R.shape[0]):
207 | unobsv_list = np.where(R[r, :] == 0)
208 | unobsv_list = unobsv_list[0]
209 | unobsv_mat.append(unobsv_list)
210 |
211 | for i in range(len(obsv_list[1])):
212 | # positive instance
213 | u = obsv_list[0][i]
214 | # negative instances
215 | unobsv_list = unobsv_mat[u]
216 | neg_samp_list = np.random.choice(unobsv_list, size=neg_sample_rate, replace=False)
217 | for ns in neg_samp_list:
218 | p_input.append([u, obsv_list[1][i]])
219 | n_input.append([u, ns])
220 | # print('dataset size = ' + str(len(p_input)))
221 | return np.array(p_input), np.array(n_input)
222 |
223 |
224 | # calculate the metrics of the result
225 | def test_model_batch(prediction, test_mask, train_mask):
226 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000
227 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000
228 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000
229 | precision = np.array([precision_1, precision_5, precision_10, precision_15])
230 | recall = np.array([recall_1, recall_5, recall_10, recall_15])
231 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15])
232 |
233 | prediction = prediction + train_mask * -100000.0
234 |
235 | user_num = prediction.shape[0]
236 | for u in range(user_num): # iterate each user
237 | u_test = test_mask[u, :]
238 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set
239 | u_pred = prediction[u, :]
240 |
241 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:]
242 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T
243 | top15 = sorted(top15, key=itemgetter(1), reverse=True)
244 |
245 | # calculate the metrics
246 | if not len(u_test) == 0:
247 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test)
248 | precision += precision_u
249 | recall += recall_u
250 | ndcg += ndcg_u
251 | else:
252 | user_num -= 1
253 |
254 | return precision, recall, ndcg
255 |
256 |
257 | # calculate the metrics of the result
258 | def test_model_cold_start(prediction, test_mask, train_mask):
259 | precision_1, precision_5, precision_10, precision_15 = 0.0000, 0.0000, 0.0000, 0.0000
260 | recall_1, recall_5, recall_10, recall_15 = 0.0000, 0.0000, 0.0000, 0.0000
261 | ndcg_1, ndcg_5, ndcg_10, ndcg_15 = 0.0000, 0.0000, 0.0000, 0.0000
262 | precision = np.array([precision_1, precision_5, precision_10, precision_15])
263 | recall = np.array([recall_1, recall_5, recall_10, recall_15])
264 | ndcg = np.array([ndcg_1, ndcg_5, ndcg_10, ndcg_15])
265 |
266 | prediction = prediction + train_mask * -100000.0
267 |
268 | user_num = prediction.shape[0]
269 | n = 0
270 | for u in range(user_num): # iterate each user
271 | u_test = test_mask[u, :]
272 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set
273 | if len(u_test) > 10:
274 | continue
275 | u_pred = prediction[u, :]
276 |
277 | top15_item_idx_no_train = np.argpartition(u_pred, -15)[-15:]
278 | top15 = (np.array([top15_item_idx_no_train, u_pred[top15_item_idx_no_train]])).T
279 | top15 = sorted(top15, key=itemgetter(1), reverse=True)
280 |
281 | # calculate the metrics
282 | if not len(u_test) == 0:
283 | precision_u, recall_u, ndcg_u = user_precision_recall_ndcg(top15, u_test)
284 | precision += precision_u
285 | recall += recall_u
286 | ndcg += ndcg_u
287 | n += 1
288 |
289 | # compute the average over all users
290 | precision /= n
291 | recall /= n
292 | ndcg /= n
293 | print 'precision_1\t[%.7f],\t||\t precision_5\t[%.7f],\t||\t precision_10\t[%.7f],\t||\t precision_15\t[%.7f]' \
294 | % (precision[0],
295 | precision[1],
296 | precision[2],
297 | precision[3])
298 | print 'recall_1 \t[%.7f],\t||\t recall_5 \t[%.7f],\t||\t recall_10 \t[%.7f],\t||\t recall_15 \t[%.7f]' \
299 | % (recall[0], recall[1],
300 | recall[2], recall[3])
301 | f_measure_1 = 2 * (precision[0] * recall[0]) / (precision[0] + recall[0]) if not precision[0] + recall[0] == 0 else 0
302 | f_measure_5 = 2 * (precision[1] * recall[1]) / (precision[1] + recall[1]) if not precision[1] + recall[1] == 0 else 0
303 | f_measure_10 = 2 * (precision[2] * recall[2]) / (precision[2] + recall[2]) if not precision[2] + recall[2] == 0 else 0
304 | f_measure_15 = 2 * (precision[3] * recall[3]) / (precision[3] + recall[3]) if not precision[3] + recall[3] == 0 else 0
305 | print 'f_measure_1\t[%.7f],\t||\t f_measure_5\t[%.7f],\t||\t f_measure_10\t[%.7f],\t||\t f_measure_15\t[%.7f]' \
306 | % (f_measure_1,
307 | f_measure_5,
308 | f_measure_10,
309 | f_measure_15)
310 | f_score = [f_measure_1, f_measure_5, f_measure_10, f_measure_15]
311 | print 'ndcg_1 \t[%.7f],\t||\t ndcg_5 \t[%.7f],\t||\t ndcg_10 \t[%.7f],\t||\t ndcg_15 \t[%.7f]' \
312 | % (ndcg[0],
313 | ndcg[1],
314 | ndcg[2],
315 | ndcg[3])
316 | return precision, recall, f_score, ndcg
317 |
318 |
319 | def test_model_factor(prediction, test_mask, train_mask):
320 | item_list = np.zeros(train_mask.shape[1])
321 | item_list_rank = np.zeros(train_mask.shape[1])
322 |
323 | prediction = prediction + train_mask * -100000.0
324 |
325 | user_num = prediction.shape[0]
326 | for u in range(user_num): # iterate each user
327 | u_test = test_mask[u, :]
328 | u_test = np.where(u_test == 1)[0] # the indices of the true positive items in the test set
329 | len_u_test = len(u_test)
330 | u_pred = prediction[u, :]
331 |
332 | top10_item_idx_no_train = np.argpartition(u_pred, -10)[-10:]
333 | item_list[top10_item_idx_no_train] += 1
334 | for i in range(len(top10_item_idx_no_train)):
335 | item_list_rank[top10_item_idx_no_train[i]] += (10 - i)
336 |
337 | item_count = np.sum(train_mask, axis=0)
338 | df = pd.DataFrame({'item_pred_freq': item_list, 'item_count': item_count})
339 | df.to_csv('data/no-factor' + time.strftime('%y-%m-%d-%H-%M-%S', time.localtime()) + '.csv')
340 | df = pd.DataFrame({'item_pred_rank': item_list_rank, 'item_count': item_count})
341 | df.to_csv('data/rank-no-factor' + time.strftime('%y-%m-%d-%H-%M-%S', time.localtime()) + '.csv')
342 |
--------------------------------------------------------------------------------
/data/yelp/.ipynb_checkpoints/preprocessing-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "text/html": [
20 | "
\n",
21 | "\n",
34 | "
\n",
35 | " \n",
36 | " \n",
37 | " | \n",
38 | " userId | \n",
39 | " itemId | \n",
40 | " rating | \n",
41 | " timestamp | \n",
42 | "
\n",
43 | " \n",
44 | " \n",
45 | " \n",
46 | " | 0 | \n",
47 | " 0 | \n",
48 | " 0 | \n",
49 | " 4.0 | \n",
50 | " 1329148800 | \n",
51 | "
\n",
52 | " \n",
53 | " | 1 | \n",
54 | " 0 | \n",
55 | " 0 | \n",
56 | " 4.0 | \n",
57 | " 1337011200 | \n",
58 | "
\n",
59 | " \n",
60 | " | 2 | \n",
61 | " 0 | \n",
62 | " 1 | \n",
63 | " 4.0 | \n",
64 | " 1335888000 | \n",
65 | "
\n",
66 | " \n",
67 | " | 3 | \n",
68 | " 0 | \n",
69 | " 2 | \n",
70 | " 4.0 | \n",
71 | " 1379260800 | \n",
72 | "
\n",
73 | " \n",
74 | " | 4 | \n",
75 | " 0 | \n",
76 | " 3 | \n",
77 | " 2.0 | \n",
78 | " 1367856000 | \n",
79 | "
\n",
80 | " \n",
81 | "
\n",
82 | "
"
83 | ],
84 | "text/plain": [
85 | " userId itemId rating timestamp\n",
86 | "0 0 0 4.0 1329148800\n",
87 | "1 0 0 4.0 1337011200\n",
88 | "2 0 1 4.0 1335888000\n",
89 | "3 0 2 4.0 1379260800\n",
90 | "4 0 3 2.0 1367856000"
91 | ]
92 | },
93 | "execution_count": 2,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "data_df = pd.read_csv('./yelp.csv', sep=',')\n",
100 | "data_df.head()"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 3,
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "data": {
110 | "text/html": [
111 | "\n",
112 | "\n",
125 | "
\n",
126 | " \n",
127 | " \n",
128 | " | \n",
129 | " userId | \n",
130 | " itemId | \n",
131 | " rating | \n",
132 | "
\n",
133 | " \n",
134 | " \n",
135 | " \n",
136 | " | 0 | \n",
137 | " 0 | \n",
138 | " 0 | \n",
139 | " 4.0 | \n",
140 | "
\n",
141 | " \n",
142 | " | 1 | \n",
143 | " 0 | \n",
144 | " 0 | \n",
145 | " 4.0 | \n",
146 | "
\n",
147 | " \n",
148 | " | 2 | \n",
149 | " 0 | \n",
150 | " 1 | \n",
151 | " 4.0 | \n",
152 | "
\n",
153 | " \n",
154 | " | 3 | \n",
155 | " 0 | \n",
156 | " 2 | \n",
157 | " 4.0 | \n",
158 | "
\n",
159 | " \n",
160 | " | 4 | \n",
161 | " 0 | \n",
162 | " 3 | \n",
163 | " 2.0 | \n",
164 | "
\n",
165 | " \n",
166 | "
\n",
167 | "
"
168 | ],
169 | "text/plain": [
170 | " userId itemId rating\n",
171 | "0 0 0 4.0\n",
172 | "1 0 0 4.0\n",
173 | "2 0 1 4.0\n",
174 | "3 0 2 4.0\n",
175 | "4 0 3 2.0"
176 | ]
177 | },
178 | "execution_count": 3,
179 | "metadata": {},
180 | "output_type": "execute_result"
181 | }
182 | ],
183 | "source": [
184 | "data_df.drop('timestamp', axis=1, inplace=True)\n",
185 | "data_df.head()"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 46,
191 | "metadata": {},
192 | "outputs": [
193 | {
194 | "data": {
195 | "text/plain": [
196 | "731671"
197 | ]
198 | },
199 | "execution_count": 46,
200 | "metadata": {},
201 | "output_type": "execute_result"
202 | }
203 | ],
204 | "source": [
205 | "len(data_df)"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "# data_df['mean'] = data_df.groupby('userId')['rating'].transform('mean')"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 4,
220 | "metadata": {
221 | "scrolled": true
222 | },
223 | "outputs": [
224 | {
225 | "name": "stdout",
226 | "output_type": "stream",
227 | "text": [
228 | "0\n",
229 | "100000\n",
230 | "200000\n",
231 | "300000\n",
232 | "400000\n",
233 | "500000\n",
234 | "600000\n",
235 | "700000\n"
236 | ]
237 | }
238 | ],
239 | "source": [
240 | "for i in range(len(data_df)):\n",
241 | " if data_df.at[i, 'rating'] > 3:\n",
242 | " data_df.at[i, 'rating'] = 1\n",
243 | " else:\n",
244 | " data_df.at[i, 'rating'] = 0\n",
245 | " if i % 100000 == 0:\n",
246 | " print i"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 7,
252 | "metadata": {
253 | "scrolled": true
254 | },
255 | "outputs": [],
256 | "source": [
257 | "data_df['userId'] = data_df['userId'] + 1\n",
258 | "data_df['itemId'] = data_df['itemId'] + 1"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 47,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "df = data_df.copy()"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 48,
273 | "metadata": {
274 | "scrolled": true
275 | },
276 | "outputs": [
277 | {
278 | "data": {
279 | "text/html": [
280 | "\n",
281 | "\n",
294 | "
\n",
295 | " \n",
296 | " \n",
297 | " | \n",
298 | " userId | \n",
299 | " itemId | \n",
300 | " rating | \n",
301 | "
\n",
302 | " \n",
303 | " \n",
304 | " \n",
305 | " | 0 | \n",
306 | " 1 | \n",
307 | " 1 | \n",
308 | " 1.0 | \n",
309 | "
\n",
310 | " \n",
311 | " | 1 | \n",
312 | " 1 | \n",
313 | " 1 | \n",
314 | " 1.0 | \n",
315 | "
\n",
316 | " \n",
317 | " | 2 | \n",
318 | " 1 | \n",
319 | " 2 | \n",
320 | " 1.0 | \n",
321 | "
\n",
322 | " \n",
323 | " | 3 | \n",
324 | " 1 | \n",
325 | " 3 | \n",
326 | " 1.0 | \n",
327 | "
\n",
328 | " \n",
329 | " | 7 | \n",
330 | " 1 | \n",
331 | " 7 | \n",
332 | " 1.0 | \n",
333 | "
\n",
334 | " \n",
335 | " | 8 | \n",
336 | " 1 | \n",
337 | " 8 | \n",
338 | " 1.0 | \n",
339 | "
\n",
340 | " \n",
341 | " | 10 | \n",
342 | " 1 | \n",
343 | " 10 | \n",
344 | " 1.0 | \n",
345 | "
\n",
346 | " \n",
347 | " | 12 | \n",
348 | " 1 | \n",
349 | " 12 | \n",
350 | " 1.0 | \n",
351 | "
\n",
352 | " \n",
353 | " | 15 | \n",
354 | " 1 | \n",
355 | " 15 | \n",
356 | " 1.0 | \n",
357 | "
\n",
358 | " \n",
359 | " | 16 | \n",
360 | " 1 | \n",
361 | " 16 | \n",
362 | " 1.0 | \n",
363 | "
\n",
364 | " \n",
365 | " | 17 | \n",
366 | " 1 | \n",
367 | " 17 | \n",
368 | " 1.0 | \n",
369 | "
\n",
370 | " \n",
371 | " | 21 | \n",
372 | " 1 | \n",
373 | " 21 | \n",
374 | " 1.0 | \n",
375 | "
\n",
376 | " \n",
377 | " | 22 | \n",
378 | " 1 | \n",
379 | " 22 | \n",
380 | " 1.0 | \n",
381 | "
\n",
382 | " \n",
383 | " | 23 | \n",
384 | " 1 | \n",
385 | " 23 | \n",
386 | " 1.0 | \n",
387 | "
\n",
388 | " \n",
389 | " | 24 | \n",
390 | " 1 | \n",
391 | " 24 | \n",
392 | " 1.0 | \n",
393 | "
\n",
394 | " \n",
395 | " | 27 | \n",
396 | " 1 | \n",
397 | " 27 | \n",
398 | " 1.0 | \n",
399 | "
\n",
400 | " \n",
401 | " | 29 | \n",
402 | " 1 | \n",
403 | " 29 | \n",
404 | " 1.0 | \n",
405 | "
\n",
406 | " \n",
407 | " | 30 | \n",
408 | " 1 | \n",
409 | " 30 | \n",
410 | " 1.0 | \n",
411 | "
\n",
412 | " \n",
413 | " | 31 | \n",
414 | " 1 | \n",
415 | " 31 | \n",
416 | " 1.0 | \n",
417 | "
\n",
418 | " \n",
419 | " | 32 | \n",
420 | " 1 | \n",
421 | " 32 | \n",
422 | " 1.0 | \n",
423 | "
\n",
424 | " \n",
425 | " | 33 | \n",
426 | " 1 | \n",
427 | " 33 | \n",
428 | " 1.0 | \n",
429 | "
\n",
430 | " \n",
431 | " | 34 | \n",
432 | " 1 | \n",
433 | " 34 | \n",
434 | " 1.0 | \n",
435 | "
\n",
436 | " \n",
437 | " | 35 | \n",
438 | " 1 | \n",
439 | " 35 | \n",
440 | " 1.0 | \n",
441 | "
\n",
442 | " \n",
443 | " | 36 | \n",
444 | " 1 | \n",
445 | " 36 | \n",
446 | " 1.0 | \n",
447 | "
\n",
448 | " \n",
449 | " | 39 | \n",
450 | " 1 | \n",
451 | " 39 | \n",
452 | " 1.0 | \n",
453 | "
\n",
454 | " \n",
455 | " | 40 | \n",
456 | " 1 | \n",
457 | " 40 | \n",
458 | " 1.0 | \n",
459 | "
\n",
460 | " \n",
461 | " | 41 | \n",
462 | " 1 | \n",
463 | " 41 | \n",
464 | " 1.0 | \n",
465 | "
\n",
466 | " \n",
467 | " | 42 | \n",
468 | " 1 | \n",
469 | " 42 | \n",
470 | " 1.0 | \n",
471 | "
\n",
472 | " \n",
473 | " | 43 | \n",
474 | " 1 | \n",
475 | " 42 | \n",
476 | " 1.0 | \n",
477 | "
\n",
478 | " \n",
479 | " | 45 | \n",
480 | " 1 | \n",
481 | " 44 | \n",
482 | " 1.0 | \n",
483 | "
\n",
484 | " \n",
485 | " | ... | \n",
486 | " ... | \n",
487 | " ... | \n",
488 | " ... | \n",
489 | "
\n",
490 | " \n",
491 | " | 731641 | \n",
492 | " 25675 | \n",
493 | " 9109 | \n",
494 | " 1.0 | \n",
495 | "
\n",
496 | " \n",
497 | " | 731642 | \n",
498 | " 25675 | \n",
499 | " 8917 | \n",
500 | " 1.0 | \n",
501 | "
\n",
502 | " \n",
503 | " | 731643 | \n",
504 | " 25675 | \n",
505 | " 18629 | \n",
506 | " 1.0 | \n",
507 | "
\n",
508 | " \n",
509 | " | 731644 | \n",
510 | " 25675 | \n",
511 | " 2578 | \n",
512 | " 1.0 | \n",
513 | "
\n",
514 | " \n",
515 | " | 731645 | \n",
516 | " 25675 | \n",
517 | " 22557 | \n",
518 | " 1.0 | \n",
519 | "
\n",
520 | " \n",
521 | " | 731646 | \n",
522 | " 25675 | \n",
523 | " 20292 | \n",
524 | " 1.0 | \n",
525 | "
\n",
526 | " \n",
527 | " | 731647 | \n",
528 | " 25675 | \n",
529 | " 22555 | \n",
530 | " 1.0 | \n",
531 | "
\n",
532 | " \n",
533 | " | 731648 | \n",
534 | " 25675 | \n",
535 | " 24168 | \n",
536 | " 1.0 | \n",
537 | "
\n",
538 | " \n",
539 | " | 731649 | \n",
540 | " 25675 | \n",
541 | " 24255 | \n",
542 | " 1.0 | \n",
543 | "
\n",
544 | " \n",
545 | " | 731650 | \n",
546 | " 25675 | \n",
547 | " 24693 | \n",
548 | " 1.0 | \n",
549 | "
\n",
550 | " \n",
551 | " | 731651 | \n",
552 | " 25676 | \n",
553 | " 8850 | \n",
554 | " 1.0 | \n",
555 | "
\n",
556 | " \n",
557 | " | 731652 | \n",
558 | " 25676 | \n",
559 | " 1946 | \n",
560 | " 1.0 | \n",
561 | "
\n",
562 | " \n",
563 | " | 731653 | \n",
564 | " 25676 | \n",
565 | " 20837 | \n",
566 | " 1.0 | \n",
567 | "
\n",
568 | " \n",
569 | " | 731654 | \n",
570 | " 25676 | \n",
571 | " 20837 | \n",
572 | " 1.0 | \n",
573 | "
\n",
574 | " \n",
575 | " | 731655 | \n",
576 | " 25676 | \n",
577 | " 1840 | \n",
578 | " 1.0 | \n",
579 | "
\n",
580 | " \n",
581 | " | 731656 | \n",
582 | " 25676 | \n",
583 | " 6033 | \n",
584 | " 1.0 | \n",
585 | "
\n",
586 | " \n",
587 | " | 731657 | \n",
588 | " 25676 | \n",
589 | " 3017 | \n",
590 | " 1.0 | \n",
591 | "
\n",
592 | " \n",
593 | " | 731658 | \n",
594 | " 25676 | \n",
595 | " 20016 | \n",
596 | " 1.0 | \n",
597 | "
\n",
598 | " \n",
599 | " | 731659 | \n",
600 | " 25676 | \n",
601 | " 3042 | \n",
602 | " 1.0 | \n",
603 | "
\n",
604 | " \n",
605 | " | 731660 | \n",
606 | " 25676 | \n",
607 | " 19256 | \n",
608 | " 1.0 | \n",
609 | "
\n",
610 | " \n",
611 | " | 731661 | \n",
612 | " 25677 | \n",
613 | " 5449 | \n",
614 | " 1.0 | \n",
615 | "
\n",
616 | " \n",
617 | " | 731662 | \n",
618 | " 25677 | \n",
619 | " 14839 | \n",
620 | " 1.0 | \n",
621 | "
\n",
622 | " \n",
623 | " | 731663 | \n",
624 | " 25677 | \n",
625 | " 11233 | \n",
626 | " 1.0 | \n",
627 | "
\n",
628 | " \n",
629 | " | 731664 | \n",
630 | " 25677 | \n",
631 | " 166 | \n",
632 | " 1.0 | \n",
633 | "
\n",
634 | " \n",
635 | " | 731665 | \n",
636 | " 25677 | \n",
637 | " 8523 | \n",
638 | " 1.0 | \n",
639 | "
\n",
640 | " \n",
641 | " | 731666 | \n",
642 | " 25677 | \n",
643 | " 22087 | \n",
644 | " 1.0 | \n",
645 | "
\n",
646 | " \n",
647 | " | 731667 | \n",
648 | " 25677 | \n",
649 | " 5615 | \n",
650 | " 1.0 | \n",
651 | "
\n",
652 | " \n",
653 | " | 731668 | \n",
654 | " 25677 | \n",
655 | " 18223 | \n",
656 | " 1.0 | \n",
657 | "
\n",
658 | " \n",
659 | " | 731669 | \n",
660 | " 25677 | \n",
661 | " 5453 | \n",
662 | " 1.0 | \n",
663 | "
\n",
664 | " \n",
665 | " | 731670 | \n",
666 | " 25677 | \n",
667 | " 1589 | \n",
668 | " 1.0 | \n",
669 | "
\n",
670 | " \n",
671 | "
\n",
672 | "
486499 rows × 3 columns
\n",
673 | "
"
674 | ],
675 | "text/plain": [
676 | " userId itemId rating\n",
677 | "0 1 1 1.0\n",
678 | "1 1 1 1.0\n",
679 | "2 1 2 1.0\n",
680 | "3 1 3 1.0\n",
681 | "7 1 7 1.0\n",
682 | "8 1 8 1.0\n",
683 | "10 1 10 1.0\n",
684 | "12 1 12 1.0\n",
685 | "15 1 15 1.0\n",
686 | "16 1 16 1.0\n",
687 | "17 1 17 1.0\n",
688 | "21 1 21 1.0\n",
689 | "22 1 22 1.0\n",
690 | "23 1 23 1.0\n",
691 | "24 1 24 1.0\n",
692 | "27 1 27 1.0\n",
693 | "29 1 29 1.0\n",
694 | "30 1 30 1.0\n",
695 | "31 1 31 1.0\n",
696 | "32 1 32 1.0\n",
697 | "33 1 33 1.0\n",
698 | "34 1 34 1.0\n",
699 | "35 1 35 1.0\n",
700 | "36 1 36 1.0\n",
701 | "39 1 39 1.0\n",
702 | "40 1 40 1.0\n",
703 | "41 1 41 1.0\n",
704 | "42 1 42 1.0\n",
705 | "43 1 42 1.0\n",
706 | "45 1 44 1.0\n",
707 | "... ... ... ...\n",
708 | "731641 25675 9109 1.0\n",
709 | "731642 25675 8917 1.0\n",
710 | "731643 25675 18629 1.0\n",
711 | "731644 25675 2578 1.0\n",
712 | "731645 25675 22557 1.0\n",
713 | "731646 25675 20292 1.0\n",
714 | "731647 25675 22555 1.0\n",
715 | "731648 25675 24168 1.0\n",
716 | "731649 25675 24255 1.0\n",
717 | "731650 25675 24693 1.0\n",
718 | "731651 25676 8850 1.0\n",
719 | "731652 25676 1946 1.0\n",
720 | "731653 25676 20837 1.0\n",
721 | "731654 25676 20837 1.0\n",
722 | "731655 25676 1840 1.0\n",
723 | "731656 25676 6033 1.0\n",
724 | "731657 25676 3017 1.0\n",
725 | "731658 25676 20016 1.0\n",
726 | "731659 25676 3042 1.0\n",
727 | "731660 25676 19256 1.0\n",
728 | "731661 25677 5449 1.0\n",
729 | "731662 25677 14839 1.0\n",
730 | "731663 25677 11233 1.0\n",
731 | "731664 25677 166 1.0\n",
732 | "731665 25677 8523 1.0\n",
733 | "731666 25677 22087 1.0\n",
734 | "731667 25677 5615 1.0\n",
735 | "731668 25677 18223 1.0\n",
736 | "731669 25677 5453 1.0\n",
737 | "731670 25677 1589 1.0\n",
738 | "\n",
739 | "[486499 rows x 3 columns]"
740 | ]
741 | },
742 | "execution_count": 48,
743 | "metadata": {},
744 | "output_type": "execute_result"
745 | }
746 | ],
747 | "source": [
748 | "df.drop(df.index[df['rating'] == 0], axis=0, inplace=True)\n",
749 | "df"
750 | ]
751 | },
752 | {
753 | "cell_type": "code",
754 | "execution_count": 49,
755 | "metadata": {},
756 | "outputs": [],
757 | "source": [
758 | "# df.drop('mean', axis=1, inplace=True)\n",
759 | "df.drop('rating', axis=1, inplace=True)"
760 | ]
761 | },
762 | {
763 | "cell_type": "code",
764 | "execution_count": 50,
765 | "metadata": {
766 | "scrolled": true
767 | },
768 | "outputs": [
769 | {
770 | "data": {
771 | "text/plain": [
772 | "1115 819\n",
773 | "1252 809\n",
774 | "1864 613\n",
775 | "42 600\n",
776 | "151 570\n",
777 | "3309 553\n",
778 | "4149 550\n",
779 | "2253 531\n",
780 | "1901 524\n",
781 | "2954 517\n",
782 | "12 501\n",
783 | "106 497\n",
784 | "767 489\n",
785 | "4500 468\n",
786 | "9 468\n",
787 | "1447 460\n",
788 | "2324 458\n",
789 | "3923 448\n",
790 | "2111 438\n",
791 | "2730 431\n",
792 | "3822 413\n",
793 | "103 405\n",
794 | "1854 398\n",
795 | "3992 394\n",
796 | "1970 394\n",
797 | "621 393\n",
798 | "615 392\n",
799 | "510 388\n",
800 | "512 376\n",
801 | "829 374\n",
802 | " ... \n",
803 | "13373 1\n",
804 | "22115 1\n",
805 | "24162 1\n",
806 | "24602 1\n",
807 | "1129 1\n",
808 | "25290 1\n",
809 | "7109 1\n",
810 | "23501 1\n",
811 | "24094 1\n",
812 | "13504 1\n",
813 | "19663 1\n",
814 | "19660 1\n",
815 | "3130 1\n",
816 | "16238 1\n",
817 | "22150 1\n",
818 | "19811 1\n",
819 | "12861 1\n",
820 | "21710 1\n",
821 | "9664 1\n",
822 | "19555 1\n",
823 | "11119 1\n",
824 | "4968 1\n",
825 | "21199 1\n",
826 | "14274 1\n",
827 | "24527 1\n",
828 | "13072 1\n",
829 | "18975 1\n",
830 | "16926 1\n",
831 | "21088 1\n",
832 | "15213 1\n",
833 | "Name: itemId, Length: 24930, dtype: int64"
834 | ]
835 | },
836 | "execution_count": 50,
837 | "metadata": {},
838 | "output_type": "execute_result"
839 | }
840 | ],
841 | "source": [
842 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
843 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n",
844 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n",
845 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n",
846 | "df['itemId'].value_counts()"
847 | ]
848 | },
849 | {
850 | "cell_type": "code",
851 | "execution_count": 51,
852 | "metadata": {
853 | "scrolled": true
854 | },
855 | "outputs": [
856 | {
857 | "data": {
858 | "text/plain": [
859 | "3730 596\n",
860 | "1420 580\n",
861 | "13384 569\n",
862 | "13063 556\n",
863 | "12942 501\n",
864 | "1447 490\n",
865 | "3782 466\n",
866 | "1805 445\n",
867 | "13232 445\n",
868 | "3555 445\n",
869 | "4478 413\n",
870 | "2236 406\n",
871 | "3609 394\n",
872 | "2176 387\n",
873 | "363 366\n",
874 | "4870 337\n",
875 | "5793 325\n",
876 | "14002 322\n",
877 | "16099 305\n",
878 | "4041 298\n",
879 | "14302 293\n",
880 | "3324 285\n",
881 | "4246 285\n",
882 | "4468 282\n",
883 | "3862 277\n",
884 | "1944 275\n",
885 | "9384 274\n",
886 | "2975 274\n",
887 | "22 271\n",
888 | "13880 265\n",
889 | " ... \n",
890 | "2116 3\n",
891 | "18799 3\n",
892 | "25145 3\n",
893 | "18962 3\n",
894 | "2704 3\n",
895 | "6877 3\n",
896 | "7762 3\n",
897 | "2529 3\n",
898 | "19199 3\n",
899 | "25618 3\n",
900 | "18829 3\n",
901 | "25333 3\n",
902 | "13455 3\n",
903 | "25051 3\n",
904 | "10556 3\n",
905 | "11482 2\n",
906 | "24397 2\n",
907 | "11881 2\n",
908 | "25219 2\n",
909 | "12712 2\n",
910 | "25274 2\n",
911 | "10963 2\n",
912 | "1879 2\n",
913 | "17221 2\n",
914 | "8924 2\n",
915 | "25488 2\n",
916 | "10506 2\n",
917 | "19117 1\n",
918 | "25564 1\n",
919 | "1934 1\n",
920 | "Name: userId, Length: 16066, dtype: int64"
921 | ]
922 | },
923 | "execution_count": 51,
924 | "metadata": {},
925 | "output_type": "execute_result"
926 | }
927 | ],
928 | "source": [
929 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n",
930 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
931 | "df['userId'].value_counts()"
932 | ]
933 | },
934 | {
935 | "cell_type": "code",
936 | "execution_count": 52,
937 | "metadata": {
938 | "scrolled": true
939 | },
940 | "outputs": [
941 | {
942 | "data": {
943 | "text/plain": [
944 | "1252 774\n",
945 | "1115 761\n",
946 | "1864 583\n",
947 | "42 579\n",
948 | "151 548\n",
949 | "3309 527\n",
950 | "4149 525\n",
951 | "2253 509\n",
952 | "2954 497\n",
953 | "1901 492\n",
954 | "106 477\n",
955 | "12 475\n",
956 | "767 472\n",
957 | "9 448\n",
958 | "2324 438\n",
959 | "4500 435\n",
960 | "1447 435\n",
961 | "3923 431\n",
962 | "2111 424\n",
963 | "2730 404\n",
964 | "3822 402\n",
965 | "1854 388\n",
966 | "3992 386\n",
967 | "103 385\n",
968 | "615 378\n",
969 | "1970 377\n",
970 | "621 371\n",
971 | "510 366\n",
972 | "512 359\n",
973 | "829 355\n",
974 | " ... \n",
975 | "13381 6\n",
976 | "7042 6\n",
977 | "13678 6\n",
978 | "16639 6\n",
979 | "13223 6\n",
980 | "19481 6\n",
981 | "12670 6\n",
982 | "13092 6\n",
983 | "2550 6\n",
984 | "1671 6\n",
985 | "14907 6\n",
986 | "22576 6\n",
987 | "6843 6\n",
988 | "5984 6\n",
989 | "20163 6\n",
990 | "2359 6\n",
991 | "15841 6\n",
992 | "11745 6\n",
993 | "13587 5\n",
994 | "1637 5\n",
995 | "19117 5\n",
996 | "13374 5\n",
997 | "24956 5\n",
998 | "3065 5\n",
999 | "6304 5\n",
1000 | "13349 5\n",
1001 | "4306 5\n",
1002 | "24966 4\n",
1003 | "20899 4\n",
1004 | "13388 1\n",
1005 | "Name: itemId, Length: 10112, dtype: int64"
1006 | ]
1007 | },
1008 | "execution_count": 52,
1009 | "metadata": {},
1010 | "output_type": "execute_result"
1011 | }
1012 | ],
1013 | "source": [
1014 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n",
1015 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
1016 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n",
1017 | "df['itemId'].value_counts()"
1018 | ]
1019 | },
1020 | {
1021 | "cell_type": "code",
1022 | "execution_count": 53,
1023 | "metadata": {
1024 | "scrolled": true
1025 | },
1026 | "outputs": [
1027 | {
1028 | "data": {
1029 | "text/plain": [
1030 | "3730 584\n",
1031 | "1420 573\n",
1032 | "13384 559\n",
1033 | "13063 548\n",
1034 | "12942 499\n",
1035 | "1447 484\n",
1036 | "3782 450\n",
1037 | "1805 442\n",
1038 | "13232 442\n",
1039 | "3555 439\n",
1040 | "4478 407\n",
1041 | "2236 402\n",
1042 | "3609 385\n",
1043 | "2176 384\n",
1044 | "363 365\n",
1045 | "4870 333\n",
1046 | "5793 325\n",
1047 | "14002 318\n",
1048 | "16099 304\n",
1049 | "4041 291\n",
1050 | "14302 290\n",
1051 | "4246 284\n",
1052 | "3324 281\n",
1053 | "4468 276\n",
1054 | "1944 273\n",
1055 | "3862 271\n",
1056 | "2975 271\n",
1057 | "9384 270\n",
1058 | "13880 264\n",
1059 | "22 262\n",
1060 | " ... \n",
1061 | "467 8\n",
1062 | "17192 8\n",
1063 | "8484 8\n",
1064 | "9089 8\n",
1065 | "1106 8\n",
1066 | "8190 8\n",
1067 | "2997 8\n",
1068 | "3053 8\n",
1069 | "25089 8\n",
1070 | "3069 7\n",
1071 | "3035 7\n",
1072 | "3205 7\n",
1073 | "25341 7\n",
1074 | "25223 7\n",
1075 | "3141 7\n",
1076 | "25113 7\n",
1077 | "25175 7\n",
1078 | "25311 7\n",
1079 | "3203 7\n",
1080 | "2639 7\n",
1081 | "3097 7\n",
1082 | "3025 7\n",
1083 | "3138 7\n",
1084 | "2037 7\n",
1085 | "24994 6\n",
1086 | "25391 6\n",
1087 | "1131 6\n",
1088 | "3132 6\n",
1089 | "3146 4\n",
1090 | "25119 4\n",
1091 | "Name: userId, Length: 13099, dtype: int64"
1092 | ]
1093 | },
1094 | "execution_count": 53,
1095 | "metadata": {},
1096 | "output_type": "execute_result"
1097 | }
1098 | ],
1099 | "source": [
1100 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n",
1101 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
1102 | "df['userId'].value_counts()"
1103 | ]
1104 | },
1105 | {
1106 | "cell_type": "code",
1107 | "execution_count": 64,
1108 | "metadata": {
1109 | "scrolled": true
1110 | },
1111 | "outputs": [
1112 | {
1113 | "data": {
1114 | "text/plain": [
1115 | "1252 769\n",
1116 | "1115 756\n",
1117 | "1864 581\n",
1118 | "42 578\n",
1119 | "151 548\n",
1120 | "3309 525\n",
1121 | "4149 523\n",
1122 | "2253 507\n",
1123 | "2954 492\n",
1124 | "1901 489\n",
1125 | "106 472\n",
1126 | "767 471\n",
1127 | "12 470\n",
1128 | "9 445\n",
1129 | "2324 438\n",
1130 | "1447 434\n",
1131 | "4500 430\n",
1132 | "3923 428\n",
1133 | "2111 422\n",
1134 | "3822 400\n",
1135 | "2730 400\n",
1136 | "1854 388\n",
1137 | "3992 384\n",
1138 | "103 382\n",
1139 | "1970 377\n",
1140 | "615 373\n",
1141 | "621 369\n",
1142 | "510 366\n",
1143 | "512 358\n",
1144 | "2016 354\n",
1145 | " ... \n",
1146 | "13621 10\n",
1147 | "830 10\n",
1148 | "6679 10\n",
1149 | "9488 10\n",
1150 | "17460 10\n",
1151 | "17716 10\n",
1152 | "5620 10\n",
1153 | "11971 10\n",
1154 | "20463 10\n",
1155 | "12089 10\n",
1156 | "10296 10\n",
1157 | "788 10\n",
1158 | "15488 10\n",
1159 | "903 10\n",
1160 | "14272 10\n",
1161 | "8387 10\n",
1162 | "14208 10\n",
1163 | "20020 10\n",
1164 | "14611 10\n",
1165 | "24493 10\n",
1166 | "13536 10\n",
1167 | "7141 10\n",
1168 | "14139 10\n",
1169 | "8355 10\n",
1170 | "10979 10\n",
1171 | "6596 10\n",
1172 | "18743 10\n",
1173 | "2367 10\n",
1174 | "16950 10\n",
1175 | "2824 9\n",
1176 | "Name: itemId, Length: 9245, dtype: int64"
1177 | ]
1178 | },
1179 | "execution_count": 64,
1180 | "metadata": {},
1181 | "output_type": "execute_result"
1182 | }
1183 | ],
1184 | "source": [
1185 | "df.drop(df.index[df['user_freq'] < 10], inplace=True)\n",
1186 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
1187 | "df['item_freq'] = df.groupby('itemId')['itemId'].transform('count')\n",
1188 | "df['itemId'].value_counts()"
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "code",
1193 | "execution_count": 65,
1194 | "metadata": {
1195 | "scrolled": true
1196 | },
1197 | "outputs": [
1198 | {
1199 | "data": {
1200 | "text/plain": [
1201 | "3730 584\n",
1202 | "1420 572\n",
1203 | "13384 557\n",
1204 | "13063 548\n",
1205 | "12942 497\n",
1206 | "1447 482\n",
1207 | "3782 448\n",
1208 | "13232 442\n",
1209 | "1805 442\n",
1210 | "3555 439\n",
1211 | "4478 405\n",
1212 | "2236 401\n",
1213 | "3609 385\n",
1214 | "2176 384\n",
1215 | "363 364\n",
1216 | "4870 332\n",
1217 | "5793 324\n",
1218 | "14002 318\n",
1219 | "16099 303\n",
1220 | "4041 291\n",
1221 | "14302 288\n",
1222 | "4246 284\n",
1223 | "3324 279\n",
1224 | "4468 275\n",
1225 | "1944 273\n",
1226 | "2975 271\n",
1227 | "3862 270\n",
1228 | "9384 269\n",
1229 | "13880 264\n",
1230 | "22 262\n",
1231 | " ... \n",
1232 | "9240 10\n",
1233 | "14768 10\n",
1234 | "12115 10\n",
1235 | "10322 10\n",
1236 | "6235 10\n",
1237 | "22611 10\n",
1238 | "24336 10\n",
1239 | "13488 10\n",
1240 | "11443 10\n",
1241 | "19694 10\n",
1242 | "1462 10\n",
1243 | "11490 10\n",
1244 | "21948 10\n",
1245 | "14363 10\n",
1246 | "19438 10\n",
1247 | "2077 10\n",
1248 | "11346 10\n",
1249 | "23997 10\n",
1250 | "19294 10\n",
1251 | "11090 10\n",
1252 | "5812 10\n",
1253 | "16991 10\n",
1254 | "14256 10\n",
1255 | "10162 10\n",
1256 | "4437 10\n",
1257 | "22460 10\n",
1258 | "183 10\n",
1259 | "24763 10\n",
1260 | "22716 10\n",
1261 | "24704 10\n",
1262 | "Name: userId, Length: 12704, dtype: int64"
1263 | ]
1264 | },
1265 | "execution_count": 65,
1266 | "metadata": {},
1267 | "output_type": "execute_result"
1268 | }
1269 | ],
1270 | "source": [
1271 | "df.drop(df.index[df['item_freq'] < 10], inplace=True)\n",
1272 | "df['user_freq'] = df.groupby('userId')['userId'].transform('count')\n",
1273 | "df['userId'].value_counts()"
1274 | ]
1275 | },
1276 | {
1277 | "cell_type": "code",
1278 | "execution_count": 66,
1279 | "metadata": {
1280 | "scrolled": true
1281 | },
1282 | "outputs": [
1283 | {
1284 | "data": {
1285 | "text/plain": [
1286 | "3730 584\n",
1287 | "1420 572\n",
1288 | "13384 557\n",
1289 | "13063 548\n",
1290 | "12942 497\n",
1291 | "1447 482\n",
1292 | "3782 448\n",
1293 | "13232 442\n",
1294 | "1805 442\n",
1295 | "3555 439\n",
1296 | "4478 405\n",
1297 | "2236 401\n",
1298 | "3609 385\n",
1299 | "2176 384\n",
1300 | "363 364\n",
1301 | "4870 332\n",
1302 | "5793 324\n",
1303 | "14002 318\n",
1304 | "16099 303\n",
1305 | "4041 291\n",
1306 | "14302 288\n",
1307 | "4246 284\n",
1308 | "3324 279\n",
1309 | "4468 275\n",
1310 | "1944 273\n",
1311 | "2975 271\n",
1312 | "3862 270\n",
1313 | "9384 269\n",
1314 | "13880 264\n",
1315 | "22 262\n",
1316 | " ... \n",
1317 | "9240 10\n",
1318 | "14768 10\n",
1319 | "12115 10\n",
1320 | "10322 10\n",
1321 | "6235 10\n",
1322 | "22611 10\n",
1323 | "24336 10\n",
1324 | "13488 10\n",
1325 | "11443 10\n",
1326 | "19694 10\n",
1327 | "1462 10\n",
1328 | "11490 10\n",
1329 | "21948 10\n",
1330 | "14363 10\n",
1331 | "19438 10\n",
1332 | "2077 10\n",
1333 | "11346 10\n",
1334 | "23997 10\n",
1335 | "19294 10\n",
1336 | "11090 10\n",
1337 | "5812 10\n",
1338 | "16991 10\n",
1339 | "14256 10\n",
1340 | "10162 10\n",
1341 | "4437 10\n",
1342 | "22460 10\n",
1343 | "183 10\n",
1344 | "24763 10\n",
1345 | "22716 10\n",
1346 | "24704 10\n",
1347 | "Name: userId, Length: 12704, dtype: int64"
1348 | ]
1349 | },
1350 | "execution_count": 66,
1351 | "metadata": {},
1352 | "output_type": "execute_result"
1353 | }
1354 | ],
1355 | "source": [
1356 | "df['userId'].value_counts()"
1357 | ]
1358 | },
1359 | {
1360 | "cell_type": "code",
1361 | "execution_count": 67,
1362 | "metadata": {
1363 | "scrolled": true
1364 | },
1365 | "outputs": [
1366 | {
1367 | "data": {
1368 | "text/plain": [
1369 | "1252 769\n",
1370 | "1115 756\n",
1371 | "1864 581\n",
1372 | "42 578\n",
1373 | "151 548\n",
1374 | "3309 525\n",
1375 | "4149 523\n",
1376 | "2253 507\n",
1377 | "2954 492\n",
1378 | "1901 489\n",
1379 | "106 472\n",
1380 | "767 471\n",
1381 | "12 470\n",
1382 | "9 445\n",
1383 | "2324 438\n",
1384 | "1447 434\n",
1385 | "4500 430\n",
1386 | "3923 428\n",
1387 | "2111 422\n",
1388 | "3822 400\n",
1389 | "2730 400\n",
1390 | "1854 388\n",
1391 | "3992 384\n",
1392 | "103 382\n",
1393 | "1970 377\n",
1394 | "615 373\n",
1395 | "621 369\n",
1396 | "510 366\n",
1397 | "512 358\n",
1398 | "2016 354\n",
1399 | " ... \n",
1400 | "2401 10\n",
1401 | "14181 10\n",
1402 | "23326 10\n",
1403 | "22891 10\n",
1404 | "2657 10\n",
1405 | "13777 10\n",
1406 | "25708 10\n",
1407 | "9427 10\n",
1408 | "20073 10\n",
1409 | "10084 10\n",
1410 | "16231 10\n",
1411 | "10768 10\n",
1412 | "20701 10\n",
1413 | "16745 10\n",
1414 | "10852 10\n",
1415 | "5364 10\n",
1416 | "19560 10\n",
1417 | "12049 10\n",
1418 | "20840 10\n",
1419 | "16081 10\n",
1420 | "9488 10\n",
1421 | "20585 10\n",
1422 | "5620 10\n",
1423 | "4961 10\n",
1424 | "2914 10\n",
1425 | "9319 10\n",
1426 | "3426 10\n",
1427 | "13669 10\n",
1428 | "788 10\n",
1429 | "14172 10\n",
1430 | "Name: itemId, Length: 9244, dtype: int64"
1431 | ]
1432 | },
1433 | "execution_count": 67,
1434 | "metadata": {},
1435 | "output_type": "execute_result"
1436 | }
1437 | ],
1438 | "source": [
1439 | "df['itemId'].value_counts()"
1440 | ]
1441 | },
1442 | {
1443 | "cell_type": "code",
1444 | "execution_count": 68,
1445 | "metadata": {},
1446 | "outputs": [],
1447 | "source": [
1448 | "df.drop('user_freq', axis=1, inplace=True)\n",
1449 | "df.drop('item_freq', axis=1, inplace=True)\n",
1450 | "df.reset_index(drop=True, inplace=True)"
1451 | ]
1452 | },
1453 | {
1454 | "cell_type": "code",
1455 | "execution_count": 69,
1456 | "metadata": {},
1457 | "outputs": [
1458 | {
1459 | "name": "stdout",
1460 | "output_type": "stream",
1461 | "text": [
1462 | "start\n",
1463 | "u20000\n",
1464 | "m10000\n",
1465 | "m20000\n"
1466 | ]
1467 | }
1468 | ],
1469 | "source": [
1470 | "import numpy as np\n",
1471 | "user_table = np.zeros(np.max(df['userId']) + 1)\n",
1472 | "movie_table = np.zeros(np.max(df['itemId']) + 1)\n",
1473 | "user_set = set(df['userId'].tolist())\n",
1474 | "movie_set = set(df['itemId'].tolist())\n",
1475 | "print('start')\n",
1476 | "u = 1\n",
1477 | "for i in range(1, np.max(df['userId']) + 1):\n",
1478 | " if i in user_set:\n",
1479 | " user_table[i] = u\n",
1480 | " u += 1\n",
1481 | " if i % 20000 == 0:\n",
1482 | " print('u' + str(i))\n",
1483 | "m = 1\n",
1484 | "for i in range(1, np.max(df['itemId']) + 1):\n",
1485 | " if i in movie_set:\n",
1486 | " movie_table[i] = m\n",
1487 | " m += 1\n",
1488 | " if i % 10000 == 0:\n",
1489 | " print('m' + str(i))"
1490 | ]
1491 | },
1492 | {
1493 | "cell_type": "code",
1494 | "execution_count": 70,
1495 | "metadata": {},
1496 | "outputs": [
1497 | {
1498 | "name": "stdout",
1499 | "output_type": "stream",
1500 | "text": [
1501 | "12705\n",
1502 | "9245\n"
1503 | ]
1504 | }
1505 | ],
1506 | "source": [
1507 | "print u\n",
1508 | "print m"
1509 | ]
1510 | },
1511 | {
1512 | "cell_type": "code",
1513 | "execution_count": 71,
1514 | "metadata": {
1515 | "scrolled": true
1516 | },
1517 | "outputs": [
1518 | {
1519 | "name": "stdout",
1520 | "output_type": "stream",
1521 | "text": [
1522 | "0\n",
1523 | "100000\n",
1524 | "200000\n",
1525 | "300000\n"
1526 | ]
1527 | },
1528 | {
1529 | "data": {
1530 | "text/html": [
1531 | "\n",
1532 | "\n",
1545 | "
\n",
1546 | " \n",
1547 | " \n",
1548 | " | \n",
1549 | " userId | \n",
1550 | " itemId | \n",
1551 | "
\n",
1552 | " \n",
1553 | " \n",
1554 | " \n",
1555 | " | 0 | \n",
1556 | " 1 | \n",
1557 | " 1 | \n",
1558 | "
\n",
1559 | " \n",
1560 | " | 1 | \n",
1561 | " 1 | \n",
1562 | " 2 | \n",
1563 | "
\n",
1564 | " \n",
1565 | " | 2 | \n",
1566 | " 1 | \n",
1567 | " 4 | \n",
1568 | "
\n",
1569 | " \n",
1570 | " | 3 | \n",
1571 | " 1 | \n",
1572 | " 6 | \n",
1573 | "
\n",
1574 | " \n",
1575 | " | 4 | \n",
1576 | " 1 | \n",
1577 | " 8 | \n",
1578 | "
\n",
1579 | " \n",
1580 | " | 5 | \n",
1581 | " 1 | \n",
1582 | " 11 | \n",
1583 | "
\n",
1584 | " \n",
1585 | " | 6 | \n",
1586 | " 1 | \n",
1587 | " 12 | \n",
1588 | "
\n",
1589 | " \n",
1590 | " | 7 | \n",
1591 | " 1 | \n",
1592 | " 13 | \n",
1593 | "
\n",
1594 | " \n",
1595 | " | 8 | \n",
1596 | " 1 | \n",
1597 | " 16 | \n",
1598 | "
\n",
1599 | " \n",
1600 | " | 9 | \n",
1601 | " 1 | \n",
1602 | " 17 | \n",
1603 | "
\n",
1604 | " \n",
1605 | " | 10 | \n",
1606 | " 1 | \n",
1607 | " 18 | \n",
1608 | "
\n",
1609 | " \n",
1610 | " | 11 | \n",
1611 | " 1 | \n",
1612 | " 19 | \n",
1613 | "
\n",
1614 | " \n",
1615 | " | 12 | \n",
1616 | " 1 | \n",
1617 | " 21 | \n",
1618 | "
\n",
1619 | " \n",
1620 | " | 13 | \n",
1621 | " 1 | \n",
1622 | " 22 | \n",
1623 | "
\n",
1624 | " \n",
1625 | " | 14 | \n",
1626 | " 1 | \n",
1627 | " 23 | \n",
1628 | "
\n",
1629 | " \n",
1630 | " | 15 | \n",
1631 | " 1 | \n",
1632 | " 24 | \n",
1633 | "
\n",
1634 | " \n",
1635 | " | 16 | \n",
1636 | " 1 | \n",
1637 | " 27 | \n",
1638 | "
\n",
1639 | " \n",
1640 | " | 17 | \n",
1641 | " 1 | \n",
1642 | " 28 | \n",
1643 | "
\n",
1644 | " \n",
1645 | " | 18 | \n",
1646 | " 1 | \n",
1647 | " 28 | \n",
1648 | "
\n",
1649 | " \n",
1650 | " | 19 | \n",
1651 | " 1 | \n",
1652 | " 29 | \n",
1653 | "
\n",
1654 | " \n",
1655 | " | 20 | \n",
1656 | " 1 | \n",
1657 | " 30 | \n",
1658 | "
\n",
1659 | " \n",
1660 | " | 21 | \n",
1661 | " 1 | \n",
1662 | " 31 | \n",
1663 | "
\n",
1664 | " \n",
1665 | " | 22 | \n",
1666 | " 1 | \n",
1667 | " 33 | \n",
1668 | "
\n",
1669 | " \n",
1670 | " | 23 | \n",
1671 | " 1 | \n",
1672 | " 34 | \n",
1673 | "
\n",
1674 | " \n",
1675 | " | 24 | \n",
1676 | " 1 | \n",
1677 | " 35 | \n",
1678 | "
\n",
1679 | " \n",
1680 | " | 25 | \n",
1681 | " 1 | \n",
1682 | " 36 | \n",
1683 | "
\n",
1684 | " \n",
1685 | " | 26 | \n",
1686 | " 1 | \n",
1687 | " 39 | \n",
1688 | "
\n",
1689 | " \n",
1690 | " | 27 | \n",
1691 | " 1 | \n",
1692 | " 40 | \n",
1693 | "
\n",
1694 | " \n",
1695 | " | 28 | \n",
1696 | " 1 | \n",
1697 | " 42 | \n",
1698 | "
\n",
1699 | " \n",
1700 | " | 29 | \n",
1701 | " 1 | \n",
1702 | " 45 | \n",
1703 | "
\n",
1704 | " \n",
1705 | " | ... | \n",
1706 | " ... | \n",
1707 | " ... | \n",
1708 | "
\n",
1709 | " \n",
1710 | " | 318284 | \n",
1711 | " 12702 | \n",
1712 | " 6209 | \n",
1713 | "
\n",
1714 | " \n",
1715 | " | 318285 | \n",
1716 | " 12702 | \n",
1717 | " 4980 | \n",
1718 | "
\n",
1719 | " \n",
1720 | " | 318286 | \n",
1721 | " 12702 | \n",
1722 | " 7463 | \n",
1723 | "
\n",
1724 | " \n",
1725 | " | 318287 | \n",
1726 | " 12702 | \n",
1727 | " 5212 | \n",
1728 | "
\n",
1729 | " \n",
1730 | " | 318288 | \n",
1731 | " 12702 | \n",
1732 | " 5212 | \n",
1733 | "
\n",
1734 | " \n",
1735 | " | 318289 | \n",
1736 | " 12702 | \n",
1737 | " 5212 | \n",
1738 | "
\n",
1739 | " \n",
1740 | " | 318290 | \n",
1741 | " 12702 | \n",
1742 | " 5212 | \n",
1743 | "
\n",
1744 | " \n",
1745 | " | 318291 | \n",
1746 | " 12702 | \n",
1747 | " 5191 | \n",
1748 | "
\n",
1749 | " \n",
1750 | " | 318292 | \n",
1751 | " 12702 | \n",
1752 | " 8499 | \n",
1753 | "
\n",
1754 | " \n",
1755 | " | 318293 | \n",
1756 | " 12703 | \n",
1757 | " 5006 | \n",
1758 | "
\n",
1759 | " \n",
1760 | " | 318294 | \n",
1761 | " 12703 | \n",
1762 | " 7926 | \n",
1763 | "
\n",
1764 | " \n",
1765 | " | 318295 | \n",
1766 | " 12703 | \n",
1767 | " 7787 | \n",
1768 | "
\n",
1769 | " \n",
1770 | " | 318296 | \n",
1771 | " 12703 | \n",
1772 | " 4000 | \n",
1773 | "
\n",
1774 | " \n",
1775 | " | 318297 | \n",
1776 | " 12703 | \n",
1777 | " 4005 | \n",
1778 | "
\n",
1779 | " \n",
1780 | " | 318298 | \n",
1781 | " 12703 | \n",
1782 | " 5008 | \n",
1783 | "
\n",
1784 | " \n",
1785 | " | 318299 | \n",
1786 | " 12703 | \n",
1787 | " 5008 | \n",
1788 | "
\n",
1789 | " \n",
1790 | " | 318300 | \n",
1791 | " 12703 | \n",
1792 | " 6521 | \n",
1793 | "
\n",
1794 | " \n",
1795 | " | 318301 | \n",
1796 | " 12703 | \n",
1797 | " 8368 | \n",
1798 | "
\n",
1799 | " \n",
1800 | " | 318302 | \n",
1801 | " 12703 | \n",
1802 | " 2852 | \n",
1803 | "
\n",
1804 | " \n",
1805 | " | 318303 | \n",
1806 | " 12703 | \n",
1807 | " 9009 | \n",
1808 | "
\n",
1809 | " \n",
1810 | " | 318304 | \n",
1811 | " 12704 | \n",
1812 | " 4829 | \n",
1813 | "
\n",
1814 | " \n",
1815 | " | 318305 | \n",
1816 | " 12704 | \n",
1817 | " 1103 | \n",
1818 | "
\n",
1819 | " \n",
1820 | " | 318306 | \n",
1821 | " 12704 | \n",
1822 | " 8649 | \n",
1823 | "
\n",
1824 | " \n",
1825 | " | 318307 | \n",
1826 | " 12704 | \n",
1827 | " 8649 | \n",
1828 | "
\n",
1829 | " \n",
1830 | " | 318308 | \n",
1831 | " 12704 | \n",
1832 | " 1026 | \n",
1833 | "
\n",
1834 | " \n",
1835 | " | 318309 | \n",
1836 | " 12704 | \n",
1837 | " 3684 | \n",
1838 | "
\n",
1839 | " \n",
1840 | " | 318310 | \n",
1841 | " 12704 | \n",
1842 | " 1709 | \n",
1843 | "
\n",
1844 | " \n",
1845 | " | 318311 | \n",
1846 | " 12704 | \n",
1847 | " 8506 | \n",
1848 | "
\n",
1849 | " \n",
1850 | " | 318312 | \n",
1851 | " 12704 | \n",
1852 | " 1727 | \n",
1853 | "
\n",
1854 | " \n",
1855 | " | 318313 | \n",
1856 | " 12704 | \n",
1857 | " 8406 | \n",
1858 | "
\n",
1859 | " \n",
1860 | "
\n",
1861 | "
318314 rows × 2 columns
\n",
1862 | "
"
1863 | ],
1864 | "text/plain": [
1865 | " userId itemId\n",
1866 | "0 1 1\n",
1867 | "1 1 2\n",
1868 | "2 1 4\n",
1869 | "3 1 6\n",
1870 | "4 1 8\n",
1871 | "5 1 11\n",
1872 | "6 1 12\n",
1873 | "7 1 13\n",
1874 | "8 1 16\n",
1875 | "9 1 17\n",
1876 | "10 1 18\n",
1877 | "11 1 19\n",
1878 | "12 1 21\n",
1879 | "13 1 22\n",
1880 | "14 1 23\n",
1881 | "15 1 24\n",
1882 | "16 1 27\n",
1883 | "17 1 28\n",
1884 | "18 1 28\n",
1885 | "19 1 29\n",
1886 | "20 1 30\n",
1887 | "21 1 31\n",
1888 | "22 1 33\n",
1889 | "23 1 34\n",
1890 | "24 1 35\n",
1891 | "25 1 36\n",
1892 | "26 1 39\n",
1893 | "27 1 40\n",
1894 | "28 1 42\n",
1895 | "29 1 45\n",
1896 | "... ... ...\n",
1897 | "318284 12702 6209\n",
1898 | "318285 12702 4980\n",
1899 | "318286 12702 7463\n",
1900 | "318287 12702 5212\n",
1901 | "318288 12702 5212\n",
1902 | "318289 12702 5212\n",
1903 | "318290 12702 5212\n",
1904 | "318291 12702 5191\n",
1905 | "318292 12702 8499\n",
1906 | "318293 12703 5006\n",
1907 | "318294 12703 7926\n",
1908 | "318295 12703 7787\n",
1909 | "318296 12703 4000\n",
1910 | "318297 12703 4005\n",
1911 | "318298 12703 5008\n",
1912 | "318299 12703 5008\n",
1913 | "318300 12703 6521\n",
1914 | "318301 12703 8368\n",
1915 | "318302 12703 2852\n",
1916 | "318303 12703 9009\n",
1917 | "318304 12704 4829\n",
1918 | "318305 12704 1103\n",
1919 | "318306 12704 8649\n",
1920 | "318307 12704 8649\n",
1921 | "318308 12704 1026\n",
1922 | "318309 12704 3684\n",
1923 | "318310 12704 1709\n",
1924 | "318311 12704 8506\n",
1925 | "318312 12704 1727\n",
1926 | "318313 12704 8406\n",
1927 | "\n",
1928 | "[318314 rows x 2 columns]"
1929 | ]
1930 | },
1931 | "execution_count": 71,
1932 | "metadata": {},
1933 | "output_type": "execute_result"
1934 | }
1935 | ],
1936 | "source": [
1937 | "tmp = df.values\n",
1938 | "for i in range(len(df)):\n",
1939 | " tmp[i, 0] = user_table[int(tmp[i, 0])]\n",
1940 | " tmp[i, 1] = movie_table[int(tmp[i, 1])]\n",
1941 | " if i % 100000 == 0:\n",
1942 | " print i\n",
1943 | "df = pd.DataFrame(tmp, columns=['userId', 'itemId'])\n",
1944 | "df"
1945 | ]
1946 | },
1947 | {
1948 | "cell_type": "code",
1949 | "execution_count": 72,
1950 | "metadata": {},
1951 | "outputs": [
1952 | {
1953 | "name": "stdout",
1954 | "output_type": "stream",
1955 | "text": [
1956 | "number of users = 12704\n",
1957 | "number of items = 9244\n",
1958 | "sparsity = 0.002710536864\n"
1959 | ]
1960 | }
1961 | ],
1962 | "source": [
1963 | "num_user = u - 1\n",
1964 | "num_movie = m - 1\n",
1965 | "print('number of users = ' + str(num_user))\n",
1966 | "print('number of items = ' + str(num_movie))\n",
1967 | "sparsity = len(df) * 1.0 / (num_user * num_movie)\n",
1968 | "print('sparsity = ' + str(sparsity))"
1969 | ]
1970 | },
1971 | {
1972 | "cell_type": "code",
1973 | "execution_count": 35,
1974 | "metadata": {},
1975 | "outputs": [],
1976 | "source": [
1977 | "df.to_csv('./data.csv', index=False)"
1978 | ]
1979 | },
1980 | {
1981 | "cell_type": "code",
1982 | "execution_count": 73,
1983 | "metadata": {},
1984 | "outputs": [],
1985 | "source": [
1986 | "train_df = df.copy()\n",
1987 | "test_df = df.copy()"
1988 | ]
1989 | },
1990 | {
1991 | "cell_type": "code",
1992 | "execution_count": 37,
1993 | "metadata": {},
1994 | "outputs": [
1995 | {
1996 | "name": "stdout",
1997 | "output_type": "stream",
1998 | "text": [
1999 | "500\n",
2000 | "1000\n",
2001 | "1500\n",
2002 | "2000\n",
2003 | "2500\n",
2004 | "3000\n",
2005 | "3500\n",
2006 | "4000\n",
2007 | "4500\n",
2008 | "5000\n",
2009 | "5500\n",
2010 | "6000\n",
2011 | "6500\n",
2012 | "7000\n",
2013 | "7500\n",
2014 | "8000\n",
2015 | "8500\n",
2016 | "9000\n",
2017 | "9500\n",
2018 | "10000\n",
2019 | "10500\n",
2020 | "11000\n",
2021 | "11500\n",
2022 | "12000\n",
2023 | "12500\n",
2024 | "13000\n",
2025 | "13500\n",
2026 | "14000\n",
2027 | "14500\n",
2028 | "15000\n",
2029 | "15500\n",
2030 | "16000\n",
2031 | "16500\n",
2032 | "17000\n",
2033 | "17500\n",
2034 | "18000\n",
2035 | "18500\n",
2036 | "19000\n",
2037 | "19500\n",
2038 | "20000\n",
2039 | "20500\n",
2040 | "21000\n",
2041 | "21500\n",
2042 | "22000\n",
2043 | "22500\n",
2044 | "23000\n",
2045 | "23500\n",
2046 | "24000\n",
2047 | "24500\n",
2048 | "25000\n",
2049 | "25500\n",
2050 | "26000\n",
2051 | "26500\n",
2052 | "27000\n",
2053 | "27500\n",
2054 | "28000\n",
2055 | "28500\n",
2056 | "29000\n",
2057 | "29500\n",
2058 | "30000\n",
2059 | "30500\n",
2060 | "31000\n",
2061 | "31500\n",
2062 | "32000\n",
2063 | "32500\n",
2064 | "33000\n",
2065 | "33500\n",
2066 | "34000\n",
2067 | "34500\n",
2068 | "35000\n",
2069 | "35500\n",
2070 | "36000\n",
2071 | "36500\n",
2072 | "37000\n",
2073 | "37500\n",
2074 | "38000\n",
2075 | "38500\n",
2076 | "39000\n",
2077 | "39500\n",
2078 | "40000\n",
2079 | "40500\n",
2080 | "41000\n",
2081 | "41500\n",
2082 | "42000\n",
2083 | "42500\n",
2084 | "43000\n",
2085 | "43500\n",
2086 | "44000\n",
2087 | "44500\n",
2088 | "45000\n",
2089 | "45500\n",
2090 | "46000\n",
2091 | "46500\n",
2092 | "47000\n",
2093 | "47500\n",
2094 | "48000\n",
2095 | "48500\n",
2096 | "49000\n",
2097 | "49500\n",
2098 | "50000\n",
2099 | "50500\n",
2100 | "51000\n",
2101 | "51500\n",
2102 | "52000\n",
2103 | "52500\n",
2104 | "53000\n",
2105 | "53500\n",
2106 | "54000\n",
2107 | "54500\n",
2108 | "55000\n",
2109 | "55500\n",
2110 | "56000\n",
2111 | "56500\n",
2112 | "57000\n",
2113 | "57500\n",
2114 | "58000\n",
2115 | "58500\n",
2116 | "59000\n",
2117 | "59500\n",
2118 | "60000\n",
2119 | "60500\n",
2120 | "61000\n",
2121 | "61500\n",
2122 | "62000\n",
2123 | "62500\n",
2124 | "63000\n",
2125 | "63500\n",
2126 | "64000\n",
2127 | "64500\n",
2128 | "65000\n",
2129 | "65500\n",
2130 | "66000\n",
2131 | "66500\n",
2132 | "67000\n",
2133 | "67500\n",
2134 | "68000\n",
2135 | "68500\n",
2136 | "69000\n"
2137 | ]
2138 | }
2139 | ],
2140 | "source": [
2141 | "train_ratio = 0.8\n",
2142 | "test_ratio = 1 - train_ratio\n",
2143 | "num_users = np.max(df['userId'])\n",
2144 | "num_items = np.max(df['itemId'])\n",
2145 | "test_idx = []\n",
2146 | "for u in range(1, num_users+1):\n",
2147 | " u_idx = train_df.index[train_df['userId'] == u]\n",
2148 | " idx_len = len(u_idx)\n",
2149 | " test_len = int(idx_len * test_ratio)\n",
2150 | " if test_len == 0:\n",
2151 | " test_len = 1\n",
2152 | " tmp = np.random.choice(u_idx, size=test_len, replace=False)\n",
2153 | " test_idx += tmp.tolist()\n",
2154 | " if u % 500 == 0:\n",
2155 | " print u"
2156 | ]
2157 | },
2158 | {
2159 | "cell_type": "code",
2160 | "execution_count": 38,
2161 | "metadata": {},
2162 | "outputs": [],
2163 | "source": [
2164 | "test_set = set(test_idx)\n",
2165 | "train_set = set(range(len(df)))\n",
2166 | "train_set -= test_set\n",
2167 | "train_idx = list(train_set)\n",
2168 | "train_df.drop(test_idx, axis=0, inplace=True)\n",
2169 | "test_df.drop(train_idx, axis=0, inplace=True)"
2170 | ]
2171 | },
2172 | {
2173 | "cell_type": "code",
2174 | "execution_count": 39,
2175 | "metadata": {},
2176 | "outputs": [
2177 | {
2178 | "name": "stdout",
2179 | "output_type": "stream",
2180 | "text": [
2181 | "1134946\n",
2182 | "4746650\n"
2183 | ]
2184 | }
2185 | ],
2186 | "source": [
2187 | "print len(test_df)\n",
2188 | "print len(train_df)"
2189 | ]
2190 | },
2191 | {
2192 | "cell_type": "code",
2193 | "execution_count": 40,
2194 | "metadata": {
2195 | "scrolled": true
2196 | },
2197 | "outputs": [
2198 | {
2199 | "data": {
2200 | "text/html": [
2201 | "\n",
2202 | "\n",
2215 | "
\n",
2216 | " \n",
2217 | " \n",
2218 | " | \n",
2219 | " userId | \n",
2220 | " movieId | \n",
2221 | "
\n",
2222 | " \n",
2223 | " \n",
2224 | " \n",
2225 | " | 0 | \n",
2226 | " 1 | \n",
2227 | " 119 | \n",
2228 | "
\n",
2229 | " \n",
2230 | " | 1 | \n",
2231 | " 1 | \n",
2232 | " 180 | \n",
2233 | "
\n",
2234 | " \n",
2235 | " | 2 | \n",
2236 | " 1 | \n",
2237 | " 225 | \n",
2238 | "
\n",
2239 | " \n",
2240 | " | 3 | \n",
2241 | " 1 | \n",
2242 | " 286 | \n",
2243 | "
\n",
2244 | " \n",
2245 | " | 4 | \n",
2246 | " 1 | \n",
2247 | " 310 | \n",
2248 | "
\n",
2249 | " \n",
2250 | " | 5 | \n",
2251 | " 1 | \n",
2252 | " 322 | \n",
2253 | "
\n",
2254 | " \n",
2255 | " | 6 | \n",
2256 | " 1 | \n",
2257 | " 348 | \n",
2258 | "
\n",
2259 | " \n",
2260 | " | 7 | \n",
2261 | " 1 | \n",
2262 | " 349 | \n",
2263 | "
\n",
2264 | " \n",
2265 | " | 8 | \n",
2266 | " 1 | \n",
2267 | " 355 | \n",
2268 | "
\n",
2269 | " \n",
2270 | " | 9 | \n",
2271 | " 1 | \n",
2272 | " 357 | \n",
2273 | "
\n",
2274 | " \n",
2275 | " | 12 | \n",
2276 | " 1 | \n",
2277 | " 408 | \n",
2278 | "
\n",
2279 | " \n",
2280 | " | 14 | \n",
2281 | " 1 | \n",
2282 | " 468 | \n",
2283 | "
\n",
2284 | " \n",
2285 | " | 15 | \n",
2286 | " 1 | \n",
2287 | " 508 | \n",
2288 | "
\n",
2289 | " \n",
2290 | " | 16 | \n",
2291 | " 1 | \n",
2292 | " 527 | \n",
2293 | "
\n",
2294 | " \n",
2295 | " | 17 | \n",
2296 | " 1 | \n",
2297 | " 570 | \n",
2298 | "
\n",
2299 | " \n",
2300 | " | 18 | \n",
2301 | " 1 | \n",
2302 | " 572 | \n",
2303 | "
\n",
2304 | " \n",
2305 | " | 20 | \n",
2306 | " 1 | \n",
2307 | " 577 | \n",
2308 | "
\n",
2309 | " \n",
2310 | " | 21 | \n",
2311 | " 1 | \n",
2312 | " 597 | \n",
2313 | "
\n",
2314 | " \n",
2315 | " | 22 | \n",
2316 | " 2 | \n",
2317 | " 107 | \n",
2318 | "
\n",
2319 | " \n",
2320 | " | 23 | \n",
2321 | " 2 | \n",
2322 | " 146 | \n",
2323 | "
\n",
2324 | " \n",
2325 | " | 24 | \n",
2326 | " 2 | \n",
2327 | " 208 | \n",
2328 | "
\n",
2329 | " \n",
2330 | " | 26 | \n",
2331 | " 2 | \n",
2332 | " 1083 | \n",
2333 | "
\n",
2334 | " \n",
2335 | " | 27 | \n",
2336 | " 2 | \n",
2337 | " 1173 | \n",
2338 | "
\n",
2339 | " \n",
2340 | " | 28 | \n",
2341 | " 2 | \n",
2342 | " 1179 | \n",
2343 | "
\n",
2344 | " \n",
2345 | " | 29 | \n",
2346 | " 2 | \n",
2347 | " 1203 | \n",
2348 | "
\n",
2349 | " \n",
2350 | " | 31 | \n",
2351 | " 2 | \n",
2352 | " 1461 | \n",
2353 | "
\n",
2354 | " \n",
2355 | " | 32 | \n",
2356 | " 2 | \n",
2357 | " 1491 | \n",
2358 | "
\n",
2359 | " \n",
2360 | " | 34 | \n",
2361 | " 2 | \n",
2362 | " 3152 | \n",
2363 | "
\n",
2364 | " \n",
2365 | " | 35 | \n",
2366 | " 2 | \n",
2367 | " 3403 | \n",
2368 | "
\n",
2369 | " \n",
2370 | " | 36 | \n",
2371 | " 2 | \n",
2372 | " 4199 | \n",
2373 | "
\n",
2374 | " \n",
2375 | " | ... | \n",
2376 | " ... | \n",
2377 | " ... | \n",
2378 | "
\n",
2379 | " \n",
2380 | " | 5881559 | \n",
2381 | " 69412 | \n",
2382 | " 312 | \n",
2383 | "
\n",
2384 | " \n",
2385 | " | 5881560 | \n",
2386 | " 69412 | \n",
2387 | " 332 | \n",
2388 | "
\n",
2389 | " \n",
2390 | " | 5881561 | \n",
2391 | " 69412 | \n",
2392 | " 335 | \n",
2393 | "
\n",
2394 | " \n",
2395 | " | 5881562 | \n",
2396 | " 69412 | \n",
2397 | " 341 | \n",
2398 | "
\n",
2399 | " \n",
2400 | " | 5881563 | \n",
2401 | " 69412 | \n",
2402 | " 373 | \n",
2403 | "
\n",
2404 | " \n",
2405 | " | 5881564 | \n",
2406 | " 69412 | \n",
2407 | " 374 | \n",
2408 | "
\n",
2409 | " \n",
2410 | " | 5881566 | \n",
2411 | " 69412 | \n",
2412 | " 572 | \n",
2413 | "
\n",
2414 | " \n",
2415 | " | 5881567 | \n",
2416 | " 69412 | \n",
2417 | " 574 | \n",
2418 | "
\n",
2419 | " \n",
2420 | " | 5881569 | \n",
2421 | " 69412 | \n",
2422 | " 578 | \n",
2423 | "
\n",
2424 | " \n",
2425 | " | 5881570 | \n",
2426 | " 69413 | \n",
2427 | " 107 | \n",
2428 | "
\n",
2429 | " \n",
2430 | " | 5881572 | \n",
2431 | " 69413 | \n",
2432 | " 254 | \n",
2433 | "
\n",
2434 | " \n",
2435 | " | 5881573 | \n",
2436 | " 69413 | \n",
2437 | " 310 | \n",
2438 | "
\n",
2439 | " \n",
2440 | " | 5881575 | \n",
2441 | " 69413 | \n",
2442 | " 573 | \n",
2443 | "
\n",
2444 | " \n",
2445 | " | 5881576 | \n",
2446 | " 69413 | \n",
2447 | " 741 | \n",
2448 | "
\n",
2449 | " \n",
2450 | " | 5881577 | \n",
2451 | " 69413 | \n",
2452 | " 749 | \n",
2453 | "
\n",
2454 | " \n",
2455 | " | 5881578 | \n",
2456 | " 69413 | \n",
2457 | " 783 | \n",
2458 | "
\n",
2459 | " \n",
2460 | " | 5881579 | \n",
2461 | " 69413 | \n",
2462 | " 848 | \n",
2463 | "
\n",
2464 | " \n",
2465 | " | 5881580 | \n",
2466 | " 69413 | \n",
2467 | " 1020 | \n",
2468 | "
\n",
2469 | " \n",
2470 | " | 5881581 | \n",
2471 | " 69413 | \n",
2472 | " 1072 | \n",
2473 | "
\n",
2474 | " \n",
2475 | " | 5881582 | \n",
2476 | " 69413 | \n",
2477 | " 1126 | \n",
2478 | "
\n",
2479 | " \n",
2480 | " | 5881584 | \n",
2481 | " 69413 | \n",
2482 | " 1139 | \n",
2483 | "
\n",
2484 | " \n",
2485 | " | 5881585 | \n",
2486 | " 69413 | \n",
2487 | " 1143 | \n",
2488 | "
\n",
2489 | " \n",
2490 | " | 5881587 | \n",
2491 | " 69413 | \n",
2492 | " 1281 | \n",
2493 | "
\n",
2494 | " \n",
2495 | " | 5881588 | \n",
2496 | " 69413 | \n",
2497 | " 1314 | \n",
2498 | "
\n",
2499 | " \n",
2500 | " | 5881589 | \n",
2501 | " 69413 | \n",
2502 | " 1430 | \n",
2503 | "
\n",
2504 | " \n",
2505 | " | 5881591 | \n",
2506 | " 69413 | \n",
2507 | " 1657 | \n",
2508 | "
\n",
2509 | " \n",
2510 | " | 5881592 | \n",
2511 | " 69413 | \n",
2512 | " 1755 | \n",
2513 | "
\n",
2514 | " \n",
2515 | " | 5881593 | \n",
2516 | " 69413 | \n",
2517 | " 1758 | \n",
2518 | "
\n",
2519 | " \n",
2520 | " | 5881594 | \n",
2521 | " 69413 | \n",
2522 | " 1865 | \n",
2523 | "
\n",
2524 | " \n",
2525 | " | 5881595 | \n",
2526 | " 69413 | \n",
2527 | " 2116 | \n",
2528 | "
\n",
2529 | " \n",
2530 | "
\n",
2531 | "
4746650 rows × 2 columns
\n",
2532 | "
"
2533 | ],
2534 | "text/plain": [
2535 | " userId movieId\n",
2536 | "0 1 119\n",
2537 | "1 1 180\n",
2538 | "2 1 225\n",
2539 | "3 1 286\n",
2540 | "4 1 310\n",
2541 | "5 1 322\n",
2542 | "6 1 348\n",
2543 | "7 1 349\n",
2544 | "8 1 355\n",
2545 | "9 1 357\n",
2546 | "12 1 408\n",
2547 | "14 1 468\n",
2548 | "15 1 508\n",
2549 | "16 1 527\n",
2550 | "17 1 570\n",
2551 | "18 1 572\n",
2552 | "20 1 577\n",
2553 | "21 1 597\n",
2554 | "22 2 107\n",
2555 | "23 2 146\n",
2556 | "24 2 208\n",
2557 | "26 2 1083\n",
2558 | "27 2 1173\n",
2559 | "28 2 1179\n",
2560 | "29 2 1203\n",
2561 | "31 2 1461\n",
2562 | "32 2 1491\n",
2563 | "34 2 3152\n",
2564 | "35 2 3403\n",
2565 | "36 2 4199\n",
2566 | "... ... ...\n",
2567 | "5881559 69412 312\n",
2568 | "5881560 69412 332\n",
2569 | "5881561 69412 335\n",
2570 | "5881562 69412 341\n",
2571 | "5881563 69412 373\n",
2572 | "5881564 69412 374\n",
2573 | "5881566 69412 572\n",
2574 | "5881567 69412 574\n",
2575 | "5881569 69412 578\n",
2576 | "5881570 69413 107\n",
2577 | "5881572 69413 254\n",
2578 | "5881573 69413 310\n",
2579 | "5881575 69413 573\n",
2580 | "5881576 69413 741\n",
2581 | "5881577 69413 749\n",
2582 | "5881578 69413 783\n",
2583 | "5881579 69413 848\n",
2584 | "5881580 69413 1020\n",
2585 | "5881581 69413 1072\n",
2586 | "5881582 69413 1126\n",
2587 | "5881584 69413 1139\n",
2588 | "5881585 69413 1143\n",
2589 | "5881587 69413 1281\n",
2590 | "5881588 69413 1314\n",
2591 | "5881589 69413 1430\n",
2592 | "5881591 69413 1657\n",
2593 | "5881592 69413 1755\n",
2594 | "5881593 69413 1758\n",
2595 | "5881594 69413 1865\n",
2596 | "5881595 69413 2116\n",
2597 | "\n",
2598 | "[4746650 rows x 2 columns]"
2599 | ]
2600 | },
2601 | "execution_count": 40,
2602 | "metadata": {},
2603 | "output_type": "execute_result"
2604 | }
2605 | ],
2606 | "source": [
2607 | "train_df"
2608 | ]
2609 | },
2610 | {
2611 | "cell_type": "code",
2612 | "execution_count": 41,
2613 | "metadata": {
2614 | "scrolled": true
2615 | },
2616 | "outputs": [
2617 | {
2618 | "data": {
2619 | "text/html": [
2620 | "\n",
2621 | "\n",
2634 | "
\n",
2635 | " \n",
2636 | " \n",
2637 | " | \n",
2638 | " userId | \n",
2639 | " movieId | \n",
2640 | "
\n",
2641 | " \n",
2642 | " \n",
2643 | " \n",
2644 | " | 10 | \n",
2645 | " 1 | \n",
2646 | " 363 | \n",
2647 | "
\n",
2648 | " \n",
2649 | " | 11 | \n",
2650 | " 1 | \n",
2651 | " 370 | \n",
2652 | "
\n",
2653 | " \n",
2654 | " | 13 | \n",
2655 | " 1 | \n",
2656 | " 454 | \n",
2657 | "
\n",
2658 | " \n",
2659 | " | 19 | \n",
2660 | " 1 | \n",
2661 | " 573 | \n",
2662 | "
\n",
2663 | " \n",
2664 | " | 25 | \n",
2665 | " 2 | \n",
2666 | " 574 | \n",
2667 | "
\n",
2668 | " \n",
2669 | " | 30 | \n",
2670 | " 2 | \n",
2671 | " 1329 | \n",
2672 | "
\n",
2673 | " \n",
2674 | " | 33 | \n",
2675 | " 2 | \n",
2676 | " 1562 | \n",
2677 | "
\n",
2678 | " \n",
2679 | " | 45 | \n",
2680 | " 2 | \n",
2681 | " 7220 | \n",
2682 | "
\n",
2683 | " \n",
2684 | " | 46 | \n",
2685 | " 2 | \n",
2686 | " 7224 | \n",
2687 | "
\n",
2688 | " \n",
2689 | " | 50 | \n",
2690 | " 3 | \n",
2691 | " 34 | \n",
2692 | "
\n",
2693 | " \n",
2694 | " | 54 | \n",
2695 | " 3 | \n",
2696 | " 156 | \n",
2697 | "
\n",
2698 | " \n",
2699 | " | 62 | \n",
2700 | " 3 | \n",
2701 | " 468 | \n",
2702 | "
\n",
2703 | " \n",
2704 | " | 67 | \n",
2705 | " 3 | \n",
2706 | " 573 | \n",
2707 | "
\n",
2708 | " \n",
2709 | " | 73 | \n",
2710 | " 4 | \n",
2711 | " 47 | \n",
2712 | "
\n",
2713 | " \n",
2714 | " | 75 | \n",
2715 | " 4 | \n",
2716 | " 108 | \n",
2717 | "
\n",
2718 | " \n",
2719 | " | 81 | \n",
2720 | " 4 | \n",
2721 | " 327 | \n",
2722 | "
\n",
2723 | " \n",
2724 | " | 88 | \n",
2725 | " 4 | \n",
2726 | " 520 | \n",
2727 | "
\n",
2728 | " \n",
2729 | " | 90 | \n",
2730 | " 4 | \n",
2731 | " 529 | \n",
2732 | "
\n",
2733 | " \n",
2734 | " | 98 | \n",
2735 | " 4 | \n",
2736 | " 869 | \n",
2737 | "
\n",
2738 | " \n",
2739 | " | 100 | \n",
2740 | " 4 | \n",
2741 | " 873 | \n",
2742 | "
\n",
2743 | " \n",
2744 | " | 101 | \n",
2745 | " 4 | \n",
2746 | " 876 | \n",
2747 | "
\n",
2748 | " \n",
2749 | " | 104 | \n",
2750 | " 4 | \n",
2751 | " 990 | \n",
2752 | "
\n",
2753 | " \n",
2754 | " | 115 | \n",
2755 | " 4 | \n",
2756 | " 1136 | \n",
2757 | "
\n",
2758 | " \n",
2759 | " | 124 | \n",
2760 | " 4 | \n",
2761 | " 1207 | \n",
2762 | "
\n",
2763 | " \n",
2764 | " | 132 | \n",
2765 | " 5 | \n",
2766 | " 1124 | \n",
2767 | "
\n",
2768 | " \n",
2769 | " | 139 | \n",
2770 | " 5 | \n",
2771 | " 1319 | \n",
2772 | "
\n",
2773 | " \n",
2774 | " | 143 | \n",
2775 | " 5 | \n",
2776 | " 1521 | \n",
2777 | "
\n",
2778 | " \n",
2779 | " | 145 | \n",
2780 | " 5 | \n",
2781 | " 1618 | \n",
2782 | "
\n",
2783 | " \n",
2784 | " | 153 | \n",
2785 | " 5 | \n",
2786 | " 3699 | \n",
2787 | "
\n",
2788 | " \n",
2789 | " | 158 | \n",
2790 | " 6 | \n",
2791 | " 99 | \n",
2792 | "
\n",
2793 | " \n",
2794 | " | ... | \n",
2795 | " ... | \n",
2796 | " ... | \n",
2797 | "
\n",
2798 | " \n",
2799 | " | 5881455 | \n",
2800 | " 69411 | \n",
2801 | " 1044 | \n",
2802 | "
\n",
2803 | " \n",
2804 | " | 5881456 | \n",
2805 | " 69411 | \n",
2806 | " 1061 | \n",
2807 | "
\n",
2808 | " \n",
2809 | " | 5881459 | \n",
2810 | " 69411 | \n",
2811 | " 1126 | \n",
2812 | "
\n",
2813 | " \n",
2814 | " | 5881461 | \n",
2815 | " 69411 | \n",
2816 | " 1134 | \n",
2817 | "
\n",
2818 | " \n",
2819 | " | 5881468 | \n",
2820 | " 69411 | \n",
2821 | " 1157 | \n",
2822 | "
\n",
2823 | " \n",
2824 | " | 5881471 | \n",
2825 | " 69411 | \n",
2826 | " 1162 | \n",
2827 | "
\n",
2828 | " \n",
2829 | " | 5881475 | \n",
2830 | " 69411 | \n",
2831 | " 1177 | \n",
2832 | "
\n",
2833 | " \n",
2834 | " | 5881477 | \n",
2835 | " 69411 | \n",
2836 | " 1179 | \n",
2837 | "
\n",
2838 | " \n",
2839 | " | 5881479 | \n",
2840 | " 69411 | \n",
2841 | " 1197 | \n",
2842 | "
\n",
2843 | " \n",
2844 | " | 5881480 | \n",
2845 | " 69411 | \n",
2846 | " 1205 | \n",
2847 | "
\n",
2848 | " \n",
2849 | " | 5881486 | \n",
2850 | " 69411 | \n",
2851 | " 1329 | \n",
2852 | "
\n",
2853 | " \n",
2854 | " | 5881491 | \n",
2855 | " 69411 | \n",
2856 | " 1779 | \n",
2857 | "
\n",
2858 | " \n",
2859 | " | 5881498 | \n",
2860 | " 69411 | \n",
2861 | " 1895 | \n",
2862 | "
\n",
2863 | " \n",
2864 | " | 5881500 | \n",
2865 | " 69411 | \n",
2866 | " 2158 | \n",
2867 | "
\n",
2868 | " \n",
2869 | " | 5881505 | \n",
2870 | " 69411 | \n",
2871 | " 2347 | \n",
2872 | "
\n",
2873 | " \n",
2874 | " | 5881509 | \n",
2875 | " 69411 | \n",
2876 | " 2435 | \n",
2877 | "
\n",
2878 | " \n",
2879 | " | 5881512 | \n",
2880 | " 69411 | \n",
2881 | " 2791 | \n",
2882 | "
\n",
2883 | " \n",
2884 | " | 5881524 | \n",
2885 | " 69411 | \n",
2886 | " 3163 | \n",
2887 | "
\n",
2888 | " \n",
2889 | " | 5881529 | \n",
2890 | " 69411 | \n",
2891 | " 3238 | \n",
2892 | "
\n",
2893 | " \n",
2894 | " | 5881537 | \n",
2895 | " 69412 | \n",
2896 | " 1 | \n",
2897 | "
\n",
2898 | " \n",
2899 | " | 5881547 | \n",
2900 | " 69412 | \n",
2901 | " 122 | \n",
2902 | "
\n",
2903 | " \n",
2904 | " | 5881552 | \n",
2905 | " 69412 | \n",
2906 | " 229 | \n",
2907 | "
\n",
2908 | " \n",
2909 | " | 5881557 | \n",
2910 | " 69412 | \n",
2911 | " 290 | \n",
2912 | "
\n",
2913 | " \n",
2914 | " | 5881565 | \n",
2915 | " 69412 | \n",
2916 | " 565 | \n",
2917 | "
\n",
2918 | " \n",
2919 | " | 5881568 | \n",
2920 | " 69412 | \n",
2921 | " 575 | \n",
2922 | "
\n",
2923 | " \n",
2924 | " | 5881571 | \n",
2925 | " 69413 | \n",
2926 | " 191 | \n",
2927 | "
\n",
2928 | " \n",
2929 | " | 5881574 | \n",
2930 | " 69413 | \n",
2931 | " 468 | \n",
2932 | "
\n",
2933 | " \n",
2934 | " | 5881583 | \n",
2935 | " 69413 | \n",
2936 | " 1130 | \n",
2937 | "
\n",
2938 | " \n",
2939 | " | 5881586 | \n",
2940 | " 69413 | \n",
2941 | " 1246 | \n",
2942 | "
\n",
2943 | " \n",
2944 | " | 5881590 | \n",
2945 | " 69413 | \n",
2946 | " 1599 | \n",
2947 | "
\n",
2948 | " \n",
2949 | "
\n",
2950 | "
1134946 rows × 2 columns
\n",
2951 | "
"
2952 | ],
2953 | "text/plain": [
2954 | " userId movieId\n",
2955 | "10 1 363\n",
2956 | "11 1 370\n",
2957 | "13 1 454\n",
2958 | "19 1 573\n",
2959 | "25 2 574\n",
2960 | "30 2 1329\n",
2961 | "33 2 1562\n",
2962 | "45 2 7220\n",
2963 | "46 2 7224\n",
2964 | "50 3 34\n",
2965 | "54 3 156\n",
2966 | "62 3 468\n",
2967 | "67 3 573\n",
2968 | "73 4 47\n",
2969 | "75 4 108\n",
2970 | "81 4 327\n",
2971 | "88 4 520\n",
2972 | "90 4 529\n",
2973 | "98 4 869\n",
2974 | "100 4 873\n",
2975 | "101 4 876\n",
2976 | "104 4 990\n",
2977 | "115 4 1136\n",
2978 | "124 4 1207\n",
2979 | "132 5 1124\n",
2980 | "139 5 1319\n",
2981 | "143 5 1521\n",
2982 | "145 5 1618\n",
2983 | "153 5 3699\n",
2984 | "158 6 99\n",
2985 | "... ... ...\n",
2986 | "5881455 69411 1044\n",
2987 | "5881456 69411 1061\n",
2988 | "5881459 69411 1126\n",
2989 | "5881461 69411 1134\n",
2990 | "5881468 69411 1157\n",
2991 | "5881471 69411 1162\n",
2992 | "5881475 69411 1177\n",
2993 | "5881477 69411 1179\n",
2994 | "5881479 69411 1197\n",
2995 | "5881480 69411 1205\n",
2996 | "5881486 69411 1329\n",
2997 | "5881491 69411 1779\n",
2998 | "5881498 69411 1895\n",
2999 | "5881500 69411 2158\n",
3000 | "5881505 69411 2347\n",
3001 | "5881509 69411 2435\n",
3002 | "5881512 69411 2791\n",
3003 | "5881524 69411 3163\n",
3004 | "5881529 69411 3238\n",
3005 | "5881537 69412 1\n",
3006 | "5881547 69412 122\n",
3007 | "5881552 69412 229\n",
3008 | "5881557 69412 290\n",
3009 | "5881565 69412 565\n",
3010 | "5881568 69412 575\n",
3011 | "5881571 69413 191\n",
3012 | "5881574 69413 468\n",
3013 | "5881583 69413 1130\n",
3014 | "5881586 69413 1246\n",
3015 | "5881590 69413 1599\n",
3016 | "\n",
3017 | "[1134946 rows x 2 columns]"
3018 | ]
3019 | },
3020 | "execution_count": 41,
3021 | "metadata": {},
3022 | "output_type": "execute_result"
3023 | }
3024 | ],
3025 | "source": [
3026 | "test_df"
3027 | ]
3028 | },
3029 | {
3030 | "cell_type": "code",
3031 | "execution_count": 45,
3032 | "metadata": {},
3033 | "outputs": [],
3034 | "source": [
3035 | "train_df.to_csv('./train.csv', index=False)\n",
3036 | "test_df.to_csv('./test.csv', index=False)"
3037 | ]
3038 | },
3039 | {
3040 | "cell_type": "code",
3041 | "execution_count": 43,
3042 | "metadata": {},
3043 | "outputs": [],
3044 | "source": [
3045 | "train_df.reset_index(drop=True, inplace=True)"
3046 | ]
3047 | },
3048 | {
3049 | "cell_type": "code",
3050 | "execution_count": 44,
3051 | "metadata": {},
3052 | "outputs": [
3053 | {
3054 | "name": "stdout",
3055 | "output_type": "stream",
3056 | "text": [
3057 | "3000\n",
3058 | "6000\n",
3059 | "9000\n",
3060 | "12000\n",
3061 | "15000\n",
3062 | "18000\n",
3063 | "21000\n",
3064 | "24000\n",
3065 | "27000\n",
3066 | "30000\n",
3067 | "33000\n",
3068 | "36000\n",
3069 | "39000\n",
3070 | "42000\n",
3071 | "45000\n",
3072 | "48000\n",
3073 | "51000\n",
3074 | "54000\n",
3075 | "57000\n",
3076 | "60000\n",
3077 | "63000\n",
3078 | "66000\n",
3079 | "69000\n",
3080 | "3000\n",
3081 | "6000\n",
3082 | "9000\n",
3083 | "12000\n",
3084 | "15000\n",
3085 | "18000\n",
3086 | "21000\n",
3087 | "24000\n",
3088 | "27000\n",
3089 | "30000\n",
3090 | "33000\n",
3091 | "36000\n",
3092 | "39000\n",
3093 | "42000\n",
3094 | "45000\n",
3095 | "48000\n",
3096 | "51000\n",
3097 | "54000\n",
3098 | "57000\n",
3099 | "60000\n",
3100 | "63000\n",
3101 | "66000\n",
3102 | "69000\n",
3103 | "3000\n",
3104 | "6000\n",
3105 | "9000\n",
3106 | "12000\n",
3107 | "15000\n",
3108 | "18000\n",
3109 | "21000\n",
3110 | "24000\n",
3111 | "27000\n",
3112 | "30000\n",
3113 | "33000\n",
3114 | "36000\n",
3115 | "39000\n",
3116 | "42000\n",
3117 | "45000\n",
3118 | "48000\n",
3119 | "51000\n",
3120 | "54000\n",
3121 | "57000\n",
3122 | "60000\n",
3123 | "63000\n",
3124 | "66000\n",
3125 | "69000\n",
3126 | "3000\n",
3127 | "6000\n",
3128 | "9000\n",
3129 | "12000\n",
3130 | "15000\n",
3131 | "18000\n",
3132 | "21000\n",
3133 | "24000\n",
3134 | "27000\n",
3135 | "30000\n",
3136 | "33000\n",
3137 | "36000\n",
3138 | "39000\n",
3139 | "42000\n",
3140 | "45000\n",
3141 | "48000\n",
3142 | "51000\n",
3143 | "54000\n",
3144 | "57000\n",
3145 | "60000\n",
3146 | "63000\n",
3147 | "66000\n",
3148 | "69000\n",
3149 | "3000\n",
3150 | "6000\n",
3151 | "9000\n",
3152 | "12000\n",
3153 | "15000\n",
3154 | "18000\n",
3155 | "21000\n",
3156 | "24000\n",
3157 | "27000\n",
3158 | "30000\n",
3159 | "33000\n",
3160 | "36000\n",
3161 | "39000\n",
3162 | "42000\n",
3163 | "45000\n",
3164 | "48000\n",
3165 | "51000\n",
3166 | "54000\n",
3167 | "57000\n",
3168 | "60000\n",
3169 | "63000\n",
3170 | "66000\n",
3171 | "69000\n"
3172 | ]
3173 | }
3174 | ],
3175 | "source": [
3176 | "train_ratio = 0.9\n",
3177 | "vali_ratio = 1 - train_ratio\n",
3178 | "for i in range(5):\n",
3179 | " train_tmp_df = train_df.copy()\n",
3180 | " vali_df = train_df.copy()\n",
3181 | " vali_idx = []\n",
3182 | " for u in range(1, num_users+1):\n",
3183 | " u_idx = train_tmp_df.index[train_tmp_df['userId'] == u]\n",
3184 | " idx_len = len(u_idx)\n",
3185 | " vali_len = int(idx_len * vali_ratio)\n",
3186 | " if vali_len == 0:\n",
3187 | " vali_len = 1\n",
3188 | " tmp = np.random.choice(u_idx, size=vali_len, replace=False)\n",
3189 | " vali_idx += tmp.tolist()\n",
3190 | " if u % 3000 == 0:\n",
3191 | " print u\n",
3192 | " vali_set = set(vali_idx)\n",
3193 | " train_set = set(range(len(train_tmp_df)))\n",
3194 | " train_set -= vali_set\n",
3195 | " train_tmp_idx = list(train_set)\n",
3196 | " train_tmp_df.drop(vali_idx, axis=0, inplace=True)\n",
3197 | " vali_df.drop(train_tmp_idx, axis=0, inplace=True)\n",
3198 | " train_tmp_df.to_csv('./train_'+str(i)+'.csv', index=False)\n",
3199 | " vali_df.to_csv('./vali_'+str(i)+'.csv', index=False)"
3200 | ]
3201 | }
3202 | ],
3203 | "metadata": {
3204 | "kernelspec": {
3205 | "display_name": "Python 2",
3206 | "language": "python",
3207 | "name": "python2"
3208 | },
3209 | "language_info": {
3210 | "codemirror_mode": {
3211 | "name": "ipython",
3212 | "version": 2
3213 | },
3214 | "file_extension": ".py",
3215 | "mimetype": "text/x-python",
3216 | "name": "python",
3217 | "nbconvert_exporter": "python",
3218 | "pygments_lexer": "ipython2",
3219 | "version": "2.7.12"
3220 | }
3221 | },
3222 | "nbformat": 4,
3223 | "nbformat_minor": 2
3224 | }
3225 |
--------------------------------------------------------------------------------