├── Data ├── README ├── amazon-book │ ├── README.md │ ├── item_list.txt │ ├── test.txt │ ├── train.txt │ └── user_list.txt ├── gowalla │ ├── README.md │ ├── item_list.txt │ ├── test.txt │ ├── train.txt │ └── user_list.txt └── yelp2018 │ ├── item_list.txt │ ├── test.txt │ ├── train.txt │ └── user_list.txt ├── LightGCN.py ├── README.md ├── evaluator ├── __init__.py ├── cpp │ ├── apt_evaluate_foldout.pyx │ ├── apt_tools.pyx │ ├── evaluate_foldout.py │ └── include │ │ ├── evaluate_foldout.h │ │ ├── thread_pool.h │ │ └── tools.h └── python │ ├── evaluate_foldout.py │ └── evaluate_loo.py ├── setup.py └── utility ├── README.md ├── batch_test.py ├── helper.py ├── load_data.py └── parser.py /Data/README: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Data/amazon-book/README.md: -------------------------------------------------------------------------------- 1 | Look for the full dataset? Please visit the [websit](http://jmcauley.ucsd.edu/data/amazon). 2 | -------------------------------------------------------------------------------- /Data/gowalla/README.md: -------------------------------------------------------------------------------- 1 | Look for the full dataset? 2 | Please visit the [websit](https://snap.stanford.edu/data/loc-gowalla.html). 3 | -------------------------------------------------------------------------------- /LightGCN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 10, 2018 3 | Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in: 4 | Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019. 5 | @author: Xiang Wang (xiangwang@u.nus.edu) 6 | version: 7 | Parallelized sampling on CPU 8 | C++ evaluation for top-k recommendation 9 | ''' 10 | 11 | import os 12 | import sys 13 | import threading 14 | import tensorflow as tf 15 | from tensorflow.python.client import device_lib 16 | from utility.helper import * 17 | from utility.batch_test import * 18 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 19 | 20 | cpus = [x.name for x in device_lib.list_local_devices() if x.device_type == 'CPU'] 21 | 22 | class LightGCN(object): 23 | def __init__(self, data_config, pretrain_data): 24 | # argument settings 25 | self.model_type = 'LightGCN' 26 | self.adj_type = args.adj_type 27 | self.alg_type = args.alg_type 28 | self.pretrain_data = pretrain_data 29 | self.n_users = data_config['n_users'] 30 | self.n_items = data_config['n_items'] 31 | self.n_fold = 100 32 | self.norm_adj = data_config['norm_adj'] 33 | self.n_nonzero_elems = self.norm_adj.count_nonzero() 34 | self.lr = args.lr 35 | self.emb_dim = args.embed_size 36 | self.batch_size = args.batch_size 37 | self.weight_size = eval(args.layer_size) 38 | self.n_layers = len(self.weight_size) 39 | self.regs = eval(args.regs) 40 | self.decay = self.regs[0] 41 | self.log_dir=self.create_model_str() 42 | self.verbose = args.verbose 43 | self.Ks = eval(args.Ks) 44 | 45 | 46 | ''' 47 | ********************************************************* 48 | Create Placeholder for Input Data & Dropout. 49 | ''' 50 | # placeholder definition 51 | self.users = tf.placeholder(tf.int32, shape=(None,)) 52 | self.pos_items = tf.placeholder(tf.int32, shape=(None,)) 53 | self.neg_items = tf.placeholder(tf.int32, shape=(None,)) 54 | 55 | self.node_dropout_flag = args.node_dropout_flag 56 | self.node_dropout = tf.placeholder(tf.float32, shape=[None]) 57 | self.mess_dropout = tf.placeholder(tf.float32, shape=[None]) 58 | with tf.name_scope('TRAIN_LOSS'): 59 | self.train_loss = tf.placeholder(tf.float32) 60 | tf.summary.scalar('train_loss', self.train_loss) 61 | self.train_mf_loss = tf.placeholder(tf.float32) 62 | tf.summary.scalar('train_mf_loss', self.train_mf_loss) 63 | self.train_emb_loss = tf.placeholder(tf.float32) 64 | tf.summary.scalar('train_emb_loss', self.train_emb_loss) 65 | self.train_reg_loss = tf.placeholder(tf.float32) 66 | tf.summary.scalar('train_reg_loss', self.train_reg_loss) 67 | self.merged_train_loss = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TRAIN_LOSS')) 68 | 69 | 70 | with tf.name_scope('TRAIN_ACC'): 71 | self.train_rec_first = tf.placeholder(tf.float32) 72 | #record for top(Ks[0]) 73 | tf.summary.scalar('train_rec_first', self.train_rec_first) 74 | self.train_rec_last = tf.placeholder(tf.float32) 75 | #record for top(Ks[-1]) 76 | tf.summary.scalar('train_rec_last', self.train_rec_last) 77 | self.train_ndcg_first = tf.placeholder(tf.float32) 78 | tf.summary.scalar('train_ndcg_first', self.train_ndcg_first) 79 | self.train_ndcg_last = tf.placeholder(tf.float32) 80 | tf.summary.scalar('train_ndcg_last', self.train_ndcg_last) 81 | self.merged_train_acc = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TRAIN_ACC')) 82 | 83 | with tf.name_scope('TEST_LOSS'): 84 | self.test_loss = tf.placeholder(tf.float32) 85 | tf.summary.scalar('test_loss', self.test_loss) 86 | self.test_mf_loss = tf.placeholder(tf.float32) 87 | tf.summary.scalar('test_mf_loss', self.test_mf_loss) 88 | self.test_emb_loss = tf.placeholder(tf.float32) 89 | tf.summary.scalar('test_emb_loss', self.test_emb_loss) 90 | self.test_reg_loss = tf.placeholder(tf.float32) 91 | tf.summary.scalar('test_reg_loss', self.test_reg_loss) 92 | self.merged_test_loss = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TEST_LOSS')) 93 | 94 | with tf.name_scope('TEST_ACC'): 95 | self.test_rec_first = tf.placeholder(tf.float32) 96 | tf.summary.scalar('test_rec_first', self.test_rec_first) 97 | self.test_rec_last = tf.placeholder(tf.float32) 98 | tf.summary.scalar('test_rec_last', self.test_rec_last) 99 | self.test_ndcg_first = tf.placeholder(tf.float32) 100 | tf.summary.scalar('test_ndcg_first', self.test_ndcg_first) 101 | self.test_ndcg_last = tf.placeholder(tf.float32) 102 | tf.summary.scalar('test_ndcg_last', self.test_ndcg_last) 103 | self.merged_test_acc = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, 'TEST_ACC')) 104 | """ 105 | ********************************************************* 106 | Create Model Parameters (i.e., Initialize Weights). 107 | """ 108 | # initialization of model parameters 109 | self.weights = self._init_weights() 110 | 111 | """ 112 | ********************************************************* 113 | Compute Graph-based Representations of all users & items via Message-Passing Mechanism of Graph Neural Networks. 114 | Different Convolutional Layers: 115 | 1. ngcf: defined in 'Neural Graph Collaborative Filtering', SIGIR2019; 116 | 2. gcn: defined in 'Semi-Supervised Classification with Graph Convolutional Networks', ICLR2018; 117 | 3. gcmc: defined in 'Graph Convolutional Matrix Completion', KDD2018; 118 | """ 119 | if self.alg_type in ['lightgcn']: 120 | self.ua_embeddings, self.ia_embeddings = self._create_lightgcn_embed() 121 | 122 | elif self.alg_type in ['ngcf']: 123 | self.ua_embeddings, self.ia_embeddings = self._create_ngcf_embed() 124 | 125 | elif self.alg_type in ['gcn']: 126 | self.ua_embeddings, self.ia_embeddings = self._create_gcn_embed() 127 | 128 | elif self.alg_type in ['gcmc']: 129 | self.ua_embeddings, self.ia_embeddings = self._create_gcmc_embed() 130 | 131 | """ 132 | ********************************************************* 133 | Establish the final representations for user-item pairs in batch. 134 | """ 135 | self.u_g_embeddings = tf.nn.embedding_lookup(self.ua_embeddings, self.users) 136 | self.pos_i_g_embeddings = tf.nn.embedding_lookup(self.ia_embeddings, self.pos_items) 137 | self.neg_i_g_embeddings = tf.nn.embedding_lookup(self.ia_embeddings, self.neg_items) 138 | self.u_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['user_embedding'], self.users) 139 | self.pos_i_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['item_embedding'], self.pos_items) 140 | self.neg_i_g_embeddings_pre = tf.nn.embedding_lookup(self.weights['item_embedding'], self.neg_items) 141 | 142 | """ 143 | ********************************************************* 144 | Inference for the testing phase. 145 | """ 146 | self.batch_ratings = tf.matmul(self.u_g_embeddings, self.pos_i_g_embeddings, transpose_a=False, transpose_b=True) 147 | 148 | """ 149 | ********************************************************* 150 | Generate Predictions & Optimize via BPR loss. 151 | """ 152 | self.mf_loss, self.emb_loss, self.reg_loss = self.create_bpr_loss(self.u_g_embeddings, 153 | self.pos_i_g_embeddings, 154 | self.neg_i_g_embeddings) 155 | self.loss = self.mf_loss + self.emb_loss 156 | 157 | self.opt = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 158 | 159 | 160 | def create_model_str(self): 161 | log_dir = '/' + self.alg_type+'/layers_'+str(self.n_layers)+'/dim_'+str(self.emb_dim) 162 | log_dir+='/'+args.dataset+'/lr_' + str(self.lr) + '/reg_' + str(self.decay) 163 | return log_dir 164 | 165 | 166 | def _init_weights(self): 167 | all_weights = dict() 168 | initializer = tf.random_normal_initializer(stddev=0.01) #tf.contrib.layers.xavier_initializer() 169 | if self.pretrain_data is None: 170 | all_weights['user_embedding'] = tf.Variable(initializer([self.n_users, self.emb_dim]), name='user_embedding') 171 | all_weights['item_embedding'] = tf.Variable(initializer([self.n_items, self.emb_dim]), name='item_embedding') 172 | print('using random initialization')#print('using xavier initialization') 173 | else: 174 | all_weights['user_embedding'] = tf.Variable(initial_value=self.pretrain_data['user_embed'], trainable=True, 175 | name='user_embedding', dtype=tf.float32) 176 | all_weights['item_embedding'] = tf.Variable(initial_value=self.pretrain_data['item_embed'], trainable=True, 177 | name='item_embedding', dtype=tf.float32) 178 | print('using pretrained initialization') 179 | 180 | self.weight_size_list = [self.emb_dim] + self.weight_size 181 | 182 | for k in range(self.n_layers): 183 | all_weights['W_gc_%d' %k] = tf.Variable( 184 | initializer([self.weight_size_list[k], self.weight_size_list[k+1]]), name='W_gc_%d' % k) 185 | all_weights['b_gc_%d' %k] = tf.Variable( 186 | initializer([1, self.weight_size_list[k+1]]), name='b_gc_%d' % k) 187 | 188 | all_weights['W_bi_%d' % k] = tf.Variable( 189 | initializer([self.weight_size_list[k], self.weight_size_list[k + 1]]), name='W_bi_%d' % k) 190 | all_weights['b_bi_%d' % k] = tf.Variable( 191 | initializer([1, self.weight_size_list[k + 1]]), name='b_bi_%d' % k) 192 | 193 | all_weights['W_mlp_%d' % k] = tf.Variable( 194 | initializer([self.weight_size_list[k], self.weight_size_list[k+1]]), name='W_mlp_%d' % k) 195 | all_weights['b_mlp_%d' % k] = tf.Variable( 196 | initializer([1, self.weight_size_list[k+1]]), name='b_mlp_%d' % k) 197 | 198 | return all_weights 199 | def _split_A_hat(self, X): 200 | A_fold_hat = [] 201 | 202 | fold_len = (self.n_users + self.n_items) // self.n_fold 203 | for i_fold in range(self.n_fold): 204 | start = i_fold * fold_len 205 | if i_fold == self.n_fold -1: 206 | end = self.n_users + self.n_items 207 | else: 208 | end = (i_fold + 1) * fold_len 209 | 210 | A_fold_hat.append(self._convert_sp_mat_to_sp_tensor(X[start:end])) 211 | return A_fold_hat 212 | 213 | def _split_A_hat_node_dropout(self, X): 214 | A_fold_hat = [] 215 | 216 | fold_len = (self.n_users + self.n_items) // self.n_fold 217 | for i_fold in range(self.n_fold): 218 | start = i_fold * fold_len 219 | if i_fold == self.n_fold -1: 220 | end = self.n_users + self.n_items 221 | else: 222 | end = (i_fold + 1) * fold_len 223 | 224 | temp = self._convert_sp_mat_to_sp_tensor(X[start:end]) 225 | n_nonzero_temp = X[start:end].count_nonzero() 226 | A_fold_hat.append(self._dropout_sparse(temp, 1 - self.node_dropout[0], n_nonzero_temp)) 227 | 228 | return A_fold_hat 229 | 230 | def _create_lightgcn_embed(self): 231 | if self.node_dropout_flag: 232 | A_fold_hat = self._split_A_hat_node_dropout(self.norm_adj) 233 | else: 234 | A_fold_hat = self._split_A_hat(self.norm_adj) 235 | 236 | ego_embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0) 237 | all_embeddings = [ego_embeddings] 238 | 239 | for k in range(0, self.n_layers): 240 | 241 | temp_embed = [] 242 | for f in range(self.n_fold): 243 | temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], ego_embeddings)) 244 | 245 | side_embeddings = tf.concat(temp_embed, 0) 246 | ego_embeddings = side_embeddings 247 | all_embeddings += [ego_embeddings] 248 | all_embeddings=tf.stack(all_embeddings,1) 249 | all_embeddings=tf.reduce_mean(all_embeddings,axis=1,keepdims=False) 250 | u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0) 251 | return u_g_embeddings, i_g_embeddings 252 | 253 | def _create_ngcf_embed(self): 254 | if self.node_dropout_flag: 255 | A_fold_hat = self._split_A_hat_node_dropout(self.norm_adj) 256 | else: 257 | A_fold_hat = self._split_A_hat(self.norm_adj) 258 | 259 | ego_embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0) 260 | 261 | all_embeddings = [ego_embeddings] 262 | 263 | for k in range(0, self.n_layers): 264 | 265 | temp_embed = [] 266 | for f in range(self.n_fold): 267 | temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], ego_embeddings)) 268 | 269 | side_embeddings = tf.concat(temp_embed, 0) 270 | sum_embeddings = tf.nn.leaky_relu(tf.matmul(side_embeddings, self.weights['W_gc_%d' % k]) + self.weights['b_gc_%d' % k]) 271 | 272 | 273 | 274 | # bi messages of neighbors. 275 | bi_embeddings = tf.multiply(ego_embeddings, side_embeddings) 276 | # transformed bi messages of neighbors. 277 | bi_embeddings = tf.nn.leaky_relu(tf.matmul(bi_embeddings, self.weights['W_bi_%d' % k]) + self.weights['b_bi_%d' % k]) 278 | # non-linear activation. 279 | ego_embeddings = sum_embeddings + bi_embeddings 280 | 281 | # message dropout. 282 | # ego_embeddings = tf.nn.dropout(ego_embeddings, 1 - self.mess_dropout[k]) 283 | 284 | # normalize the distribution of embeddings. 285 | norm_embeddings = tf.nn.l2_normalize(ego_embeddings, axis=1) 286 | 287 | all_embeddings += [norm_embeddings] 288 | 289 | all_embeddings = tf.concat(all_embeddings, 1) 290 | u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0) 291 | return u_g_embeddings, i_g_embeddings 292 | 293 | 294 | def _create_gcn_embed(self): 295 | A_fold_hat = self._split_A_hat(self.norm_adj) 296 | embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0) 297 | 298 | 299 | all_embeddings = [embeddings] 300 | 301 | for k in range(0, self.n_layers): 302 | temp_embed = [] 303 | for f in range(self.n_fold): 304 | temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], embeddings)) 305 | 306 | embeddings = tf.concat(temp_embed, 0) 307 | embeddings = tf.nn.leaky_relu(tf.matmul(embeddings, self.weights['W_gc_%d' %k]) + self.weights['b_gc_%d' %k]) 308 | # embeddings = tf.nn.dropout(embeddings, 1 - self.mess_dropout[k]) 309 | 310 | all_embeddings += [embeddings] 311 | 312 | all_embeddings = tf.concat(all_embeddings, 1) 313 | u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0) 314 | return u_g_embeddings, i_g_embeddings 315 | 316 | def _create_gcmc_embed(self): 317 | A_fold_hat = self._split_A_hat(self.norm_adj) 318 | 319 | embeddings = tf.concat([self.weights['user_embedding'], self.weights['item_embedding']], axis=0) 320 | 321 | all_embeddings = [] 322 | 323 | for k in range(0, self.n_layers): 324 | temp_embed = [] 325 | for f in range(self.n_fold): 326 | temp_embed.append(tf.sparse_tensor_dense_matmul(A_fold_hat[f], embeddings)) 327 | embeddings = tf.concat(temp_embed, 0) 328 | # convolutional layer. 329 | embeddings = tf.nn.leaky_relu(tf.matmul(embeddings, self.weights['W_gc_%d' % k]) + self.weights['b_gc_%d' % k]) 330 | # dense layer. 331 | mlp_embeddings = tf.matmul(embeddings, self.weights['W_mlp_%d' %k]) + self.weights['b_mlp_%d' %k] 332 | # mlp_embeddings = tf.nn.dropout(mlp_embeddings, 1 - self.mess_dropout[k]) 333 | 334 | all_embeddings += [mlp_embeddings] 335 | all_embeddings = tf.concat(all_embeddings, 1) 336 | 337 | u_g_embeddings, i_g_embeddings = tf.split(all_embeddings, [self.n_users, self.n_items], 0) 338 | return u_g_embeddings, i_g_embeddings 339 | 340 | def create_bpr_loss(self, users, pos_items, neg_items): 341 | pos_scores = tf.reduce_sum(tf.multiply(users, pos_items), axis=1) 342 | neg_scores = tf.reduce_sum(tf.multiply(users, neg_items), axis=1) 343 | 344 | regularizer = tf.nn.l2_loss(self.u_g_embeddings_pre) + tf.nn.l2_loss( 345 | self.pos_i_g_embeddings_pre) + tf.nn.l2_loss(self.neg_i_g_embeddings_pre) 346 | regularizer = regularizer / self.batch_size 347 | 348 | mf_loss = tf.reduce_mean(tf.nn.softplus(-(pos_scores - neg_scores))) 349 | 350 | 351 | emb_loss = self.decay * regularizer 352 | 353 | reg_loss = tf.constant(0.0, tf.float32, [1]) 354 | 355 | return mf_loss, emb_loss, reg_loss 356 | 357 | def _convert_sp_mat_to_sp_tensor(self, X): 358 | coo = X.tocoo().astype(np.float32) 359 | indices = np.mat([coo.row, coo.col]).transpose() 360 | return tf.SparseTensor(indices, coo.data, coo.shape) 361 | 362 | def _dropout_sparse(self, X, keep_prob, n_nonzero_elems): 363 | """ 364 | Dropout for sparse tensors. 365 | """ 366 | noise_shape = [n_nonzero_elems] 367 | random_tensor = keep_prob 368 | random_tensor += tf.random_uniform(noise_shape) 369 | dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool) 370 | pre_out = tf.sparse_retain(X, dropout_mask) 371 | 372 | return pre_out * tf.div(1., keep_prob) 373 | 374 | def load_pretrained_data(): 375 | pretrain_path = '%spretrain/%s/%s.npz' % (args.proj_path, args.dataset, 'embedding') 376 | try: 377 | pretrain_data = np.load(pretrain_path) 378 | print('load the pretrained embeddings.') 379 | except Exception: 380 | pretrain_data = None 381 | return pretrain_data 382 | 383 | # parallelized sampling on CPU 384 | class sample_thread(threading.Thread): 385 | def __init__(self): 386 | threading.Thread.__init__(self) 387 | def run(self): 388 | with tf.device(cpus[0]): 389 | self.data = data_generator.sample() 390 | 391 | class sample_thread_test(threading.Thread): 392 | def __init__(self): 393 | threading.Thread.__init__(self) 394 | def run(self): 395 | with tf.device(cpus[0]): 396 | self.data = data_generator.sample_test() 397 | 398 | # training on GPU 399 | class train_thread(threading.Thread): 400 | def __init__(self,model, sess, sample): 401 | threading.Thread.__init__(self) 402 | self.model = model 403 | self.sess = sess 404 | self.sample = sample 405 | def run(self): 406 | 407 | users, pos_items, neg_items = self.sample.data 408 | self.data = sess.run([self.model.opt, self.model.loss, self.model.mf_loss, self.model.emb_loss, self.model.reg_loss], 409 | feed_dict={model.users: users, model.pos_items: pos_items, 410 | model.node_dropout: eval(args.node_dropout), 411 | model.mess_dropout: eval(args.mess_dropout), 412 | model.neg_items: neg_items}) 413 | 414 | class train_thread_test(threading.Thread): 415 | def __init__(self,model, sess, sample): 416 | threading.Thread.__init__(self) 417 | self.model = model 418 | self.sess = sess 419 | self.sample = sample 420 | def run(self): 421 | 422 | users, pos_items, neg_items = self.sample.data 423 | self.data = sess.run([self.model.loss, self.model.mf_loss, self.model.emb_loss], 424 | feed_dict={model.users: users, model.pos_items: pos_items, 425 | model.neg_items: neg_items, 426 | model.node_dropout: eval(args.node_dropout), 427 | model.mess_dropout: eval(args.mess_dropout)}) 428 | 429 | if __name__ == '__main__': 430 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 431 | f0 = time() 432 | 433 | config = dict() 434 | config['n_users'] = data_generator.n_users 435 | config['n_items'] = data_generator.n_items 436 | 437 | """ 438 | ********************************************************* 439 | Generate the Laplacian matrix, where each entry defines the decay factor (e.g., p_ui) between two connected nodes. 440 | """ 441 | plain_adj, norm_adj, mean_adj,pre_adj = data_generator.get_adj_mat() 442 | if args.adj_type == 'plain': 443 | config['norm_adj'] = plain_adj 444 | print('use the plain adjacency matrix') 445 | elif args.adj_type == 'norm': 446 | config['norm_adj'] = norm_adj 447 | print('use the normalized adjacency matrix') 448 | elif args.adj_type == 'gcmc': 449 | config['norm_adj'] = mean_adj 450 | print('use the gcmc adjacency matrix') 451 | elif args.adj_type=='pre': 452 | config['norm_adj']=pre_adj 453 | print('use the pre adjcency matrix') 454 | else: 455 | config['norm_adj'] = mean_adj + sp.eye(mean_adj.shape[0]) 456 | print('use the mean adjacency matrix') 457 | t0 = time() 458 | if args.pretrain == -1: 459 | pretrain_data = load_pretrained_data() 460 | else: 461 | pretrain_data = None 462 | model = LightGCN(data_config=config, pretrain_data=pretrain_data) 463 | 464 | """ 465 | ********************************************************* 466 | Save the model parameters. 467 | """ 468 | saver = tf.train.Saver() 469 | 470 | if args.save_flag == 1: 471 | layer = '-'.join([str(l) for l in eval(args.layer_size)]) 472 | weights_save_path = '%sweights/%s/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type, layer, 473 | str(args.lr), '-'.join([str(r) for r in eval(args.regs)])) 474 | ensureDir(weights_save_path) 475 | save_saver = tf.train.Saver(max_to_keep=1) 476 | 477 | config = tf.ConfigProto() 478 | config.gpu_options.allow_growth = True 479 | sess = tf.Session(config=config) 480 | 481 | """ 482 | ********************************************************* 483 | Reload the pretrained model parameters. 484 | """ 485 | if args.pretrain == 1: 486 | layer = '-'.join([str(l) for l in eval(args.layer_size)]) 487 | 488 | pretrain_path = '%sweights/%s/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type, layer, 489 | str(args.lr), '-'.join([str(r) for r in eval(args.regs)])) 490 | 491 | 492 | ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrain_path + '/checkpoint')) 493 | if ckpt and ckpt.model_checkpoint_path: 494 | sess.run(tf.global_variables_initializer()) 495 | saver.restore(sess, ckpt.model_checkpoint_path) 496 | print('load the pretrained model parameters from: ', pretrain_path) 497 | 498 | # ********************************************************* 499 | # get the performance from pretrained model. 500 | if args.report != 1: 501 | users_to_test = list(data_generator.test_set.keys()) 502 | ret = test(sess, model, users_to_test, drop_flag=True) 503 | cur_best_pre_0 = ret['recall'][0] 504 | 505 | pretrain_ret = 'pretrained model recall=[%s], precision=[%s], '\ 506 | 'ndcg=[%s]' % \ 507 | (', '.join(['%.5f' % r for r in ret['recall']]), 508 | ', '.join(['%.5f' % r for r in ret['precision']]), 509 | ', '.join(['%.5f' % r for r in ret['ndcg']])) 510 | print(pretrain_ret) 511 | else: 512 | sess.run(tf.global_variables_initializer()) 513 | cur_best_pre_0 = 0. 514 | print('without pretraining.') 515 | 516 | else: 517 | sess.run(tf.global_variables_initializer()) 518 | cur_best_pre_0 = 0. 519 | print('without pretraining.') 520 | 521 | """ 522 | ********************************************************* 523 | Get the performance w.r.t. different sparsity levels. 524 | """ 525 | if args.report == 1: 526 | assert args.test_flag == 'full' 527 | users_to_test_list, split_state = data_generator.get_sparsity_split() 528 | users_to_test_list.append(list(data_generator.test_set.keys())) 529 | split_state.append('all') 530 | 531 | report_path = '%sreport/%s/%s.result' % (args.proj_path, args.dataset, model.model_type) 532 | ensureDir(report_path) 533 | f = open(report_path, 'w') 534 | f.write( 535 | 'embed_size=%d, lr=%.4f, layer_size=%s, keep_prob=%s, regs=%s, loss_type=%s, adj_type=%s\n' 536 | % (args.embed_size, args.lr, args.layer_size, args.keep_prob, args.regs, args.loss_type, args.adj_type)) 537 | 538 | for i, users_to_test in enumerate(users_to_test_list): 539 | ret = test(sess, model, users_to_test, drop_flag=True) 540 | 541 | final_perf = "recall=[%s], precision=[%s], ndcg=[%s]" % \ 542 | (', '.join(['%.5f' % r for r in ret['recall']]), 543 | ', '.join(['%.5f' % r for r in ret['precision']]), 544 | ', '.join(['%.5f' % r for r in ret['ndcg']])) 545 | 546 | f.write('\t%s\n\t%s\n' % (split_state[i], final_perf)) 547 | f.close() 548 | exit() 549 | 550 | """ 551 | ********************************************************* 552 | Train. 553 | """ 554 | tensorboard_model_path = 'tensorboard/' 555 | if not os.path.exists(tensorboard_model_path): 556 | os.makedirs(tensorboard_model_path) 557 | run_time = 1 558 | while (True): 559 | if os.path.exists(tensorboard_model_path + model.log_dir +'/run_' + str(run_time)): 560 | run_time += 1 561 | else: 562 | break 563 | train_writer = tf.summary.FileWriter(tensorboard_model_path +model.log_dir+ '/run_' + str(run_time), sess.graph) 564 | 565 | 566 | loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], [] 567 | stopping_step = 0 568 | should_stop = False 569 | 570 | 571 | for epoch in range(1, args.epoch + 1): 572 | t1 = time() 573 | loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0. 574 | n_batch = data_generator.n_train // args.batch_size + 1 575 | loss_test,mf_loss_test,emb_loss_test,reg_loss_test=0.,0.,0.,0. 576 | ''' 577 | ********************************************************* 578 | parallelized sampling 579 | ''' 580 | sample_last = sample_thread() 581 | sample_last.start() 582 | sample_last.join() 583 | for idx in range(n_batch): 584 | train_cur = train_thread(model, sess, sample_last) 585 | sample_next = sample_thread() 586 | 587 | train_cur.start() 588 | sample_next.start() 589 | 590 | sample_next.join() 591 | train_cur.join() 592 | 593 | users, pos_items, neg_items = sample_last.data 594 | _, batch_loss, batch_mf_loss, batch_emb_loss, batch_reg_loss = train_cur.data 595 | sample_last = sample_next 596 | 597 | loss += batch_loss/n_batch 598 | mf_loss += batch_mf_loss/n_batch 599 | emb_loss += batch_emb_loss/n_batch 600 | 601 | summary_train_loss= sess.run(model.merged_train_loss, 602 | feed_dict={model.train_loss: loss, model.train_mf_loss: mf_loss, 603 | model.train_emb_loss: emb_loss, model.train_reg_loss: reg_loss}) 604 | train_writer.add_summary(summary_train_loss, epoch) 605 | if np.isnan(loss) == True: 606 | print('ERROR: loss is nan.') 607 | sys.exit() 608 | 609 | if (epoch % 20) != 0: 610 | if args.verbose > 0 and epoch % args.verbose == 0: 611 | perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % ( 612 | epoch, time() - t1, loss, mf_loss, emb_loss) 613 | print(perf_str) 614 | continue 615 | users_to_test = list(data_generator.train_items.keys()) 616 | ret = test(sess, model, users_to_test ,drop_flag=True,train_set_flag=1) 617 | perf_str = 'Epoch %d: train==[%.5f=%.5f + %.5f + %.5f], recall=[%s], precision=[%s], ndcg=[%s]' % \ 618 | (epoch, loss, mf_loss, emb_loss, reg_loss, 619 | ', '.join(['%.5f' % r for r in ret['recall']]), 620 | ', '.join(['%.5f' % r for r in ret['precision']]), 621 | ', '.join(['%.5f' % r for r in ret['ndcg']])) 622 | print(perf_str) 623 | summary_train_acc = sess.run(model.merged_train_acc, feed_dict={model.train_rec_first: ret['recall'][0], 624 | model.train_rec_last: ret['recall'][-1], 625 | model.train_ndcg_first: ret['ndcg'][0], 626 | model.train_ndcg_last: ret['ndcg'][-1]}) 627 | train_writer.add_summary(summary_train_acc, epoch // 20) 628 | 629 | ''' 630 | ********************************************************* 631 | parallelized sampling 632 | ''' 633 | sample_last= sample_thread_test() 634 | sample_last.start() 635 | sample_last.join() 636 | for idx in range(n_batch): 637 | train_cur = train_thread_test(model, sess, sample_last) 638 | sample_next = sample_thread_test() 639 | 640 | train_cur.start() 641 | sample_next.start() 642 | 643 | sample_next.join() 644 | train_cur.join() 645 | 646 | users, pos_items, neg_items = sample_last.data 647 | batch_loss_test, batch_mf_loss_test, batch_emb_loss_test = train_cur.data 648 | sample_last = sample_next 649 | 650 | loss_test += batch_loss_test / n_batch 651 | mf_loss_test += batch_mf_loss_test / n_batch 652 | emb_loss_test += batch_emb_loss_test / n_batch 653 | 654 | summary_test_loss = sess.run(model.merged_test_loss, 655 | feed_dict={model.test_loss: loss_test, model.test_mf_loss: mf_loss_test, 656 | model.test_emb_loss: emb_loss_test, model.test_reg_loss: reg_loss_test}) 657 | train_writer.add_summary(summary_test_loss, epoch // 20) 658 | t2 = time() 659 | users_to_test = list(data_generator.test_set.keys()) 660 | ret = test(sess, model, users_to_test, drop_flag=True) 661 | summary_test_acc = sess.run(model.merged_test_acc, 662 | feed_dict={model.test_rec_first: ret['recall'][0], model.test_rec_last: ret['recall'][-1], 663 | model.test_ndcg_first: ret['ndcg'][0], model.test_ndcg_last: ret['ndcg'][-1]}) 664 | train_writer.add_summary(summary_test_acc, epoch // 20) 665 | 666 | 667 | t3 = time() 668 | 669 | loss_loger.append(loss) 670 | rec_loger.append(ret['recall']) 671 | pre_loger.append(ret['precision']) 672 | ndcg_loger.append(ret['ndcg']) 673 | 674 | if args.verbose > 0: 675 | perf_str = 'Epoch %d [%.1fs + %.1fs]: test==[%.5f=%.5f + %.5f + %.5f], recall=[%s], ' \ 676 | 'precision=[%s], ndcg=[%s]' % \ 677 | (epoch, t2 - t1, t3 - t2, loss_test, mf_loss_test, emb_loss_test, reg_loss_test, 678 | ', '.join(['%.5f' % r for r in ret['recall']]), 679 | ', '.join(['%.5f' % r for r in ret['precision']]), 680 | ', '.join(['%.5f' % r for r in ret['ndcg']])) 681 | print(perf_str) 682 | 683 | cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0, 684 | stopping_step, expected_order='acc', flag_step=5) 685 | 686 | # ********************************************************* 687 | # early stopping when cur_best_pre_0 is decreasing for ten successive steps. 688 | if should_stop == True: 689 | break 690 | 691 | # ********************************************************* 692 | # save the user & item embeddings for pretraining. 693 | if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1: 694 | save_saver.save(sess, weights_save_path + '/weights', global_step=epoch) 695 | print('save the weights in path: ', weights_save_path) 696 | recs = np.array(rec_loger) 697 | pres = np.array(pre_loger) 698 | ndcgs = np.array(ndcg_loger) 699 | 700 | best_rec_0 = max(recs[:, 0]) 701 | idx = list(recs[:, 0]).index(best_rec_0) 702 | 703 | final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], ndcg=[%s]" % \ 704 | (idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]), 705 | '\t'.join(['%.5f' % r for r in pres[idx]]), 706 | '\t'.join(['%.5f' % r for r in ndcgs[idx]])) 707 | print(final_perf) 708 | 709 | save_path = '%soutput/%s/%s.result' % (args.proj_path, args.dataset, model.model_type) 710 | ensureDir(save_path) 711 | f = open(save_path, 'a') 712 | 713 | f.write( 714 | 'embed_size=%d, lr=%.4f, layer_size=%s, node_dropout=%s, mess_dropout=%s, regs=%s, adj_type=%s\n\t%s\n' 715 | % (args.embed_size, args.lr, args.layer_size, args.node_dropout, args.mess_dropout, args.regs, 716 | args.adj_type, final_perf)) 717 | f.close() 718 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightGCN 2 | This is our Tensorflow implementation for our SIGIR 2020 paper: 3 | 4 | >Xiangnan He, Kuan Deng ,Xiang Wang, Yan Li, Yongdong Zhang, Meng Wang(2020). LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation, [Paper in arXiv](https://arxiv.org/abs/2002.02126). 5 | 6 | Contributors: Dr. Xiangnan He (staff.ustc.edu.cn/~hexn/), Kuan Deng, Yingxin Wu. 7 | 8 | (We also provide Pytorch implementation for LightGCN : https://github.com/gusye1234/LightGCN-PyTorch. Contributors: Jianbai Ye.) 9 | 10 | ## Introduction 11 | In this work, we aim to simplify the design of GCN to make it more concise and appropriate for recommendation. We propose a new model named LightGCN, including only the most essential component in GCN—neighborhood aggregation—for collaborative filtering. 12 | 13 | ## Environment Requirement 14 | The code has been tested running under Python 3.6.5. The required packages are as follows: 15 | * tensorflow == 1.11.0 16 | * numpy == 1.14.3 17 | * scipy == 1.1.0 18 | * sklearn == 0.19.1 19 | * cython == 0.29.15 20 | ## C++ evaluator 21 | We have implemented C++ code to output metrics during and after training, which is much more efficient than python evaluator. It needs to be compiled first using the following command. 22 | ``` 23 | python setup.py build_ext --inplace 24 | ``` 25 | After compilation, the C++ code will run by default instead of Python code. 26 | 27 | ## Examples to run a 3-layer LightGCN 28 | The instruction of commands has been clearly stated in the codes (see the parser function in LightGCN/utility/parser.py). 29 | ### Gowalla dataset 30 | * Command 31 | ``` 32 | python LightGCN.py --dataset gowalla --regs [1e-4] --embed_size 64 --layer_size [64,64,64] --lr 0.001 --batch_size 2048 --epoch 1000 33 | ``` 34 | * Output log : 35 | ``` 36 | eval_score_matrix_foldout with cpp 37 | n_users=29858, n_items=40981 38 | n_interactions=1027370 39 | n_train=810128, n_test=217242, sparsity=0.00084 40 | ... 41 | Epoch 1 [30.3s]: train==[0.46925=0.46911 + 0.00014] 42 | Epoch 2 [27.1s]: train==[0.21866=0.21817 + 0.00048] 43 | ... 44 | Epoch 879 [81.6s + 31.3s]: test==[0.13271=0.12645 + 0.00626 + 0.00000], recall=[0.18201], precision=[0.05601], ndcg=[0.15555] 45 | Early stopping is trigger at step: 5 log:0.18201370537281036 46 | Best Iter=[38]@[32829.6] recall=[0.18236], precision=[0.05607], ndcg=[0.15539] 47 | ``` 48 | 49 | 50 | ### Yelp2018 dataset 51 | * Command 52 | ``` 53 | python LightGCN.py --dataset yelp2018 --regs [1e-4] --embed_size 64 --layer_size [64,64,64] --lr 0.001 --batch_size 2048 --epoch 1000 54 | ``` 55 | * Output log : 56 | ``` 57 | eval_score_matrix_foldout with cpp 58 | n_users=31668, n_items=38048 59 | n_interactions=1561406 60 | n_train=1237259, n_test=324147, sparsity=0.00130 61 | ... 62 | Epoch 1 [56.5s]: train==[0.33843=0.33815 + 0.00028] 63 | Epoch 2 [53.1s]: train==[0.16253=0.16192 + 0.00061] 64 | ... 65 | Epoch 679 [104.6s + 12.9s]: test==[0.17217=0.16289 + 0.00929 + 0.00000], recall=[0.06359], precision=[0.02874], ndcg=[0.05240] 66 | Early stopping is trigger at step: 5 log:0.06359195709228516 67 | Best Iter=[28]@[42815.0] recall=[0.06367], precision=[0.02868], ndcg=[0.05236] 68 | ``` 69 | ### Amazon-book dataset 70 | * Command 71 | ``` 72 | python LightGCN.py --dataset amazon-book --regs [1e-4] --embed_size 64 --layer_size [64,64,64] --lr 0.001 --batch_size 8192 --epoch 1000 73 | ``` 74 | * Output log : 75 | ``` 76 | eval_score_matrix_foldout with cpp 77 | n_users=52643, n_items=91599 78 | n_interactions=2984108 79 | n_train=2380730, n_test=603378, sparsity=0.00062 80 | ... 81 | Epoch 1 [53.2s]: train==[0.57471=0.57463 + 0.00008] 82 | Epoch 2 [47.3s]: train==[0.31518=0.31478 + 0.00040] 83 | ... 84 | Epoch 779 [181.7s + 79.0s]: test==[0.20300=0.19434 + 0.00866 + 0.00000], recall=[0.04120], precision=[0.01703], ndcg=[0.03186] 85 | Early stopping is trigger at step: 5 log:0.04119725897908211 86 | Best Iter=[33]@[49875.4] recall=[0.04123], precision=[0.01710], ndcg=[0.03189] 87 | ``` 88 | NOTE : the duration of training and testing depends on the running environment. 89 | ## Dataset 90 | We provide three processed datasets: Gowalla, Yelp2018 and Amazon-book. 91 | * `train.txt` 92 | * Train file. 93 | * Each line is a user with her/his positive interactions with items: userID\t a list of itemID\n. 94 | 95 | * `test.txt` 96 | * Test file (positive instances). 97 | * Each line is a user with her/his positive interactions with items: userID\t a list of itemID\n. 98 | * Note that here we treat all unobserved interactions as the negative instances when reporting performance. 99 | 100 | * `user_list.txt` 101 | * User file. 102 | * Each line is a triplet (org_id, remap_id) for one user, where org_id and remap_id represent the ID of the user in the original and our datasets, respectively. 103 | 104 | * `item_list.txt` 105 | * Item file. 106 | * Each line is a triplet (org_id, remap_id) for one item, where org_id and remap_id represent the ID of the item in the original and our datasets, respectively. 107 | 108 | ## Efficiency Improvements: 109 | * Parallelized sampling on CPU 110 | * C++ evaluation for top-k recommendation 111 | 112 | ======= 113 | -------------------------------------------------------------------------------- /evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | # import eval_score_matrix_foldout 2 | try: 3 | from evaluator.cpp.evaluate_foldout import eval_score_matrix_foldout 4 | print("eval_score_matrix_foldout with cpp") 5 | except: 6 | from evaluator.python.evaluate_foldout import eval_score_matrix_foldout 7 | print("eval_score_matrix_foldout with python") 8 | 9 | -------------------------------------------------------------------------------- /evaluator/cpp/apt_evaluate_foldout.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | """ 3 | @author: Zhongchuan Sun 4 | """ 5 | import numpy as np 6 | cimport numpy as np 7 | import os 8 | from .apt_tools import get_float_type, get_int_type, is_ndarray 9 | from cpython.mem cimport PyMem_Malloc, PyMem_Free 10 | 11 | cdef extern from "include/tools.h": 12 | void c_top_k_array_index(float *scores_pt, int columns_num, int rows_num, 13 | int top_k, int thread_num, int *rankings_pt) 14 | 15 | cdef extern from "include/evaluate_foldout.h": 16 | void evaluate_foldout(int users_num, 17 | int *rankings, int rank_len, 18 | int **ground_truths, int *ground_truths_num, 19 | int thread_num, float *results) 20 | 21 | 22 | def apt_evaluate_foldout(ranking_scores, ground_truth, top_k = 20, thread_num=None): 23 | metrics_num = 5 24 | users_num, rank_len = np.shape(ranking_scores) 25 | if users_num != len(ground_truth): 26 | raise Exception("The lengths of 'ranking_scores' and 'ground_truth' are different.") 27 | thread_num = (thread_num or (os.cpu_count() or 1) * 5) 28 | 29 | float_type = get_float_type() 30 | int_type = get_int_type() 31 | 32 | if not is_ndarray(ranking_scores, float_type): 33 | ranking_scores = np.array(ranking_scores, dtype=float_type) 34 | 35 | # get the pointer of ranking scores 36 | cdef float *scores_pt = np.PyArray_DATA(ranking_scores) 37 | 38 | # store ranks results 39 | top_rankings = np.zeros([users_num, top_k], dtype=int_type) 40 | cdef int *rankings_pt = np.PyArray_DATA(top_rankings) 41 | 42 | # get top k rating index 43 | c_top_k_array_index(scores_pt, rank_len, users_num, top_k, thread_num, rankings_pt) 44 | 45 | # the pointer of ground truth, the pointer of the length array of ground truth 46 | ground_truth_pt = PyMem_Malloc(users_num * sizeof(int *)) 47 | ground_truth_num = np.zeros([users_num], dtype=int_type) 48 | ground_truth_num_pt = np.PyArray_DATA(ground_truth_num) 49 | for u in range(users_num): 50 | if not is_ndarray(ground_truth[u], int_type): 51 | ground_truth[u] = np.array(ground_truth[u], dtype=int_type, copy=True) 52 | ground_truth_pt[u] = np.PyArray_DATA(ground_truth[u]) 53 | ground_truth_num[u] = len(ground_truth[u]) 54 | 55 | #evaluate results 56 | results = np.zeros([users_num, metrics_num*top_k], dtype=float_type) 57 | results_pt = np.PyArray_DATA(results) 58 | 59 | #evaluate 60 | evaluate_foldout(users_num, rankings_pt, top_k, ground_truth_pt, ground_truth_num_pt, thread_num, results_pt) 61 | 62 | #release the allocated space 63 | PyMem_Free(ground_truth_pt) 64 | 65 | return results 66 | -------------------------------------------------------------------------------- /evaluator/cpp/apt_tools.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | """ 3 | @author: Zhongchuan Sun 4 | """ 5 | import numpy as np 6 | 7 | 8 | def get_float_type(): 9 | cdef size_of_float = sizeof(float)*8 10 | if size_of_float == 32: 11 | return np.float32 12 | elif size_of_float == 64: 13 | return np.float64 14 | else: 15 | raise EnvironmentError("The size of 'float' is %d, but 32 or 64." % size_of_float) 16 | 17 | def get_int_type(): 18 | cdef size_of_int = sizeof(int)*8 19 | if size_of_int == 16: 20 | return np.int16 21 | elif size_of_int == 32: 22 | return np.int32 23 | else: 24 | raise EnvironmentError("The size of 'int' is %d, but 16 or 32." % size_of_int) 25 | 26 | def is_ndarray(array, dtype): 27 | if not isinstance(array, np.ndarray): 28 | return False 29 | if array.dtype != dtype: 30 | return False 31 | if array.base is not None: 32 | return False 33 | return True 34 | -------------------------------------------------------------------------------- /evaluator/cpp/evaluate_foldout.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zhongchuan Sun 3 | """ 4 | try: 5 | from .apt_evaluate_foldout import apt_evaluate_foldout 6 | except: 7 | raise ImportError("Import apt_evaluate_foldout error!") 8 | import numpy as np 9 | import os 10 | import sys 11 | 12 | def eval_score_matrix_foldout(score_matrix, test_items, top_k=20, thread_num=None): 13 | if len(score_matrix) != len(test_items): 14 | raise ValueError("The lengths of score_matrix and test_items are not equal.") 15 | thread_num = (thread_num or (os.cpu_count() or 1) * 5) 16 | results = apt_evaluate_foldout(score_matrix, test_items, top_k, thread_num) 17 | 18 | return results 19 | -------------------------------------------------------------------------------- /evaluator/cpp/include/evaluate_foldout.h: -------------------------------------------------------------------------------- 1 | /* 2 | @author: Zhongchuan Sun 3 | */ 4 | #ifndef EVALUATE_FOLDOUT_H 5 | #define EVALUATE_FOLDOUT_H 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "thread_pool.h" 11 | 12 | using std::vector; 13 | using std::set; 14 | using std::future; 15 | 16 | vector precision(int *rank, int top_k, int *truth, int truth_len) 17 | { 18 | vector result(top_k); 19 | int hits = 0; 20 | set truth_set(truth, truth+truth_len); 21 | for(int i=0; i recall(int *rank, int top_k, int *truth, int truth_len) 33 | { 34 | vector result(top_k); 35 | int hits = 0; 36 | set truth_set(truth, truth+truth_len); 37 | for(int i=0; i ap(int *rank, int top_k, int *truth, int truth_len) 49 | { 50 | vector result(top_k); // = precision(rank, top_k, truth, truth_len); 51 | int hits = 0; 52 | float pre = 0; 53 | float sum_pre = 0; 54 | set truth_set(truth, truth+truth_len); 55 | for(int i=0; i ndcg(int *rank, int top_k, int *truth, int truth_len) 69 | { 70 | vector result(top_k); 71 | float iDCG = 0; 72 | float DCG = 0; 73 | set truth_set(truth, truth+truth_len); 74 | for(int i=0; i mrr(int *rank, int top_k, int *truth, int truth_len) 90 | { 91 | vector result(top_k); 92 | float rr = 0; 93 | set truth_set(truth, truth+truth_len); 94 | for(int i=0; i > > sync_pre_results; 123 | vector< future< vector > > sync_recall_results; 124 | vector< future< vector > > sync_ap_results; 125 | vector< future< vector > > sync_ndcg_results; 126 | vector< future< vector > > sync_mrr_results; 127 | 128 | for(int uid=0; uid 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | class ThreadPool { 15 | public: 16 | ThreadPool(size_t); 17 | template 18 | auto enqueue(F&& f, Args&&... args) 19 | -> std::future::type>; 20 | ~ThreadPool(); 21 | private: 22 | // need to keep track of threads so we can join them 23 | std::vector< std::thread > workers; 24 | // the task queue 25 | std::queue< std::function > tasks; 26 | 27 | // synchronization 28 | std::mutex queue_mutex; 29 | std::condition_variable condition; 30 | bool stop; 31 | }; 32 | 33 | // the constructor just launches some amount of workers 34 | inline ThreadPool::ThreadPool(size_t threads) 35 | : stop(false) 36 | { 37 | for(size_t i = 0;i task; 44 | 45 | { 46 | std::unique_lock lock(this->queue_mutex); 47 | this->condition.wait(lock, 48 | [this]{ return this->stop || !this->tasks.empty(); }); 49 | if(this->stop && this->tasks.empty()) 50 | return; 51 | task = std::move(this->tasks.front()); 52 | this->tasks.pop(); 53 | } 54 | 55 | task(); 56 | } 57 | } 58 | ); 59 | } 60 | 61 | // add new work item to the pool 62 | template 63 | auto ThreadPool::enqueue(F&& f, Args&&... args) 64 | -> std::future::type> 65 | { 66 | using return_type = typename std::result_of::type; 67 | 68 | auto task = std::make_shared< std::packaged_task >( 69 | std::bind(std::forward(f), std::forward(args)...) 70 | ); 71 | 72 | std::future res = task->get_future(); 73 | { 74 | std::unique_lock lock(queue_mutex); 75 | 76 | // don't allow enqueueing after stopping the pool 77 | if(stop) 78 | throw std::runtime_error("enqueue on stopped ThreadPool"); 79 | 80 | tasks.emplace([task](){ (*task)(); }); 81 | } 82 | condition.notify_one(); 83 | return res; 84 | } 85 | 86 | // the destructor joins all threads 87 | inline ThreadPool::~ThreadPool() 88 | { 89 | { 90 | std::unique_lock lock(queue_mutex); 91 | stop = true; 92 | } 93 | condition.notify_all(); 94 | for(std::thread &worker: workers) 95 | worker.join(); 96 | } 97 | 98 | #endif 99 | -------------------------------------------------------------------------------- /evaluator/cpp/include/tools.h: -------------------------------------------------------------------------------- 1 | /* 2 | @author: Zhongchuan Sun 3 | */ 4 | #ifndef TOOLS_H 5 | #define TOOLS_H 6 | 7 | #include "thread_pool.h" 8 | #include 9 | #include 10 | using std::vector; 11 | 12 | 13 | void c_top_k_index(float *ratings, int rating_len, int top_k, int *result) 14 | { 15 | vector index(rating_len); 16 | for(auto i=0; ibool{return ratings[x1]>ratings[x2];}); 22 | } 23 | 24 | void c_top_k_array_index(float *scores_pt, int columns_num, int rows_num, int top_k, int thread_num, int *rankings_pt) 25 | { 26 | ThreadPool pool(thread_num); 27 | for(int i=0; i= best_value) or (expected_order == 'dec' and log_value <= best_value): 40 | stopping_step = 0 41 | best_value = log_value 42 | else: 43 | stopping_step += 1 44 | 45 | if stopping_step >= flag_step: 46 | print("Early stopping is trigger at step: {} log:{}".format(flag_step, log_value)) 47 | should_stop = True 48 | else: 49 | should_stop = False 50 | return best_value, stopping_step, should_stop 51 | -------------------------------------------------------------------------------- /utility/load_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 10, 2018 3 | Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in: 4 | Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019. 5 | 6 | @author: Xiang Wang (xiangwang@u.nus.edu) 7 | ''' 8 | import numpy as np 9 | import random as rd 10 | import scipy.sparse as sp 11 | from time import time 12 | 13 | class Data(object): 14 | def __init__(self, path, batch_size): 15 | self.path = path 16 | self.batch_size = batch_size 17 | 18 | train_file = path + '/train.txt' 19 | test_file = path + '/test.txt' 20 | 21 | self.n_users, self.n_items = 0, 0 22 | self.n_train, self.n_test = 0, 0 23 | self.neg_pools = {} 24 | 25 | self.exist_users = [] 26 | 27 | with open(train_file) as f: 28 | for l in f.readlines(): 29 | if len(l) > 0: 30 | l = l.strip('\n').split(' ') 31 | items = [int(i) for i in l[1:]] 32 | uid = int(l[0]) 33 | self.exist_users.append(uid) 34 | self.n_items = max(self.n_items, max(items)) 35 | self.n_users = max(self.n_users, uid) 36 | self.n_train += len(items) 37 | 38 | with open(test_file) as f: 39 | for l in f.readlines(): 40 | if len(l) > 0: 41 | l = l.strip('\n') 42 | try: 43 | items = [int(i) for i in l.split(' ')[1:]] 44 | except Exception: 45 | continue 46 | self.n_items = max(self.n_items, max(items)) 47 | self.n_test += len(items) 48 | self.n_items += 1 49 | self.n_users += 1 50 | self.print_statistics() 51 | self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32) 52 | self.train_items, self.test_set = {}, {} 53 | with open(train_file) as f_train: 54 | with open(test_file) as f_test: 55 | for l in f_train.readlines(): 56 | if len(l) == 0: break 57 | l = l.strip('\n') 58 | items = [int(i) for i in l.split(' ')] 59 | uid, train_items = items[0], items[1:] 60 | 61 | for i in train_items: 62 | self.R[uid, i] = 1. 63 | 64 | self.train_items[uid] = train_items 65 | 66 | for l in f_test.readlines(): 67 | if len(l) == 0: break 68 | l = l.strip('\n') 69 | try: 70 | items = [int(i) for i in l.split(' ')] 71 | except Exception: 72 | continue 73 | 74 | uid, test_items = items[0], items[1:] 75 | self.test_set[uid] = test_items 76 | 77 | def get_adj_mat(self): 78 | try: 79 | t1 = time() 80 | adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz') 81 | norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz') 82 | mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz') 83 | print('already load adj matrix', adj_mat.shape, time() - t1) 84 | 85 | except Exception: 86 | adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat() 87 | sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat) 88 | sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat) 89 | sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat) 90 | 91 | try: 92 | pre_adj_mat = sp.load_npz(self.path + '/s_pre_adj_mat.npz') 93 | except Exception: 94 | adj_mat=adj_mat 95 | rowsum = np.array(adj_mat.sum(1)) 96 | d_inv = np.power(rowsum, -0.5).flatten() 97 | 98 | d_inv[np.isinf(d_inv)] = 0. 99 | d_mat_inv = sp.diags(d_inv) 100 | norm_adj = d_mat_inv.dot(adj_mat) 101 | norm_adj = norm_adj.dot(d_mat_inv) 102 | print('generate pre adjacency matrix.') 103 | pre_adj_mat = norm_adj.tocsr() 104 | sp.save_npz(self.path + '/s_pre_adj_mat.npz', norm_adj) 105 | 106 | return adj_mat, norm_adj_mat, mean_adj_mat,pre_adj_mat 107 | 108 | def create_adj_mat(self): 109 | t1 = time() 110 | adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32) 111 | adj_mat = adj_mat.tolil() 112 | R = self.R.tolil() 113 | # prevent memory from overflowing 114 | for i in range(5): 115 | adj_mat[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5), self.n_users:] =\ 116 | R[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)] 117 | adj_mat[self.n_users:,int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)] =\ 118 | R[int(self.n_users*i/5.0):int(self.n_users*(i+1.0)/5)].T 119 | adj_mat = adj_mat.todok() 120 | print('already create adjacency matrix', adj_mat.shape, time() - t1) 121 | 122 | t2 = time() 123 | def normalized_adj_single(adj): 124 | rowsum = np.array(adj.sum(1)) 125 | 126 | d_inv = np.power(rowsum, -1).flatten() 127 | d_inv[np.isinf(d_inv)] = 0. 128 | d_mat_inv = sp.diags(d_inv) 129 | 130 | norm_adj = d_mat_inv.dot(adj) 131 | print('generate single-normalized adjacency matrix.') 132 | return norm_adj.tocoo() 133 | 134 | def check_adj_if_equal(adj): 135 | dense_A = np.array(adj.todense()) 136 | degree = np.sum(dense_A, axis=1, keepdims=False) 137 | 138 | temp = np.dot(np.diag(np.power(degree, -1)), dense_A) 139 | print('check normalized adjacency matrix whether equal to this laplacian matrix.') 140 | return temp 141 | 142 | norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0])) 143 | mean_adj_mat = normalized_adj_single(adj_mat) 144 | 145 | print('already normalize adjacency matrix', time() - t2) 146 | return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr() 147 | 148 | def negative_pool(self): 149 | t1 = time() 150 | for u in self.train_items.keys(): 151 | neg_items = list(set(range(self.n_items)) - set(self.train_items[u])) 152 | pools = [rd.choice(neg_items) for _ in range(100)] 153 | self.neg_pools[u] = pools 154 | print('refresh negative pools', time() - t1) 155 | 156 | def sample(self): 157 | if self.batch_size <= self.n_users: 158 | users = rd.sample(self.exist_users, self.batch_size) 159 | else: 160 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] 161 | 162 | 163 | def sample_pos_items_for_u(u, num): 164 | pos_items = self.train_items[u] 165 | n_pos_items = len(pos_items) 166 | pos_batch = [] 167 | while True: 168 | if len(pos_batch) == num: break 169 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0] 170 | pos_i_id = pos_items[pos_id] 171 | 172 | if pos_i_id not in pos_batch: 173 | pos_batch.append(pos_i_id) 174 | return pos_batch 175 | 176 | def sample_neg_items_for_u(u, num): 177 | neg_items = [] 178 | while True: 179 | if len(neg_items) == num: break 180 | neg_id = np.random.randint(low=0, high=self.n_items,size=1)[0] 181 | if neg_id not in self.train_items[u] and neg_id not in neg_items: 182 | neg_items.append(neg_id) 183 | return neg_items 184 | 185 | def sample_neg_items_for_u_from_pools(u, num): 186 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u])) 187 | return rd.sample(neg_items, num) 188 | 189 | pos_items, neg_items = [], [] 190 | for u in users: 191 | pos_items += sample_pos_items_for_u(u, 1) 192 | neg_items += sample_neg_items_for_u(u, 1) 193 | 194 | return users, pos_items, neg_items 195 | 196 | def sample_test(self): 197 | if self.batch_size <= self.n_users: 198 | users = rd.sample(self.test_set.keys(), self.batch_size) 199 | else: 200 | users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] 201 | 202 | def sample_pos_items_for_u(u, num): 203 | pos_items = self.test_set[u] 204 | n_pos_items = len(pos_items) 205 | pos_batch = [] 206 | while True: 207 | if len(pos_batch) == num: break 208 | pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0] 209 | pos_i_id = pos_items[pos_id] 210 | 211 | if pos_i_id not in pos_batch: 212 | pos_batch.append(pos_i_id) 213 | return pos_batch 214 | 215 | def sample_neg_items_for_u(u, num): 216 | neg_items = [] 217 | while True: 218 | if len(neg_items) == num: break 219 | neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0] 220 | if neg_id not in (self.test_set[u]+self.train_items[u]) and neg_id not in neg_items: 221 | neg_items.append(neg_id) 222 | return neg_items 223 | 224 | def sample_neg_items_for_u_from_pools(u, num): 225 | neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u])) 226 | return rd.sample(neg_items, num) 227 | 228 | pos_items, neg_items = [], [] 229 | for u in users: 230 | pos_items += sample_pos_items_for_u(u, 1) 231 | neg_items += sample_neg_items_for_u(u, 1) 232 | 233 | return users, pos_items, neg_items 234 | 235 | 236 | 237 | 238 | 239 | 240 | def get_num_users_items(self): 241 | return self.n_users, self.n_items 242 | 243 | def print_statistics(self): 244 | print('n_users=%d, n_items=%d' % (self.n_users, self.n_items)) 245 | print('n_interactions=%d' % (self.n_train + self.n_test)) 246 | print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items))) 247 | 248 | 249 | def get_sparsity_split(self): 250 | try: 251 | split_uids, split_state = [], [] 252 | lines = open(self.path + '/sparsity.split', 'r').readlines() 253 | 254 | for idx, line in enumerate(lines): 255 | if idx % 2 == 0: 256 | split_state.append(line.strip()) 257 | print(line.strip()) 258 | else: 259 | split_uids.append([int(uid) for uid in line.strip().split(' ')]) 260 | print('get sparsity split.') 261 | 262 | except Exception: 263 | split_uids, split_state = self.create_sparsity_split() 264 | f = open(self.path + '/sparsity.split', 'w') 265 | for idx in range(len(split_state)): 266 | f.write(split_state[idx] + '\n') 267 | f.write(' '.join([str(uid) for uid in split_uids[idx]]) + '\n') 268 | print('create sparsity split.') 269 | 270 | return split_uids, split_state 271 | 272 | 273 | 274 | def create_sparsity_split(self): 275 | all_users_to_test = list(self.test_set.keys()) 276 | user_n_iid = dict() 277 | 278 | # generate a dictionary to store (key=n_iids, value=a list of uid). 279 | for uid in all_users_to_test: 280 | train_iids = self.train_items[uid] 281 | test_iids = self.test_set[uid] 282 | 283 | n_iids = len(train_iids) + len(test_iids) 284 | 285 | if n_iids not in user_n_iid.keys(): 286 | user_n_iid[n_iids] = [uid] 287 | else: 288 | user_n_iid[n_iids].append(uid) 289 | split_uids = list() 290 | 291 | # split the whole user set into four subset. 292 | temp = [] 293 | count = 1 294 | fold = 4 295 | n_count = (self.n_train + self.n_test) 296 | n_rates = 0 297 | 298 | split_state = [] 299 | for idx, n_iids in enumerate(sorted(user_n_iid)): 300 | temp += user_n_iid[n_iids] 301 | n_rates += n_iids * len(user_n_iid[n_iids]) 302 | n_count -= n_iids * len(user_n_iid[n_iids]) 303 | 304 | if n_rates >= count * 0.25 * (self.n_train + self.n_test): 305 | split_uids.append(temp) 306 | 307 | state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' %(n_iids, len(temp), n_rates) 308 | split_state.append(state) 309 | print(state) 310 | 311 | temp = [] 312 | n_rates = 0 313 | fold -= 1 314 | 315 | if idx == len(user_n_iid.keys()) - 1 or n_count == 0: 316 | split_uids.append(temp) 317 | 318 | state = '#inter per user<=[%d], #users=[%d], #all rates=[%d]' % (n_iids, len(temp), n_rates) 319 | split_state.append(state) 320 | print(state) 321 | 322 | 323 | 324 | return split_uids, split_state 325 | -------------------------------------------------------------------------------- /utility/parser.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Oct 10, 2018 3 | Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in: 4 | Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019. 5 | 6 | @author: Xiang Wang (xiangwang@u.nus.edu) 7 | ''' 8 | import argparse 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description="Run NGCF.") 12 | parser.add_argument('--weights_path', nargs='?', default='', 13 | help='Store model path.') 14 | parser.add_argument('--data_path', nargs='?', default='Data/', 15 | help='Input data path.') 16 | parser.add_argument('--proj_path', nargs='?', default='', 17 | help='Project path.') 18 | 19 | parser.add_argument('--dataset', nargs='?', default='gowalla', 20 | help='Choose a dataset from {gowalla, yelp2018, amazon-book}') 21 | parser.add_argument('--pretrain', type=int, default=0, 22 | help='0: No pretrain, -1: Pretrain with the learned embeddings, 1:Pretrain with stored models.') 23 | parser.add_argument('--verbose', type=int, default=1, 24 | help='Interval of evaluation.') 25 | parser.add_argument('--is_norm', type=int, default=1, 26 | help='Interval of evaluation.') 27 | parser.add_argument('--epoch', type=int, default=1000, 28 | help='Number of epoch.') 29 | 30 | parser.add_argument('--embed_size', type=int, default=64, 31 | help='Embedding size.') 32 | parser.add_argument('--layer_size', nargs='?', default='[64, 64, 64, 64]', 33 | help='Output sizes of every layer') 34 | parser.add_argument('--batch_size', type=int, default=1024, 35 | help='Batch size.') 36 | 37 | parser.add_argument('--regs', nargs='?', default='[1e-5,1e-5,1e-2]', 38 | help='Regularizations.') 39 | parser.add_argument('--lr', type=float, default=0.01, 40 | help='Learning rate.') 41 | 42 | parser.add_argument('--model_type', nargs='?', default='lightgcn', 43 | help='Specify the name of model (lightgcn).') 44 | parser.add_argument('--adj_type', nargs='?', default='pre', 45 | help='Specify the type of the adjacency (laplacian) matrix from {plain, norm, mean}.') 46 | parser.add_argument('--alg_type', nargs='?', default='lightgcn', 47 | help='Specify the type of the graph convolutional layer from {ngcf, gcn, gcmc}.') 48 | 49 | parser.add_argument('--gpu_id', type=int, default=0, 50 | help='0 for NAIS_prod, 1 for NAIS_concat') 51 | 52 | parser.add_argument('--node_dropout_flag', type=int, default=0, 53 | help='0: Disable node dropout, 1: Activate node dropout') 54 | parser.add_argument('--node_dropout', nargs='?', default='[0.1]', 55 | help='Keep probability w.r.t. node dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.') 56 | parser.add_argument('--mess_dropout', nargs='?', default='[0.1]', 57 | help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.') 58 | 59 | parser.add_argument('--Ks', nargs='?', default='[20]', 60 | help='Top k(s) recommend') 61 | 62 | parser.add_argument('--save_flag', type=int, default=0, 63 | help='0: Disable model saver, 1: Activate model saver') 64 | 65 | parser.add_argument('--test_flag', nargs='?', default='part', 66 | help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch') 67 | 68 | parser.add_argument('--report', type=int, default=0, 69 | help='0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels') 70 | 71 | return parser.parse_args() 72 | --------------------------------------------------------------------------------