├── .DS_Store ├── .idea ├── .gitignore ├── GCEGNN.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── README.md ├── aggregator.py ├── build_graph.py ├── datasets ├── .DS_Store ├── Nowplaying │ ├── all_train_seq.txt │ ├── test.txt │ └── train.txt ├── Tmall │ ├── all_train_seq.txt │ ├── test.txt │ └── train.txt ├── diginetica │ ├── all_train_seq.txt │ ├── test.txt │ └── train.txt ├── process_nowplaying.py └── process_tmall.py ├── main.py ├── model.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../:\GCEGNN\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/GCEGNN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCE-GNN 2 | 3 | ## Code 4 | 5 | This is the source code for SIGIR 2020 Paper: _Global Context Enhanced Graph Neural Networks for Session-based Recommendation_. 6 | 7 | ## Requirements 8 | 9 | - Python 3 10 | - PyTorch >= 1.3.0 11 | - tqdm 12 | 13 | ## Usage 14 | 15 | Data preprocessing: 16 | 17 | The code for data preprocessing can refer to [SR-GNN](https://github.com/CRIPAC-DIG/SR-GNN). 18 | 19 | Train and evaluate the model: 20 | ~~~~ 21 | python build_graph.py --dataset diginetica --sample_num 12 22 | python main.py --dataset diginetica 23 | ~~~~ 24 | 25 | ## Citation 26 | 27 | ~~~~ 28 | @inproceedings{wang2020global, 29 | title={Global Context Enhanced Graph Neural Networks for Session-based Recommendation}, 30 | author={Wang, Ziyang and Wei, Wei and Cong, Gao and Li, Xiao-Li and Mao, Xian-Ling and Qiu, Minghui}, 31 | booktitle={Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval}, 32 | pages={169--178}, 33 | year={2020} 34 | } 35 | ~~~~ -------------------------------------------------------------------------------- /aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import numpy 6 | 7 | 8 | class Aggregator(nn.Module): 9 | def __init__(self, batch_size, dim, dropout, act, name=None): 10 | super(Aggregator, self).__init__() 11 | self.dropout = dropout 12 | self.act = act 13 | self.batch_size = batch_size 14 | self.dim = dim 15 | 16 | def forward(self): 17 | pass 18 | 19 | 20 | class LocalAggregator(nn.Module): 21 | def __init__(self, dim, alpha, dropout=0., name=None): 22 | super(LocalAggregator, self).__init__() 23 | self.dim = dim 24 | self.dropout = dropout 25 | 26 | self.a_0 = nn.Parameter(torch.Tensor(self.dim, 1)) 27 | self.a_1 = nn.Parameter(torch.Tensor(self.dim, 1)) 28 | self.a_2 = nn.Parameter(torch.Tensor(self.dim, 1)) 29 | self.a_3 = nn.Parameter(torch.Tensor(self.dim, 1)) 30 | self.bias = nn.Parameter(torch.Tensor(self.dim)) 31 | 32 | self.leakyrelu = nn.LeakyReLU(alpha) 33 | 34 | def forward(self, hidden, adj, mask_item=None): 35 | h = hidden 36 | batch_size = h.shape[0] 37 | N = h.shape[1] 38 | 39 | a_input = (h.repeat(1, 1, N).view(batch_size, N * N, self.dim) 40 | * h.repeat(1, N, 1)).view(batch_size, N, N, self.dim) 41 | 42 | e_0 = torch.matmul(a_input, self.a_0) 43 | e_1 = torch.matmul(a_input, self.a_1) 44 | e_2 = torch.matmul(a_input, self.a_2) 45 | e_3 = torch.matmul(a_input, self.a_3) 46 | 47 | e_0 = self.leakyrelu(e_0).squeeze(-1).view(batch_size, N, N) 48 | e_1 = self.leakyrelu(e_1).squeeze(-1).view(batch_size, N, N) 49 | e_2 = self.leakyrelu(e_2).squeeze(-1).view(batch_size, N, N) 50 | e_3 = self.leakyrelu(e_3).squeeze(-1).view(batch_size, N, N) 51 | 52 | mask = -9e15 * torch.ones_like(e_0) 53 | alpha = torch.where(adj.eq(1), e_0, mask) 54 | alpha = torch.where(adj.eq(2), e_1, alpha) 55 | alpha = torch.where(adj.eq(3), e_2, alpha) 56 | alpha = torch.where(adj.eq(4), e_3, alpha) 57 | alpha = torch.softmax(alpha, dim=-1) 58 | 59 | output = torch.matmul(alpha, h) 60 | return output 61 | 62 | 63 | class GlobalAggregator(nn.Module): 64 | def __init__(self, dim, dropout, act=torch.relu, name=None): 65 | super(GlobalAggregator, self).__init__() 66 | self.dropout = dropout 67 | self.act = act 68 | self.dim = dim 69 | 70 | self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim)) 71 | self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1)) 72 | self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim)) 73 | self.bias = nn.Parameter(torch.Tensor(self.dim)) 74 | 75 | def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None): 76 | if extra_vector is not None: 77 | alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1) 78 | alpha = F.leaky_relu(alpha, negative_slope=0.2) 79 | alpha = torch.matmul(alpha, self.w_2).squeeze(-1) 80 | alpha = torch.softmax(alpha, -1).unsqueeze(-1) 81 | neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2) 82 | else: 83 | neighbor_vector = torch.mean(neighbor_vector, dim=2) 84 | # self_vectors = F.dropout(self_vectors, 0.5, training=self.training) 85 | output = torch.cat([self_vectors, neighbor_vector], -1) 86 | output = F.dropout(output, self.dropout, training=self.training) 87 | output = torch.matmul(output, self.w_3) 88 | output = output.view(batch_size, -1, self.dim) 89 | output = self.act(output) 90 | return output 91 | -------------------------------------------------------------------------------- /build_graph.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--dataset', default='diginetica', help='diginetica/Tmall/Nowplaying') 6 | parser.add_argument('--sample_num', type=int, default=12) 7 | opt = parser.parse_args() 8 | 9 | dataset = opt.dataset 10 | sample_num = opt.sample_num 11 | 12 | seq = pickle.load(open('datasets/' + dataset + '/all_train_seq.txt', 'rb')) 13 | 14 | if dataset == 'diginetica': 15 | num = 43098 16 | elif dataset == "Tmall": 17 | num = 40728 18 | elif dataset == "Nowplaying": 19 | num = 60417 20 | else: 21 | num = 3 22 | 23 | relation = [] 24 | neighbor = [] * num 25 | 26 | all_test = set() 27 | 28 | adj1 = [dict() for _ in range(num)] 29 | adj = [[] for _ in range(num)] 30 | 31 | for i in range(len(seq)): 32 | data = seq[i] 33 | for k in range(1, 4): 34 | for j in range(len(data)-k): 35 | relation.append([data[j], data[j+k]]) 36 | relation.append([data[j+k], data[j]]) 37 | 38 | for tup in relation: 39 | if tup[1] in adj1[tup[0]].keys(): 40 | adj1[tup[0]][tup[1]] += 1 41 | else: 42 | adj1[tup[0]][tup[1]] = 1 43 | 44 | weight = [[] for _ in range(num)] 45 | 46 | for t in range(num): 47 | x = [v for v in sorted(adj1[t].items(), reverse=True, key=lambda x: x[1])] 48 | adj[t] = [v[0] for v in x] 49 | weight[t] = [v[1] for v in x] 50 | 51 | for i in range(num): 52 | adj[i] = adj[i][:sample_num] 53 | weight[i] = weight[i][:sample_num] 54 | 55 | pickle.dump(adj, open('datasets/' + dataset + '/adj_' + str(sample_num) + '.pkl', 'wb')) 56 | pickle.dump(weight, open('datasets/' + dataset + '/num_' + str(sample_num) + '.pkl', 'wb')) 57 | -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/.DS_Store -------------------------------------------------------------------------------- /datasets/Nowplaying/all_train_seq.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/all_train_seq.txt -------------------------------------------------------------------------------- /datasets/Nowplaying/test.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/test.txt -------------------------------------------------------------------------------- /datasets/Nowplaying/train.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/train.txt -------------------------------------------------------------------------------- /datasets/Tmall/all_train_seq.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/all_train_seq.txt -------------------------------------------------------------------------------- /datasets/Tmall/test.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/test.txt -------------------------------------------------------------------------------- /datasets/Tmall/train.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/train.txt -------------------------------------------------------------------------------- /datasets/diginetica/all_train_seq.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/all_train_seq.txt -------------------------------------------------------------------------------- /datasets/diginetica/test.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/test.txt -------------------------------------------------------------------------------- /datasets/diginetica/train.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/train.txt -------------------------------------------------------------------------------- /datasets/process_nowplaying.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import csv 4 | import pickle 5 | import operator 6 | import datetime 7 | import os 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', default='Nowplaying', help='dataset name: diginetica/yoochoose/sample') 11 | opt = parser.parse_args() 12 | print(opt) 13 | 14 | dataset = 'nowplaying.csv' 15 | 16 | 17 | print("-- Starting @ %ss" % datetime.datetime.now()) 18 | with open(dataset, "r") as f: 19 | reader = csv.DictReader(f, delimiter='\t') 20 | sess_clicks = {} 21 | sess_date = {} 22 | ctr = 0 23 | curid = -1 24 | curdate = None 25 | for data in reader: 26 | sessid = int(data['SessionId']) 27 | if curdate and not curid == sessid: 28 | date = curdate 29 | sess_date[curid] = date 30 | curid = sessid 31 | 32 | item = int(data['ItemId']) 33 | curdate = float(data['Time']) 34 | 35 | if sessid in sess_clicks: 36 | sess_clicks[sessid] += [item] 37 | else: 38 | sess_clicks[sessid] = [item] 39 | ctr += 1 40 | date = float(data['Time']) 41 | sess_date[curid] = date 42 | print('ctr:', ctr) 43 | print("-- Reading data @ %ss" % datetime.datetime.now()) 44 | 45 | # Filter out length 1 sessions 46 | for s in list(sess_clicks): 47 | if len(sess_clicks[s]) == 1: 48 | del sess_clicks[s] 49 | del sess_date[s] 50 | 51 | # Count number of times each item appears 52 | iid_counts = {} 53 | for s in sess_clicks: 54 | seq = sess_clicks[s] 55 | for iid in seq: 56 | if iid in iid_counts: 57 | iid_counts[iid] += 1 58 | else: 59 | iid_counts[iid] = 1 60 | 61 | sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1)) 62 | 63 | length = len(sess_clicks) 64 | for s in list(sess_clicks): 65 | curseq = sess_clicks[s] 66 | filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq)) 67 | if len(filseq) < 2 or len(filseq) > 30: 68 | del sess_clicks[s] 69 | del sess_date[s] 70 | else: 71 | sess_clicks[s] = filseq 72 | 73 | # Split out test set based on dates 74 | dates = list(sess_date.items()) 75 | maxdate = dates[0][1] 76 | 77 | for _, date in dates: 78 | if maxdate < date: 79 | maxdate = date 80 | 81 | # Two months for test 82 | splitdate = maxdate - 60 * 86400 83 | 84 | print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0) 85 | tra_sess = filter(lambda x: x[1] < splitdate, dates) 86 | tes_sess = filter(lambda x: x[1] > splitdate, dates) 87 | 88 | # Sort sessions by date 89 | tra_sess = sorted(tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 90 | tes_sess = sorted(tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 91 | print(len(tra_sess)) # 186670 # 7966257 92 | print(len(tes_sess)) # 15979 # 15324 93 | print(tra_sess[:3]) 94 | print(tes_sess[:3]) 95 | print("-- Splitting train set and test set @ %ss" % datetime.datetime.now()) 96 | 97 | # Choosing item count >=5 gives approximately the same number of items as reported in paper 98 | item_dict = {} 99 | # Convert training sessions to sequences and renumber items to start from 1 100 | def obtian_tra(): 101 | train_ids = [] 102 | train_seqs = [] 103 | train_dates = [] 104 | item_ctr = 1 105 | for s, date in tra_sess: 106 | seq = sess_clicks[s] 107 | outseq = [] 108 | for i in seq: 109 | if i in item_dict: 110 | outseq += [item_dict[i]] 111 | else: 112 | outseq += [item_ctr] 113 | item_dict[i] = item_ctr 114 | item_ctr += 1 115 | if len(outseq) < 2: # Doesn't occur 116 | continue 117 | train_ids += [s] 118 | train_dates += [date] 119 | train_seqs += [outseq] 120 | print('item_ctr') 121 | print(item_ctr) # 43098, 37484 122 | return train_ids, train_dates, train_seqs 123 | 124 | 125 | # Convert test sessions to sequences, ignoring items that do not appear in training set 126 | def obtian_tes(): 127 | test_ids = [] 128 | test_seqs = [] 129 | test_dates = [] 130 | for s, date in tes_sess: 131 | seq = sess_clicks[s] 132 | outseq = [] 133 | for i in seq: 134 | if i in item_dict: 135 | outseq += [item_dict[i]] 136 | if len(outseq) < 2: 137 | continue 138 | test_ids += [s] 139 | test_dates += [date] 140 | test_seqs += [outseq] 141 | return test_ids, test_dates, test_seqs 142 | 143 | 144 | tra_ids, tra_dates, tra_seqs = obtian_tra() 145 | tes_ids, tes_dates, tes_seqs = obtian_tes() 146 | 147 | 148 | def process_seqs(iseqs, idates): 149 | out_seqs = [] 150 | out_dates = [] 151 | labs = [] 152 | ids = [] 153 | for id, seq, date in zip(range(len(iseqs)), iseqs, idates): 154 | for i in range(1, len(seq)): 155 | tar = seq[-i] 156 | labs += [tar] 157 | out_seqs += [seq[:-i]] 158 | out_dates += [date] 159 | ids += [id] 160 | return out_seqs, out_dates, labs, ids 161 | 162 | 163 | tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates) 164 | te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates) 165 | tra = (tr_seqs, tr_labs) 166 | tes = (te_seqs, te_labs) 167 | print('train_test') 168 | print(len(tr_seqs)) 169 | print(len(te_seqs)) 170 | print(tr_seqs[:3], tr_dates[:3], tr_labs[:3]) 171 | print(te_seqs[:3], te_dates[:3], te_labs[:3]) 172 | all = 0 173 | 174 | for seq in tra_seqs: 175 | all += len(seq) 176 | for seq in tes_seqs: 177 | all += len(seq) 178 | print('avg length: ', all/(len(tra_seqs) + len(tes_seqs) * 1.0)) 179 | print('all:', all) 180 | 181 | if not os.path.exists('Nowplaying'): 182 | os.makedirs('Nowplaying') 183 | pickle.dump(tra, open('Nowplaying/train.txt', 'wb')) 184 | pickle.dump(tes, open('Nowplaying/test.txt', 'wb')) 185 | pickle.dump(tra_seqs, open('Nowplaying/all_train_seq.txt', 'wb')) 186 | -------------------------------------------------------------------------------- /datasets/process_tmall.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import csv 4 | import pickle 5 | import operator 6 | import datetime 7 | import os 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', default='Tmall', help='dataset name: diginetica/yoochoose/sample') 11 | opt = parser.parse_args() 12 | print(opt) 13 | 14 | with open('tmall_data.csv', 'w') as tmall_data: 15 | with open('tmall/dataset15.csv', 'r') as tmall_file: 16 | header = tmall_file.readline() 17 | tmall_data.write(header) 18 | for line in tmall_file: 19 | data = line[:-1].split('\t') 20 | if int(data[2]) > 120000: 21 | break 22 | tmall_data.write(line) 23 | 24 | print("-- Starting @ %ss" % datetime.datetime.now()) 25 | with open('tmall_data.csv', "r") as f: 26 | reader = csv.DictReader(f, delimiter='\t') 27 | sess_clicks = {} 28 | sess_date = {} 29 | ctr = 0 30 | curid = -1 31 | curdate = None 32 | for data in reader: 33 | sessid = int(data['SessionId']) 34 | if curdate and not curid == sessid: 35 | date = curdate 36 | sess_date[curid] = date 37 | curid = sessid 38 | item = int(data['ItemId']) 39 | curdate = float(data['Time']) 40 | 41 | if sessid in sess_clicks: 42 | sess_clicks[sessid] += [item] 43 | else: 44 | sess_clicks[sessid] = [item] 45 | ctr += 1 46 | date = float(data['Time']) 47 | sess_date[curid] = date 48 | print("-- Reading data @ %ss" % datetime.datetime.now()) 49 | 50 | 51 | # Filter out length 1 sessions 52 | for s in list(sess_clicks): 53 | if len(sess_clicks[s]) == 1: 54 | del sess_clicks[s] 55 | del sess_date[s] 56 | 57 | # Count number of times each item appears 58 | iid_counts = {} 59 | for s in sess_clicks: 60 | seq = sess_clicks[s] 61 | for iid in seq: 62 | if iid in iid_counts: 63 | iid_counts[iid] += 1 64 | else: 65 | iid_counts[iid] = 1 66 | 67 | sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1)) 68 | 69 | length = len(sess_clicks) 70 | for s in list(sess_clicks): 71 | curseq = sess_clicks[s] 72 | filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq)) 73 | if len(filseq) < 2 or len(filseq) > 40: 74 | del sess_clicks[s] 75 | del sess_date[s] 76 | else: 77 | sess_clicks[s] = filseq 78 | 79 | # Split out test set based on dates 80 | dates = list(sess_date.items()) 81 | maxdate = dates[0][1] 82 | 83 | for _, date in dates: 84 | if maxdate < date: 85 | maxdate = date 86 | 87 | # the last of 100 seconds for test 88 | splitdate = maxdate - 100 89 | 90 | print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0) 91 | tra_sess = filter(lambda x: x[1] < splitdate, dates) 92 | tes_sess = filter(lambda x: x[1] > splitdate, dates) 93 | 94 | # Sort sessions by date 95 | tra_sess = sorted(tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 96 | tes_sess = sorted(tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 97 | print(len(tra_sess)) # 186670 # 7966257 98 | print(len(tes_sess)) # 15979 # 15324 99 | print(tra_sess[:3]) 100 | print(tes_sess[:3]) 101 | print("-- Splitting train set and test set @ %ss" % datetime.datetime.now()) 102 | 103 | # Choosing item count >=5 gives approximately the same number of items as reported in paper 104 | item_dict = {} 105 | # Convert training sessions to sequences and renumber items to start from 1 106 | def obtian_tra(): 107 | train_ids = [] 108 | train_seqs = [] 109 | train_dates = [] 110 | item_ctr = 1 111 | for s, date in tra_sess: 112 | seq = sess_clicks[s] 113 | outseq = [] 114 | for i in seq: 115 | if i in item_dict: 116 | outseq += [item_dict[i]] 117 | else: 118 | outseq += [item_ctr] 119 | item_dict[i] = item_ctr 120 | item_ctr += 1 121 | if len(outseq) < 2: # Doesn't occur 122 | continue 123 | train_ids += [s] 124 | train_dates += [date] 125 | train_seqs += [outseq] 126 | print('item_ctr') 127 | print(item_ctr) # 43098, 37484 128 | return train_ids, train_dates, train_seqs 129 | 130 | # Convert test sessions to sequences, ignoring items that do not appear in training set 131 | def obtian_tes(): 132 | test_ids = [] 133 | test_seqs = [] 134 | test_dates = [] 135 | for s, date in tes_sess: 136 | seq = sess_clicks[s] 137 | outseq = [] 138 | for i in seq: 139 | if i in item_dict: 140 | outseq += [item_dict[i]] 141 | if len(outseq) < 2: 142 | continue 143 | test_ids += [s] 144 | test_dates += [date] 145 | test_seqs += [outseq] 146 | return test_ids, test_dates, test_seqs 147 | 148 | tra_ids, tra_dates, tra_seqs = obtian_tra() 149 | tes_ids, tes_dates, tes_seqs = obtian_tes() 150 | 151 | def process_seqs(iseqs, idates): 152 | out_seqs = [] 153 | out_dates = [] 154 | labs = [] 155 | ids = [] 156 | for id, seq, date in zip(range(len(iseqs)), iseqs, idates): 157 | for i in range(1, len(seq)): 158 | tar = seq[-i] 159 | labs += [tar] 160 | out_seqs += [seq[:-i]] 161 | out_dates += [date] 162 | ids += [id] 163 | return out_seqs, out_dates, labs, ids 164 | 165 | tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates) 166 | te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates) 167 | tra = (tr_seqs, tr_labs) 168 | tes = (te_seqs, te_labs) 169 | print('train_test') 170 | print(len(tr_seqs)) 171 | print(len(te_seqs)) 172 | print(tr_seqs[:3], tr_dates[:3], tr_labs[:3]) 173 | print(te_seqs[:3], te_dates[:3], te_labs[:3]) 174 | all = 0 175 | 176 | for seq in tra_seqs: 177 | all += len(seq) 178 | for seq in tes_seqs: 179 | all += len(seq) 180 | print('avg length: ', all * 1.0/(len(tra_seqs) + len(tes_seqs))) 181 | 182 | if not os.path.exists('tmall'): 183 | os.makedirs('tmall') 184 | pickle.dump(tra, open('tmall/train.txt', 'wb')) 185 | pickle.dump(tes, open('tmall/test.txt', 'wb')) 186 | pickle.dump(tra_seqs, open('tmall/all_train_seq.txt', 'wb')) 187 | 188 | # Namespace(dataset='Tmall') 189 | # Splitting train set and test set 190 | # item_ctr 191 | # 40728 192 | # train_test 193 | # 351268 194 | # 25898 195 | # avg length: 6.687663052493478 196 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pickle 4 | from model import * 5 | from utils import * 6 | 7 | 8 | def init_seed(seed=None): 9 | if seed is None: 10 | seed = int(time.time() * 1000 // 1000) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default='diginetica', help='diginetica/Nowplaying/Tmall') 19 | parser.add_argument('--hiddenSize', type=int, default=100) 20 | parser.add_argument('--epoch', type=int, default=20) 21 | parser.add_argument('--activate', type=str, default='relu') 22 | parser.add_argument('--n_sample_all', type=int, default=12) 23 | parser.add_argument('--n_sample', type=int, default=12) 24 | parser.add_argument('--batch_size', type=int, default=100) 25 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate.') 26 | parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay.') 27 | parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay.') 28 | parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty ') 29 | parser.add_argument('--n_iter', type=int, default=1) # [1, 2] 30 | parser.add_argument('--dropout_gcn', type=float, default=0, help='Dropout rate.') # [0, 0.2, 0.4, 0.6, 0.8] 31 | parser.add_argument('--dropout_local', type=float, default=0, help='Dropout rate.') # [0, 0.5] 32 | parser.add_argument('--dropout_global', type=float, default=0.5, help='Dropout rate.') 33 | parser.add_argument('--validation', action='store_true', help='validation') 34 | parser.add_argument('--valid_portion', type=float, default=0.1, help='split the portion') 35 | parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.') 36 | parser.add_argument('--patience', type=int, default=3) 37 | 38 | opt = parser.parse_args() 39 | 40 | 41 | def main(): 42 | init_seed(2020) 43 | 44 | if opt.dataset == 'diginetica': 45 | num_node = 43098 46 | opt.n_iter = 2 47 | opt.dropout_gcn = 0.2 48 | opt.dropout_local = 0.0 49 | elif opt.dataset == 'Nowplaying': 50 | num_node = 60417 51 | opt.n_iter = 1 52 | opt.dropout_gcn = 0.0 53 | opt.dropout_local = 0.0 54 | elif opt.dataset == 'Tmall': 55 | num_node = 40728 56 | opt.n_iter = 1 57 | opt.dropout_gcn = 0.6 58 | opt.dropout_local = 0.5 59 | else: 60 | num_node = 310 61 | 62 | train_data = pickle.load(open('datasets/' + opt.dataset + '/train.txt', 'rb')) 63 | if opt.validation: 64 | train_data, valid_data = split_validation(train_data, opt.valid_portion) 65 | test_data = valid_data 66 | else: 67 | test_data = pickle.load(open('datasets/' + opt.dataset + '/test.txt', 'rb')) 68 | 69 | adj = pickle.load(open('datasets/' + opt.dataset + '/adj_' + str(opt.n_sample_all) + '.pkl', 'rb')) 70 | num = pickle.load(open('datasets/' + opt.dataset + '/num_' + str(opt.n_sample_all) + '.pkl', 'rb')) 71 | train_data = Data(train_data) 72 | test_data = Data(test_data) 73 | 74 | adj, num = handle_adj(adj, num_node, opt.n_sample_all, num) 75 | model = trans_to_cuda(CombineGraph(opt, num_node, adj, num)) 76 | 77 | print(opt) 78 | start = time.time() 79 | best_result = [0, 0] 80 | best_epoch = [0, 0] 81 | bad_counter = 0 82 | 83 | for epoch in range(opt.epoch): 84 | print('-------------------------------------------------------') 85 | print('epoch: ', epoch) 86 | hit, mrr = train_test(model, train_data, test_data) 87 | flag = 0 88 | if hit >= best_result[0]: 89 | best_result[0] = hit 90 | best_epoch[0] = epoch 91 | flag = 1 92 | if mrr >= best_result[1]: 93 | best_result[1] = mrr 94 | best_epoch[1] = epoch 95 | flag = 1 96 | print('Current Result:') 97 | print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f' % (hit, mrr)) 98 | print('Best Result:') 99 | print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d' % ( 100 | best_result[0], best_result[1], best_epoch[0], best_epoch[1])) 101 | bad_counter += 1 - flag 102 | if bad_counter >= opt.patience: 103 | break 104 | print('-------------------------------------------------------') 105 | end = time.time() 106 | print("Run time: %f s" % (end - start)) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from tqdm import tqdm 7 | from aggregator import LocalAggregator, GlobalAggregator 8 | from torch.nn import Module, Parameter 9 | import torch.nn.functional as F 10 | 11 | 12 | class CombineGraph(Module): 13 | def __init__(self, opt, num_node, adj_all, num): 14 | super(CombineGraph, self).__init__() 15 | self.opt = opt 16 | 17 | self.batch_size = opt.batch_size 18 | self.num_node = num_node 19 | self.dim = opt.hiddenSize 20 | self.dropout_local = opt.dropout_local 21 | self.dropout_global = opt.dropout_global 22 | self.hop = opt.n_iter 23 | self.sample_num = opt.n_sample 24 | self.adj_all = trans_to_cuda(torch.Tensor(adj_all)).long() 25 | self.num = trans_to_cuda(torch.Tensor(num)).float() 26 | 27 | # Aggregator 28 | self.local_agg = LocalAggregator(self.dim, self.opt.alpha, dropout=0.0) 29 | self.global_agg = [] 30 | for i in range(self.hop): 31 | if opt.activate == 'relu': 32 | agg = GlobalAggregator(self.dim, opt.dropout_gcn, act=torch.relu) 33 | else: 34 | agg = GlobalAggregator(self.dim, opt.dropout_gcn, act=torch.tanh) 35 | self.add_module('agg_gcn_{}'.format(i), agg) 36 | self.global_agg.append(agg) 37 | 38 | # Item representation & Position representation 39 | self.embedding = nn.Embedding(num_node, self.dim) 40 | self.pos_embedding = nn.Embedding(200, self.dim) 41 | 42 | # Parameters 43 | self.w_1 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim)) 44 | self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1)) 45 | self.glu1 = nn.Linear(self.dim, self.dim) 46 | self.glu2 = nn.Linear(self.dim, self.dim, bias=False) 47 | self.linear_transform = nn.Linear(self.dim, self.dim, bias=False) 48 | 49 | self.leakyrelu = nn.LeakyReLU(opt.alpha) 50 | self.loss_function = nn.CrossEntropyLoss() 51 | self.optimizer = torch.optim.Adam(self.parameters(), lr=opt.lr, weight_decay=opt.l2) 52 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc) 53 | 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | stdv = 1.0 / math.sqrt(self.dim) 58 | for weight in self.parameters(): 59 | weight.data.uniform_(-stdv, stdv) 60 | 61 | def sample(self, target, n_sample): 62 | # neighbor = self.adj_all[target.view(-1)] 63 | # index = np.arange(neighbor.shape[1]) 64 | # np.random.shuffle(index) 65 | # index = index[:n_sample] 66 | # return self.adj_all[target.view(-1)][:, index], self.num[target.view(-1)][:, index] 67 | return self.adj_all[target.view(-1)], self.num[target.view(-1)] 68 | 69 | def compute_scores(self, hidden, mask): 70 | mask = mask.float().unsqueeze(-1) 71 | 72 | batch_size = hidden.shape[0] 73 | len = hidden.shape[1] 74 | pos_emb = self.pos_embedding.weight[:len] 75 | pos_emb = pos_emb.unsqueeze(0).repeat(batch_size, 1, 1) 76 | 77 | hs = torch.sum(hidden * mask, -2) / torch.sum(mask, 1) 78 | hs = hs.unsqueeze(-2).repeat(1, len, 1) 79 | nh = torch.matmul(torch.cat([pos_emb, hidden], -1), self.w_1) 80 | nh = torch.tanh(nh) 81 | nh = torch.sigmoid(self.glu1(nh) + self.glu2(hs)) 82 | beta = torch.matmul(nh, self.w_2) 83 | beta = beta * mask 84 | select = torch.sum(beta * hidden, 1) 85 | 86 | b = self.embedding.weight[1:] # n_nodes x latent_size 87 | scores = torch.matmul(select, b.transpose(1, 0)) 88 | return scores 89 | 90 | def forward(self, inputs, adj, mask_item, item): 91 | batch_size = inputs.shape[0] 92 | seqs_len = inputs.shape[1] 93 | h = self.embedding(inputs) 94 | 95 | # local 96 | h_local = self.local_agg(h, adj, mask_item) 97 | 98 | # global 99 | item_neighbors = [inputs] 100 | weight_neighbors = [] 101 | support_size = seqs_len 102 | 103 | for i in range(1, self.hop + 1): 104 | item_sample_i, weight_sample_i = self.sample(item_neighbors[-1], self.sample_num) 105 | support_size *= self.sample_num 106 | item_neighbors.append(item_sample_i.view(batch_size, support_size)) 107 | weight_neighbors.append(weight_sample_i.view(batch_size, support_size)) 108 | 109 | entity_vectors = [self.embedding(i) for i in item_neighbors] 110 | weight_vectors = weight_neighbors 111 | 112 | session_info = [] 113 | item_emb = self.embedding(item) * mask_item.float().unsqueeze(-1) 114 | 115 | # mean 116 | sum_item_emb = torch.sum(item_emb, 1) / torch.sum(mask_item.float(), -1).unsqueeze(-1) 117 | 118 | # sum 119 | # sum_item_emb = torch.sum(item_emb, 1) 120 | 121 | sum_item_emb = sum_item_emb.unsqueeze(-2) 122 | for i in range(self.hop): 123 | session_info.append(sum_item_emb.repeat(1, entity_vectors[i].shape[1], 1)) 124 | 125 | for n_hop in range(self.hop): 126 | entity_vectors_next_iter = [] 127 | shape = [batch_size, -1, self.sample_num, self.dim] 128 | for hop in range(self.hop - n_hop): 129 | aggregator = self.global_agg[n_hop] 130 | vector = aggregator(self_vectors=entity_vectors[hop], 131 | neighbor_vector=entity_vectors[hop+1].view(shape), 132 | masks=None, 133 | batch_size=batch_size, 134 | neighbor_weight=weight_vectors[hop].view(batch_size, -1, self.sample_num), 135 | extra_vector=session_info[hop]) 136 | entity_vectors_next_iter.append(vector) 137 | entity_vectors = entity_vectors_next_iter 138 | 139 | h_global = entity_vectors[0].view(batch_size, seqs_len, self.dim) 140 | 141 | # combine 142 | h_local = F.dropout(h_local, self.dropout_local, training=self.training) 143 | h_global = F.dropout(h_global, self.dropout_global, training=self.training) 144 | output = h_local + h_global 145 | 146 | return output 147 | 148 | 149 | def trans_to_cuda(variable): 150 | if torch.cuda.is_available(): 151 | return variable.cuda() 152 | else: 153 | return variable 154 | 155 | 156 | def trans_to_cpu(variable): 157 | if torch.cuda.is_available(): 158 | return variable.cpu() 159 | else: 160 | return variable 161 | 162 | 163 | def forward(model, data): 164 | alias_inputs, adj, items, mask, targets, inputs = data 165 | alias_inputs = trans_to_cuda(alias_inputs).long() 166 | items = trans_to_cuda(items).long() 167 | adj = trans_to_cuda(adj).float() 168 | mask = trans_to_cuda(mask).long() 169 | inputs = trans_to_cuda(inputs).long() 170 | 171 | hidden = model(items, adj, mask, inputs) 172 | get = lambda index: hidden[index][alias_inputs[index]] 173 | seq_hidden = torch.stack([get(i) for i in torch.arange(len(alias_inputs)).long()]) 174 | return targets, model.compute_scores(seq_hidden, mask) 175 | 176 | 177 | def train_test(model, train_data, test_data): 178 | print('start training: ', datetime.datetime.now()) 179 | model.train() 180 | total_loss = 0.0 181 | train_loader = torch.utils.data.DataLoader(train_data, num_workers=4, batch_size=model.batch_size, 182 | shuffle=True, pin_memory=True) 183 | for data in tqdm(train_loader): 184 | model.optimizer.zero_grad() 185 | targets, scores = forward(model, data) 186 | targets = trans_to_cuda(targets).long() 187 | loss = model.loss_function(scores, targets - 1) 188 | loss.backward() 189 | model.optimizer.step() 190 | total_loss += loss 191 | print('\tLoss:\t%.3f' % total_loss) 192 | model.scheduler.step() 193 | 194 | print('start predicting: ', datetime.datetime.now()) 195 | model.eval() 196 | test_loader = torch.utils.data.DataLoader(test_data, num_workers=4, batch_size=model.batch_size, 197 | shuffle=False, pin_memory=True) 198 | result = [] 199 | hit, mrr = [], [] 200 | for data in test_loader: 201 | targets, scores = forward(model, data) 202 | sub_scores = scores.topk(20)[1] 203 | sub_scores = trans_to_cpu(sub_scores).detach().numpy() 204 | targets = targets.numpy() 205 | for score, target, mask in zip(sub_scores, targets, test_data.mask): 206 | hit.append(np.isin(target - 1, score)) 207 | if len(np.where(score == target - 1)[0]) == 0: 208 | mrr.append(0) 209 | else: 210 | mrr.append(1 / (np.where(score == target - 1)[0][0] + 1)) 211 | 212 | result.append(np.mean(hit) * 100) 213 | result.append(np.mean(mrr) * 100) 214 | 215 | return result 216 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | def split_validation(train_set, valid_portion): 7 | train_set_x, train_set_y = train_set 8 | n_samples = len(train_set_x) 9 | sidx = np.arange(n_samples, dtype='int32') 10 | np.random.shuffle(sidx) 11 | n_train = int(np.round(n_samples * (1. - valid_portion))) 12 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]] 13 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]] 14 | train_set_x = [train_set_x[s] for s in sidx[:n_train]] 15 | train_set_y = [train_set_y[s] for s in sidx[:n_train]] 16 | 17 | return (train_set_x, train_set_y), (valid_set_x, valid_set_y) 18 | 19 | 20 | def handle_data(inputData, train_len=None): 21 | len_data = [len(nowData) for nowData in inputData] 22 | if train_len is None: 23 | max_len = max(len_data) 24 | else: 25 | max_len = train_len 26 | # reverse the sequence 27 | us_pois = [list(reversed(upois)) + [0] * (max_len - le) if le < max_len else list(reversed(upois[-max_len:])) 28 | for upois, le in zip(inputData, len_data)] 29 | us_msks = [[1] * le + [0] * (max_len - le) if le < max_len else [1] * max_len 30 | for le in len_data] 31 | return us_pois, us_msks, max_len 32 | 33 | 34 | def handle_adj(adj_dict, n_entity, sample_num, num_dict=None): 35 | adj_entity = np.zeros([n_entity, sample_num], dtype=np.int64) 36 | num_entity = np.zeros([n_entity, sample_num], dtype=np.int64) 37 | for entity in range(1, n_entity): 38 | neighbor = list(adj_dict[entity]) 39 | neighbor_weight = list(num_dict[entity]) 40 | n_neighbor = len(neighbor) 41 | if n_neighbor == 0: 42 | continue 43 | if n_neighbor >= sample_num: 44 | sampled_indices = np.random.choice(list(range(n_neighbor)), size=sample_num, replace=False) 45 | else: 46 | sampled_indices = np.random.choice(list(range(n_neighbor)), size=sample_num, replace=True) 47 | adj_entity[entity] = np.array([neighbor[i] for i in sampled_indices]) 48 | num_entity[entity] = np.array([neighbor_weight[i] for i in sampled_indices]) 49 | 50 | return adj_entity, num_entity 51 | 52 | 53 | class Data(Dataset): 54 | def __init__(self, data, train_len=None): 55 | inputs, mask, max_len = handle_data(data[0], train_len) 56 | self.inputs = np.asarray(inputs) 57 | self.targets = np.asarray(data[1]) 58 | self.mask = np.asarray(mask) 59 | self.length = len(data[0]) 60 | self.max_len = max_len 61 | 62 | def __getitem__(self, index): 63 | u_input, mask, target = self.inputs[index], self.mask[index], self.targets[index] 64 | 65 | max_n_node = self.max_len 66 | node = np.unique(u_input) 67 | items = node.tolist() + (max_n_node - len(node)) * [0] 68 | adj = np.zeros((max_n_node, max_n_node)) 69 | for i in np.arange(len(u_input) - 1): 70 | u = np.where(node == u_input[i])[0][0] 71 | adj[u][u] = 1 72 | if u_input[i + 1] == 0: 73 | break 74 | v = np.where(node == u_input[i + 1])[0][0] 75 | if u == v or adj[u][v] == 4: 76 | continue 77 | adj[v][v] = 1 78 | if adj[v][u] == 2: 79 | adj[u][v] = 4 80 | adj[v][u] = 4 81 | else: 82 | adj[u][v] = 2 83 | adj[v][u] = 3 84 | 85 | alias_inputs = [np.where(node == i)[0][0] for i in u_input] 86 | 87 | return [torch.tensor(alias_inputs), torch.tensor(adj), torch.tensor(items), 88 | torch.tensor(mask), torch.tensor(target), torch.tensor(u_input)] 89 | 90 | def __len__(self): 91 | return self.length 92 | --------------------------------------------------------------------------------