├── .DS_Store
├── .idea
├── .gitignore
├── GCEGNN.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── README.md
├── aggregator.py
├── build_graph.py
├── datasets
├── .DS_Store
├── Nowplaying
│ ├── all_train_seq.txt
│ ├── test.txt
│ └── train.txt
├── Tmall
│ ├── all_train_seq.txt
│ ├── test.txt
│ └── train.txt
├── diginetica
│ ├── all_train_seq.txt
│ ├── test.txt
│ └── train.txt
├── process_nowplaying.py
└── process_tmall.py
├── main.py
├── model.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../:\GCEGNN\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/GCEGNN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GCE-GNN
2 |
3 | ## Code
4 |
5 | This is the source code for SIGIR 2020 Paper: _Global Context Enhanced Graph Neural Networks for Session-based Recommendation_.
6 |
7 | ## Requirements
8 |
9 | - Python 3
10 | - PyTorch >= 1.3.0
11 | - tqdm
12 |
13 | ## Usage
14 |
15 | Data preprocessing:
16 |
17 | The code for data preprocessing can refer to [SR-GNN](https://github.com/CRIPAC-DIG/SR-GNN).
18 |
19 | Train and evaluate the model:
20 | ~~~~
21 | python build_graph.py --dataset diginetica --sample_num 12
22 | python main.py --dataset diginetica
23 | ~~~~
24 |
25 | ## Citation
26 |
27 | ~~~~
28 | @inproceedings{wang2020global,
29 | title={Global Context Enhanced Graph Neural Networks for Session-based Recommendation},
30 | author={Wang, Ziyang and Wei, Wei and Cong, Gao and Li, Xiao-Li and Mao, Xian-Ling and Qiu, Minghui},
31 | booktitle={Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval},
32 | pages={169--178},
33 | year={2020}
34 | }
35 | ~~~~
--------------------------------------------------------------------------------
/aggregator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Parameter
4 | import torch.nn.functional as F
5 | import numpy
6 |
7 |
8 | class Aggregator(nn.Module):
9 | def __init__(self, batch_size, dim, dropout, act, name=None):
10 | super(Aggregator, self).__init__()
11 | self.dropout = dropout
12 | self.act = act
13 | self.batch_size = batch_size
14 | self.dim = dim
15 |
16 | def forward(self):
17 | pass
18 |
19 |
20 | class LocalAggregator(nn.Module):
21 | def __init__(self, dim, alpha, dropout=0., name=None):
22 | super(LocalAggregator, self).__init__()
23 | self.dim = dim
24 | self.dropout = dropout
25 |
26 | self.a_0 = nn.Parameter(torch.Tensor(self.dim, 1))
27 | self.a_1 = nn.Parameter(torch.Tensor(self.dim, 1))
28 | self.a_2 = nn.Parameter(torch.Tensor(self.dim, 1))
29 | self.a_3 = nn.Parameter(torch.Tensor(self.dim, 1))
30 | self.bias = nn.Parameter(torch.Tensor(self.dim))
31 |
32 | self.leakyrelu = nn.LeakyReLU(alpha)
33 |
34 | def forward(self, hidden, adj, mask_item=None):
35 | h = hidden
36 | batch_size = h.shape[0]
37 | N = h.shape[1]
38 |
39 | a_input = (h.repeat(1, 1, N).view(batch_size, N * N, self.dim)
40 | * h.repeat(1, N, 1)).view(batch_size, N, N, self.dim)
41 |
42 | e_0 = torch.matmul(a_input, self.a_0)
43 | e_1 = torch.matmul(a_input, self.a_1)
44 | e_2 = torch.matmul(a_input, self.a_2)
45 | e_3 = torch.matmul(a_input, self.a_3)
46 |
47 | e_0 = self.leakyrelu(e_0).squeeze(-1).view(batch_size, N, N)
48 | e_1 = self.leakyrelu(e_1).squeeze(-1).view(batch_size, N, N)
49 | e_2 = self.leakyrelu(e_2).squeeze(-1).view(batch_size, N, N)
50 | e_3 = self.leakyrelu(e_3).squeeze(-1).view(batch_size, N, N)
51 |
52 | mask = -9e15 * torch.ones_like(e_0)
53 | alpha = torch.where(adj.eq(1), e_0, mask)
54 | alpha = torch.where(adj.eq(2), e_1, alpha)
55 | alpha = torch.where(adj.eq(3), e_2, alpha)
56 | alpha = torch.where(adj.eq(4), e_3, alpha)
57 | alpha = torch.softmax(alpha, dim=-1)
58 |
59 | output = torch.matmul(alpha, h)
60 | return output
61 |
62 |
63 | class GlobalAggregator(nn.Module):
64 | def __init__(self, dim, dropout, act=torch.relu, name=None):
65 | super(GlobalAggregator, self).__init__()
66 | self.dropout = dropout
67 | self.act = act
68 | self.dim = dim
69 |
70 | self.w_1 = nn.Parameter(torch.Tensor(self.dim + 1, self.dim))
71 | self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1))
72 | self.w_3 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim))
73 | self.bias = nn.Parameter(torch.Tensor(self.dim))
74 |
75 | def forward(self, self_vectors, neighbor_vector, batch_size, masks, neighbor_weight, extra_vector=None):
76 | if extra_vector is not None:
77 | alpha = torch.matmul(torch.cat([extra_vector.unsqueeze(2).repeat(1, 1, neighbor_vector.shape[2], 1)*neighbor_vector, neighbor_weight.unsqueeze(-1)], -1), self.w_1).squeeze(-1)
78 | alpha = F.leaky_relu(alpha, negative_slope=0.2)
79 | alpha = torch.matmul(alpha, self.w_2).squeeze(-1)
80 | alpha = torch.softmax(alpha, -1).unsqueeze(-1)
81 | neighbor_vector = torch.sum(alpha * neighbor_vector, dim=-2)
82 | else:
83 | neighbor_vector = torch.mean(neighbor_vector, dim=2)
84 | # self_vectors = F.dropout(self_vectors, 0.5, training=self.training)
85 | output = torch.cat([self_vectors, neighbor_vector], -1)
86 | output = F.dropout(output, self.dropout, training=self.training)
87 | output = torch.matmul(output, self.w_3)
88 | output = output.view(batch_size, -1, self.dim)
89 | output = self.act(output)
90 | return output
91 |
--------------------------------------------------------------------------------
/build_graph.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--dataset', default='diginetica', help='diginetica/Tmall/Nowplaying')
6 | parser.add_argument('--sample_num', type=int, default=12)
7 | opt = parser.parse_args()
8 |
9 | dataset = opt.dataset
10 | sample_num = opt.sample_num
11 |
12 | seq = pickle.load(open('datasets/' + dataset + '/all_train_seq.txt', 'rb'))
13 |
14 | if dataset == 'diginetica':
15 | num = 43098
16 | elif dataset == "Tmall":
17 | num = 40728
18 | elif dataset == "Nowplaying":
19 | num = 60417
20 | else:
21 | num = 3
22 |
23 | relation = []
24 | neighbor = [] * num
25 |
26 | all_test = set()
27 |
28 | adj1 = [dict() for _ in range(num)]
29 | adj = [[] for _ in range(num)]
30 |
31 | for i in range(len(seq)):
32 | data = seq[i]
33 | for k in range(1, 4):
34 | for j in range(len(data)-k):
35 | relation.append([data[j], data[j+k]])
36 | relation.append([data[j+k], data[j]])
37 |
38 | for tup in relation:
39 | if tup[1] in adj1[tup[0]].keys():
40 | adj1[tup[0]][tup[1]] += 1
41 | else:
42 | adj1[tup[0]][tup[1]] = 1
43 |
44 | weight = [[] for _ in range(num)]
45 |
46 | for t in range(num):
47 | x = [v for v in sorted(adj1[t].items(), reverse=True, key=lambda x: x[1])]
48 | adj[t] = [v[0] for v in x]
49 | weight[t] = [v[1] for v in x]
50 |
51 | for i in range(num):
52 | adj[i] = adj[i][:sample_num]
53 | weight[i] = weight[i][:sample_num]
54 |
55 | pickle.dump(adj, open('datasets/' + dataset + '/adj_' + str(sample_num) + '.pkl', 'wb'))
56 | pickle.dump(weight, open('datasets/' + dataset + '/num_' + str(sample_num) + '.pkl', 'wb'))
57 |
--------------------------------------------------------------------------------
/datasets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/.DS_Store
--------------------------------------------------------------------------------
/datasets/Nowplaying/all_train_seq.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/all_train_seq.txt
--------------------------------------------------------------------------------
/datasets/Nowplaying/test.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/test.txt
--------------------------------------------------------------------------------
/datasets/Nowplaying/train.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Nowplaying/train.txt
--------------------------------------------------------------------------------
/datasets/Tmall/all_train_seq.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/all_train_seq.txt
--------------------------------------------------------------------------------
/datasets/Tmall/test.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/test.txt
--------------------------------------------------------------------------------
/datasets/Tmall/train.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/Tmall/train.txt
--------------------------------------------------------------------------------
/datasets/diginetica/all_train_seq.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/all_train_seq.txt
--------------------------------------------------------------------------------
/datasets/diginetica/test.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/test.txt
--------------------------------------------------------------------------------
/datasets/diginetica/train.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CCIIPLab/GCE-GNN/212575302e42d64b8cf0b9ffe1b71b0db398d5be/datasets/diginetica/train.txt
--------------------------------------------------------------------------------
/datasets/process_nowplaying.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import csv
4 | import pickle
5 | import operator
6 | import datetime
7 | import os
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dataset', default='Nowplaying', help='dataset name: diginetica/yoochoose/sample')
11 | opt = parser.parse_args()
12 | print(opt)
13 |
14 | dataset = 'nowplaying.csv'
15 |
16 |
17 | print("-- Starting @ %ss" % datetime.datetime.now())
18 | with open(dataset, "r") as f:
19 | reader = csv.DictReader(f, delimiter='\t')
20 | sess_clicks = {}
21 | sess_date = {}
22 | ctr = 0
23 | curid = -1
24 | curdate = None
25 | for data in reader:
26 | sessid = int(data['SessionId'])
27 | if curdate and not curid == sessid:
28 | date = curdate
29 | sess_date[curid] = date
30 | curid = sessid
31 |
32 | item = int(data['ItemId'])
33 | curdate = float(data['Time'])
34 |
35 | if sessid in sess_clicks:
36 | sess_clicks[sessid] += [item]
37 | else:
38 | sess_clicks[sessid] = [item]
39 | ctr += 1
40 | date = float(data['Time'])
41 | sess_date[curid] = date
42 | print('ctr:', ctr)
43 | print("-- Reading data @ %ss" % datetime.datetime.now())
44 |
45 | # Filter out length 1 sessions
46 | for s in list(sess_clicks):
47 | if len(sess_clicks[s]) == 1:
48 | del sess_clicks[s]
49 | del sess_date[s]
50 |
51 | # Count number of times each item appears
52 | iid_counts = {}
53 | for s in sess_clicks:
54 | seq = sess_clicks[s]
55 | for iid in seq:
56 | if iid in iid_counts:
57 | iid_counts[iid] += 1
58 | else:
59 | iid_counts[iid] = 1
60 |
61 | sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1))
62 |
63 | length = len(sess_clicks)
64 | for s in list(sess_clicks):
65 | curseq = sess_clicks[s]
66 | filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq))
67 | if len(filseq) < 2 or len(filseq) > 30:
68 | del sess_clicks[s]
69 | del sess_date[s]
70 | else:
71 | sess_clicks[s] = filseq
72 |
73 | # Split out test set based on dates
74 | dates = list(sess_date.items())
75 | maxdate = dates[0][1]
76 |
77 | for _, date in dates:
78 | if maxdate < date:
79 | maxdate = date
80 |
81 | # Two months for test
82 | splitdate = maxdate - 60 * 86400
83 |
84 | print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0)
85 | tra_sess = filter(lambda x: x[1] < splitdate, dates)
86 | tes_sess = filter(lambda x: x[1] > splitdate, dates)
87 |
88 | # Sort sessions by date
89 | tra_sess = sorted(tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
90 | tes_sess = sorted(tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
91 | print(len(tra_sess)) # 186670 # 7966257
92 | print(len(tes_sess)) # 15979 # 15324
93 | print(tra_sess[:3])
94 | print(tes_sess[:3])
95 | print("-- Splitting train set and test set @ %ss" % datetime.datetime.now())
96 |
97 | # Choosing item count >=5 gives approximately the same number of items as reported in paper
98 | item_dict = {}
99 | # Convert training sessions to sequences and renumber items to start from 1
100 | def obtian_tra():
101 | train_ids = []
102 | train_seqs = []
103 | train_dates = []
104 | item_ctr = 1
105 | for s, date in tra_sess:
106 | seq = sess_clicks[s]
107 | outseq = []
108 | for i in seq:
109 | if i in item_dict:
110 | outseq += [item_dict[i]]
111 | else:
112 | outseq += [item_ctr]
113 | item_dict[i] = item_ctr
114 | item_ctr += 1
115 | if len(outseq) < 2: # Doesn't occur
116 | continue
117 | train_ids += [s]
118 | train_dates += [date]
119 | train_seqs += [outseq]
120 | print('item_ctr')
121 | print(item_ctr) # 43098, 37484
122 | return train_ids, train_dates, train_seqs
123 |
124 |
125 | # Convert test sessions to sequences, ignoring items that do not appear in training set
126 | def obtian_tes():
127 | test_ids = []
128 | test_seqs = []
129 | test_dates = []
130 | for s, date in tes_sess:
131 | seq = sess_clicks[s]
132 | outseq = []
133 | for i in seq:
134 | if i in item_dict:
135 | outseq += [item_dict[i]]
136 | if len(outseq) < 2:
137 | continue
138 | test_ids += [s]
139 | test_dates += [date]
140 | test_seqs += [outseq]
141 | return test_ids, test_dates, test_seqs
142 |
143 |
144 | tra_ids, tra_dates, tra_seqs = obtian_tra()
145 | tes_ids, tes_dates, tes_seqs = obtian_tes()
146 |
147 |
148 | def process_seqs(iseqs, idates):
149 | out_seqs = []
150 | out_dates = []
151 | labs = []
152 | ids = []
153 | for id, seq, date in zip(range(len(iseqs)), iseqs, idates):
154 | for i in range(1, len(seq)):
155 | tar = seq[-i]
156 | labs += [tar]
157 | out_seqs += [seq[:-i]]
158 | out_dates += [date]
159 | ids += [id]
160 | return out_seqs, out_dates, labs, ids
161 |
162 |
163 | tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates)
164 | te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates)
165 | tra = (tr_seqs, tr_labs)
166 | tes = (te_seqs, te_labs)
167 | print('train_test')
168 | print(len(tr_seqs))
169 | print(len(te_seqs))
170 | print(tr_seqs[:3], tr_dates[:3], tr_labs[:3])
171 | print(te_seqs[:3], te_dates[:3], te_labs[:3])
172 | all = 0
173 |
174 | for seq in tra_seqs:
175 | all += len(seq)
176 | for seq in tes_seqs:
177 | all += len(seq)
178 | print('avg length: ', all/(len(tra_seqs) + len(tes_seqs) * 1.0))
179 | print('all:', all)
180 |
181 | if not os.path.exists('Nowplaying'):
182 | os.makedirs('Nowplaying')
183 | pickle.dump(tra, open('Nowplaying/train.txt', 'wb'))
184 | pickle.dump(tes, open('Nowplaying/test.txt', 'wb'))
185 | pickle.dump(tra_seqs, open('Nowplaying/all_train_seq.txt', 'wb'))
186 |
--------------------------------------------------------------------------------
/datasets/process_tmall.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import csv
4 | import pickle
5 | import operator
6 | import datetime
7 | import os
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dataset', default='Tmall', help='dataset name: diginetica/yoochoose/sample')
11 | opt = parser.parse_args()
12 | print(opt)
13 |
14 | with open('tmall_data.csv', 'w') as tmall_data:
15 | with open('tmall/dataset15.csv', 'r') as tmall_file:
16 | header = tmall_file.readline()
17 | tmall_data.write(header)
18 | for line in tmall_file:
19 | data = line[:-1].split('\t')
20 | if int(data[2]) > 120000:
21 | break
22 | tmall_data.write(line)
23 |
24 | print("-- Starting @ %ss" % datetime.datetime.now())
25 | with open('tmall_data.csv', "r") as f:
26 | reader = csv.DictReader(f, delimiter='\t')
27 | sess_clicks = {}
28 | sess_date = {}
29 | ctr = 0
30 | curid = -1
31 | curdate = None
32 | for data in reader:
33 | sessid = int(data['SessionId'])
34 | if curdate and not curid == sessid:
35 | date = curdate
36 | sess_date[curid] = date
37 | curid = sessid
38 | item = int(data['ItemId'])
39 | curdate = float(data['Time'])
40 |
41 | if sessid in sess_clicks:
42 | sess_clicks[sessid] += [item]
43 | else:
44 | sess_clicks[sessid] = [item]
45 | ctr += 1
46 | date = float(data['Time'])
47 | sess_date[curid] = date
48 | print("-- Reading data @ %ss" % datetime.datetime.now())
49 |
50 |
51 | # Filter out length 1 sessions
52 | for s in list(sess_clicks):
53 | if len(sess_clicks[s]) == 1:
54 | del sess_clicks[s]
55 | del sess_date[s]
56 |
57 | # Count number of times each item appears
58 | iid_counts = {}
59 | for s in sess_clicks:
60 | seq = sess_clicks[s]
61 | for iid in seq:
62 | if iid in iid_counts:
63 | iid_counts[iid] += 1
64 | else:
65 | iid_counts[iid] = 1
66 |
67 | sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1))
68 |
69 | length = len(sess_clicks)
70 | for s in list(sess_clicks):
71 | curseq = sess_clicks[s]
72 | filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq))
73 | if len(filseq) < 2 or len(filseq) > 40:
74 | del sess_clicks[s]
75 | del sess_date[s]
76 | else:
77 | sess_clicks[s] = filseq
78 |
79 | # Split out test set based on dates
80 | dates = list(sess_date.items())
81 | maxdate = dates[0][1]
82 |
83 | for _, date in dates:
84 | if maxdate < date:
85 | maxdate = date
86 |
87 | # the last of 100 seconds for test
88 | splitdate = maxdate - 100
89 |
90 | print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0)
91 | tra_sess = filter(lambda x: x[1] < splitdate, dates)
92 | tes_sess = filter(lambda x: x[1] > splitdate, dates)
93 |
94 | # Sort sessions by date
95 | tra_sess = sorted(tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
96 | tes_sess = sorted(tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ]
97 | print(len(tra_sess)) # 186670 # 7966257
98 | print(len(tes_sess)) # 15979 # 15324
99 | print(tra_sess[:3])
100 | print(tes_sess[:3])
101 | print("-- Splitting train set and test set @ %ss" % datetime.datetime.now())
102 |
103 | # Choosing item count >=5 gives approximately the same number of items as reported in paper
104 | item_dict = {}
105 | # Convert training sessions to sequences and renumber items to start from 1
106 | def obtian_tra():
107 | train_ids = []
108 | train_seqs = []
109 | train_dates = []
110 | item_ctr = 1
111 | for s, date in tra_sess:
112 | seq = sess_clicks[s]
113 | outseq = []
114 | for i in seq:
115 | if i in item_dict:
116 | outseq += [item_dict[i]]
117 | else:
118 | outseq += [item_ctr]
119 | item_dict[i] = item_ctr
120 | item_ctr += 1
121 | if len(outseq) < 2: # Doesn't occur
122 | continue
123 | train_ids += [s]
124 | train_dates += [date]
125 | train_seqs += [outseq]
126 | print('item_ctr')
127 | print(item_ctr) # 43098, 37484
128 | return train_ids, train_dates, train_seqs
129 |
130 | # Convert test sessions to sequences, ignoring items that do not appear in training set
131 | def obtian_tes():
132 | test_ids = []
133 | test_seqs = []
134 | test_dates = []
135 | for s, date in tes_sess:
136 | seq = sess_clicks[s]
137 | outseq = []
138 | for i in seq:
139 | if i in item_dict:
140 | outseq += [item_dict[i]]
141 | if len(outseq) < 2:
142 | continue
143 | test_ids += [s]
144 | test_dates += [date]
145 | test_seqs += [outseq]
146 | return test_ids, test_dates, test_seqs
147 |
148 | tra_ids, tra_dates, tra_seqs = obtian_tra()
149 | tes_ids, tes_dates, tes_seqs = obtian_tes()
150 |
151 | def process_seqs(iseqs, idates):
152 | out_seqs = []
153 | out_dates = []
154 | labs = []
155 | ids = []
156 | for id, seq, date in zip(range(len(iseqs)), iseqs, idates):
157 | for i in range(1, len(seq)):
158 | tar = seq[-i]
159 | labs += [tar]
160 | out_seqs += [seq[:-i]]
161 | out_dates += [date]
162 | ids += [id]
163 | return out_seqs, out_dates, labs, ids
164 |
165 | tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates)
166 | te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates)
167 | tra = (tr_seqs, tr_labs)
168 | tes = (te_seqs, te_labs)
169 | print('train_test')
170 | print(len(tr_seqs))
171 | print(len(te_seqs))
172 | print(tr_seqs[:3], tr_dates[:3], tr_labs[:3])
173 | print(te_seqs[:3], te_dates[:3], te_labs[:3])
174 | all = 0
175 |
176 | for seq in tra_seqs:
177 | all += len(seq)
178 | for seq in tes_seqs:
179 | all += len(seq)
180 | print('avg length: ', all * 1.0/(len(tra_seqs) + len(tes_seqs)))
181 |
182 | if not os.path.exists('tmall'):
183 | os.makedirs('tmall')
184 | pickle.dump(tra, open('tmall/train.txt', 'wb'))
185 | pickle.dump(tes, open('tmall/test.txt', 'wb'))
186 | pickle.dump(tra_seqs, open('tmall/all_train_seq.txt', 'wb'))
187 |
188 | # Namespace(dataset='Tmall')
189 | # Splitting train set and test set
190 | # item_ctr
191 | # 40728
192 | # train_test
193 | # 351268
194 | # 25898
195 | # avg length: 6.687663052493478
196 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import pickle
4 | from model import *
5 | from utils import *
6 |
7 |
8 | def init_seed(seed=None):
9 | if seed is None:
10 | seed = int(time.time() * 1000 // 1000)
11 | np.random.seed(seed)
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--dataset', default='diginetica', help='diginetica/Nowplaying/Tmall')
19 | parser.add_argument('--hiddenSize', type=int, default=100)
20 | parser.add_argument('--epoch', type=int, default=20)
21 | parser.add_argument('--activate', type=str, default='relu')
22 | parser.add_argument('--n_sample_all', type=int, default=12)
23 | parser.add_argument('--n_sample', type=int, default=12)
24 | parser.add_argument('--batch_size', type=int, default=100)
25 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate.')
26 | parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay.')
27 | parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay.')
28 | parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty ')
29 | parser.add_argument('--n_iter', type=int, default=1) # [1, 2]
30 | parser.add_argument('--dropout_gcn', type=float, default=0, help='Dropout rate.') # [0, 0.2, 0.4, 0.6, 0.8]
31 | parser.add_argument('--dropout_local', type=float, default=0, help='Dropout rate.') # [0, 0.5]
32 | parser.add_argument('--dropout_global', type=float, default=0.5, help='Dropout rate.')
33 | parser.add_argument('--validation', action='store_true', help='validation')
34 | parser.add_argument('--valid_portion', type=float, default=0.1, help='split the portion')
35 | parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
36 | parser.add_argument('--patience', type=int, default=3)
37 |
38 | opt = parser.parse_args()
39 |
40 |
41 | def main():
42 | init_seed(2020)
43 |
44 | if opt.dataset == 'diginetica':
45 | num_node = 43098
46 | opt.n_iter = 2
47 | opt.dropout_gcn = 0.2
48 | opt.dropout_local = 0.0
49 | elif opt.dataset == 'Nowplaying':
50 | num_node = 60417
51 | opt.n_iter = 1
52 | opt.dropout_gcn = 0.0
53 | opt.dropout_local = 0.0
54 | elif opt.dataset == 'Tmall':
55 | num_node = 40728
56 | opt.n_iter = 1
57 | opt.dropout_gcn = 0.6
58 | opt.dropout_local = 0.5
59 | else:
60 | num_node = 310
61 |
62 | train_data = pickle.load(open('datasets/' + opt.dataset + '/train.txt', 'rb'))
63 | if opt.validation:
64 | train_data, valid_data = split_validation(train_data, opt.valid_portion)
65 | test_data = valid_data
66 | else:
67 | test_data = pickle.load(open('datasets/' + opt.dataset + '/test.txt', 'rb'))
68 |
69 | adj = pickle.load(open('datasets/' + opt.dataset + '/adj_' + str(opt.n_sample_all) + '.pkl', 'rb'))
70 | num = pickle.load(open('datasets/' + opt.dataset + '/num_' + str(opt.n_sample_all) + '.pkl', 'rb'))
71 | train_data = Data(train_data)
72 | test_data = Data(test_data)
73 |
74 | adj, num = handle_adj(adj, num_node, opt.n_sample_all, num)
75 | model = trans_to_cuda(CombineGraph(opt, num_node, adj, num))
76 |
77 | print(opt)
78 | start = time.time()
79 | best_result = [0, 0]
80 | best_epoch = [0, 0]
81 | bad_counter = 0
82 |
83 | for epoch in range(opt.epoch):
84 | print('-------------------------------------------------------')
85 | print('epoch: ', epoch)
86 | hit, mrr = train_test(model, train_data, test_data)
87 | flag = 0
88 | if hit >= best_result[0]:
89 | best_result[0] = hit
90 | best_epoch[0] = epoch
91 | flag = 1
92 | if mrr >= best_result[1]:
93 | best_result[1] = mrr
94 | best_epoch[1] = epoch
95 | flag = 1
96 | print('Current Result:')
97 | print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f' % (hit, mrr))
98 | print('Best Result:')
99 | print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d' % (
100 | best_result[0], best_result[1], best_epoch[0], best_epoch[1]))
101 | bad_counter += 1 - flag
102 | if bad_counter >= opt.patience:
103 | break
104 | print('-------------------------------------------------------')
105 | end = time.time()
106 | print("Run time: %f s" % (end - start))
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from tqdm import tqdm
7 | from aggregator import LocalAggregator, GlobalAggregator
8 | from torch.nn import Module, Parameter
9 | import torch.nn.functional as F
10 |
11 |
12 | class CombineGraph(Module):
13 | def __init__(self, opt, num_node, adj_all, num):
14 | super(CombineGraph, self).__init__()
15 | self.opt = opt
16 |
17 | self.batch_size = opt.batch_size
18 | self.num_node = num_node
19 | self.dim = opt.hiddenSize
20 | self.dropout_local = opt.dropout_local
21 | self.dropout_global = opt.dropout_global
22 | self.hop = opt.n_iter
23 | self.sample_num = opt.n_sample
24 | self.adj_all = trans_to_cuda(torch.Tensor(adj_all)).long()
25 | self.num = trans_to_cuda(torch.Tensor(num)).float()
26 |
27 | # Aggregator
28 | self.local_agg = LocalAggregator(self.dim, self.opt.alpha, dropout=0.0)
29 | self.global_agg = []
30 | for i in range(self.hop):
31 | if opt.activate == 'relu':
32 | agg = GlobalAggregator(self.dim, opt.dropout_gcn, act=torch.relu)
33 | else:
34 | agg = GlobalAggregator(self.dim, opt.dropout_gcn, act=torch.tanh)
35 | self.add_module('agg_gcn_{}'.format(i), agg)
36 | self.global_agg.append(agg)
37 |
38 | # Item representation & Position representation
39 | self.embedding = nn.Embedding(num_node, self.dim)
40 | self.pos_embedding = nn.Embedding(200, self.dim)
41 |
42 | # Parameters
43 | self.w_1 = nn.Parameter(torch.Tensor(2 * self.dim, self.dim))
44 | self.w_2 = nn.Parameter(torch.Tensor(self.dim, 1))
45 | self.glu1 = nn.Linear(self.dim, self.dim)
46 | self.glu2 = nn.Linear(self.dim, self.dim, bias=False)
47 | self.linear_transform = nn.Linear(self.dim, self.dim, bias=False)
48 |
49 | self.leakyrelu = nn.LeakyReLU(opt.alpha)
50 | self.loss_function = nn.CrossEntropyLoss()
51 | self.optimizer = torch.optim.Adam(self.parameters(), lr=opt.lr, weight_decay=opt.l2)
52 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc)
53 |
54 | self.reset_parameters()
55 |
56 | def reset_parameters(self):
57 | stdv = 1.0 / math.sqrt(self.dim)
58 | for weight in self.parameters():
59 | weight.data.uniform_(-stdv, stdv)
60 |
61 | def sample(self, target, n_sample):
62 | # neighbor = self.adj_all[target.view(-1)]
63 | # index = np.arange(neighbor.shape[1])
64 | # np.random.shuffle(index)
65 | # index = index[:n_sample]
66 | # return self.adj_all[target.view(-1)][:, index], self.num[target.view(-1)][:, index]
67 | return self.adj_all[target.view(-1)], self.num[target.view(-1)]
68 |
69 | def compute_scores(self, hidden, mask):
70 | mask = mask.float().unsqueeze(-1)
71 |
72 | batch_size = hidden.shape[0]
73 | len = hidden.shape[1]
74 | pos_emb = self.pos_embedding.weight[:len]
75 | pos_emb = pos_emb.unsqueeze(0).repeat(batch_size, 1, 1)
76 |
77 | hs = torch.sum(hidden * mask, -2) / torch.sum(mask, 1)
78 | hs = hs.unsqueeze(-2).repeat(1, len, 1)
79 | nh = torch.matmul(torch.cat([pos_emb, hidden], -1), self.w_1)
80 | nh = torch.tanh(nh)
81 | nh = torch.sigmoid(self.glu1(nh) + self.glu2(hs))
82 | beta = torch.matmul(nh, self.w_2)
83 | beta = beta * mask
84 | select = torch.sum(beta * hidden, 1)
85 |
86 | b = self.embedding.weight[1:] # n_nodes x latent_size
87 | scores = torch.matmul(select, b.transpose(1, 0))
88 | return scores
89 |
90 | def forward(self, inputs, adj, mask_item, item):
91 | batch_size = inputs.shape[0]
92 | seqs_len = inputs.shape[1]
93 | h = self.embedding(inputs)
94 |
95 | # local
96 | h_local = self.local_agg(h, adj, mask_item)
97 |
98 | # global
99 | item_neighbors = [inputs]
100 | weight_neighbors = []
101 | support_size = seqs_len
102 |
103 | for i in range(1, self.hop + 1):
104 | item_sample_i, weight_sample_i = self.sample(item_neighbors[-1], self.sample_num)
105 | support_size *= self.sample_num
106 | item_neighbors.append(item_sample_i.view(batch_size, support_size))
107 | weight_neighbors.append(weight_sample_i.view(batch_size, support_size))
108 |
109 | entity_vectors = [self.embedding(i) for i in item_neighbors]
110 | weight_vectors = weight_neighbors
111 |
112 | session_info = []
113 | item_emb = self.embedding(item) * mask_item.float().unsqueeze(-1)
114 |
115 | # mean
116 | sum_item_emb = torch.sum(item_emb, 1) / torch.sum(mask_item.float(), -1).unsqueeze(-1)
117 |
118 | # sum
119 | # sum_item_emb = torch.sum(item_emb, 1)
120 |
121 | sum_item_emb = sum_item_emb.unsqueeze(-2)
122 | for i in range(self.hop):
123 | session_info.append(sum_item_emb.repeat(1, entity_vectors[i].shape[1], 1))
124 |
125 | for n_hop in range(self.hop):
126 | entity_vectors_next_iter = []
127 | shape = [batch_size, -1, self.sample_num, self.dim]
128 | for hop in range(self.hop - n_hop):
129 | aggregator = self.global_agg[n_hop]
130 | vector = aggregator(self_vectors=entity_vectors[hop],
131 | neighbor_vector=entity_vectors[hop+1].view(shape),
132 | masks=None,
133 | batch_size=batch_size,
134 | neighbor_weight=weight_vectors[hop].view(batch_size, -1, self.sample_num),
135 | extra_vector=session_info[hop])
136 | entity_vectors_next_iter.append(vector)
137 | entity_vectors = entity_vectors_next_iter
138 |
139 | h_global = entity_vectors[0].view(batch_size, seqs_len, self.dim)
140 |
141 | # combine
142 | h_local = F.dropout(h_local, self.dropout_local, training=self.training)
143 | h_global = F.dropout(h_global, self.dropout_global, training=self.training)
144 | output = h_local + h_global
145 |
146 | return output
147 |
148 |
149 | def trans_to_cuda(variable):
150 | if torch.cuda.is_available():
151 | return variable.cuda()
152 | else:
153 | return variable
154 |
155 |
156 | def trans_to_cpu(variable):
157 | if torch.cuda.is_available():
158 | return variable.cpu()
159 | else:
160 | return variable
161 |
162 |
163 | def forward(model, data):
164 | alias_inputs, adj, items, mask, targets, inputs = data
165 | alias_inputs = trans_to_cuda(alias_inputs).long()
166 | items = trans_to_cuda(items).long()
167 | adj = trans_to_cuda(adj).float()
168 | mask = trans_to_cuda(mask).long()
169 | inputs = trans_to_cuda(inputs).long()
170 |
171 | hidden = model(items, adj, mask, inputs)
172 | get = lambda index: hidden[index][alias_inputs[index]]
173 | seq_hidden = torch.stack([get(i) for i in torch.arange(len(alias_inputs)).long()])
174 | return targets, model.compute_scores(seq_hidden, mask)
175 |
176 |
177 | def train_test(model, train_data, test_data):
178 | print('start training: ', datetime.datetime.now())
179 | model.train()
180 | total_loss = 0.0
181 | train_loader = torch.utils.data.DataLoader(train_data, num_workers=4, batch_size=model.batch_size,
182 | shuffle=True, pin_memory=True)
183 | for data in tqdm(train_loader):
184 | model.optimizer.zero_grad()
185 | targets, scores = forward(model, data)
186 | targets = trans_to_cuda(targets).long()
187 | loss = model.loss_function(scores, targets - 1)
188 | loss.backward()
189 | model.optimizer.step()
190 | total_loss += loss
191 | print('\tLoss:\t%.3f' % total_loss)
192 | model.scheduler.step()
193 |
194 | print('start predicting: ', datetime.datetime.now())
195 | model.eval()
196 | test_loader = torch.utils.data.DataLoader(test_data, num_workers=4, batch_size=model.batch_size,
197 | shuffle=False, pin_memory=True)
198 | result = []
199 | hit, mrr = [], []
200 | for data in test_loader:
201 | targets, scores = forward(model, data)
202 | sub_scores = scores.topk(20)[1]
203 | sub_scores = trans_to_cpu(sub_scores).detach().numpy()
204 | targets = targets.numpy()
205 | for score, target, mask in zip(sub_scores, targets, test_data.mask):
206 | hit.append(np.isin(target - 1, score))
207 | if len(np.where(score == target - 1)[0]) == 0:
208 | mrr.append(0)
209 | else:
210 | mrr.append(1 / (np.where(score == target - 1)[0][0] + 1))
211 |
212 | result.append(np.mean(hit) * 100)
213 | result.append(np.mean(mrr) * 100)
214 |
215 | return result
216 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 |
6 | def split_validation(train_set, valid_portion):
7 | train_set_x, train_set_y = train_set
8 | n_samples = len(train_set_x)
9 | sidx = np.arange(n_samples, dtype='int32')
10 | np.random.shuffle(sidx)
11 | n_train = int(np.round(n_samples * (1. - valid_portion)))
12 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]]
13 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]]
14 | train_set_x = [train_set_x[s] for s in sidx[:n_train]]
15 | train_set_y = [train_set_y[s] for s in sidx[:n_train]]
16 |
17 | return (train_set_x, train_set_y), (valid_set_x, valid_set_y)
18 |
19 |
20 | def handle_data(inputData, train_len=None):
21 | len_data = [len(nowData) for nowData in inputData]
22 | if train_len is None:
23 | max_len = max(len_data)
24 | else:
25 | max_len = train_len
26 | # reverse the sequence
27 | us_pois = [list(reversed(upois)) + [0] * (max_len - le) if le < max_len else list(reversed(upois[-max_len:]))
28 | for upois, le in zip(inputData, len_data)]
29 | us_msks = [[1] * le + [0] * (max_len - le) if le < max_len else [1] * max_len
30 | for le in len_data]
31 | return us_pois, us_msks, max_len
32 |
33 |
34 | def handle_adj(adj_dict, n_entity, sample_num, num_dict=None):
35 | adj_entity = np.zeros([n_entity, sample_num], dtype=np.int64)
36 | num_entity = np.zeros([n_entity, sample_num], dtype=np.int64)
37 | for entity in range(1, n_entity):
38 | neighbor = list(adj_dict[entity])
39 | neighbor_weight = list(num_dict[entity])
40 | n_neighbor = len(neighbor)
41 | if n_neighbor == 0:
42 | continue
43 | if n_neighbor >= sample_num:
44 | sampled_indices = np.random.choice(list(range(n_neighbor)), size=sample_num, replace=False)
45 | else:
46 | sampled_indices = np.random.choice(list(range(n_neighbor)), size=sample_num, replace=True)
47 | adj_entity[entity] = np.array([neighbor[i] for i in sampled_indices])
48 | num_entity[entity] = np.array([neighbor_weight[i] for i in sampled_indices])
49 |
50 | return adj_entity, num_entity
51 |
52 |
53 | class Data(Dataset):
54 | def __init__(self, data, train_len=None):
55 | inputs, mask, max_len = handle_data(data[0], train_len)
56 | self.inputs = np.asarray(inputs)
57 | self.targets = np.asarray(data[1])
58 | self.mask = np.asarray(mask)
59 | self.length = len(data[0])
60 | self.max_len = max_len
61 |
62 | def __getitem__(self, index):
63 | u_input, mask, target = self.inputs[index], self.mask[index], self.targets[index]
64 |
65 | max_n_node = self.max_len
66 | node = np.unique(u_input)
67 | items = node.tolist() + (max_n_node - len(node)) * [0]
68 | adj = np.zeros((max_n_node, max_n_node))
69 | for i in np.arange(len(u_input) - 1):
70 | u = np.where(node == u_input[i])[0][0]
71 | adj[u][u] = 1
72 | if u_input[i + 1] == 0:
73 | break
74 | v = np.where(node == u_input[i + 1])[0][0]
75 | if u == v or adj[u][v] == 4:
76 | continue
77 | adj[v][v] = 1
78 | if adj[v][u] == 2:
79 | adj[u][v] = 4
80 | adj[v][u] = 4
81 | else:
82 | adj[u][v] = 2
83 | adj[v][u] = 3
84 |
85 | alias_inputs = [np.where(node == i)[0][0] for i in u_input]
86 |
87 | return [torch.tensor(alias_inputs), torch.tensor(adj), torch.tensor(items),
88 | torch.tensor(mask), torch.tensor(target), torch.tensor(u_input)]
89 |
90 | def __len__(self):
91 | return self.length
92 |
--------------------------------------------------------------------------------