├── src ├── __init__.py ├── reservoir.py ├── dataset.py ├── train.py ├── sampling.py ├── AGCN.py ├── main.py ├── model.py └── se_data_process.py ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 17/9/2019 4 | @author: RuihongQiu 5 | """ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAG 2 | "GAG: Global Attributed Graph Neural Network for Streaming Session-based Recommendation", SIGIR 2020 3 | 4 | If you find this repo helpful, please cite this paper. 5 | -------------------------------------------------------------------------------- /src/reservoir.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 19/9/2019 4 | @author: LeiGuo, RuihongQiu 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class Reservoir(object): 11 | def __init__(self, train, size_denominator): 12 | super(Reservoir, self).__init__() 13 | self.r_size = len(train[0]) / size_denominator 14 | self.t = 0 15 | self.data = ([], [], []) 16 | 17 | def add(self, x, y, u): # one list represents one sample 18 | # global t 19 | if self.t < self.r_size: 20 | self.data[0].append(x) 21 | self.data[1].append(y) 22 | self.data[2].append(u) 23 | else: 24 | p = self.r_size / self.t 25 | s = False 26 | random = np.random.rand() 27 | if random <= p: 28 | s = True 29 | if s: 30 | random = np.random.rand() 31 | index = int(random * (len(self.data[0]) - 1)) 32 | self.data[0][index] = x 33 | self.data[1][index] = y 34 | self.data[2][index] = u 35 | self.t += 1 36 | 37 | def update(self, data): 38 | for index in range(len(data[0])): 39 | x = data[0][index] 40 | y = data[1][index] 41 | user = data[2][index] 42 | self.add(x, y, user) 43 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 17/9/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import torch 8 | import collections 9 | import logging 10 | from torch_geometric.data import InMemoryDataset, Data 11 | from se_data_process import load_data_valid, load_testdata 12 | 13 | 14 | class MultiSessionsGraph(InMemoryDataset): 15 | """Every session is a graph.""" 16 | def __init__(self, root, phrase=None, transform=None, pre_transform=None, sampled_data=None): 17 | """ 18 | Args: 19 | root: address of the dataset 20 | phrase: 'train' or 'test1' ~ 'test5' or 'sampled***' or 'uni***' 21 | """ 22 | self.phrase, self.sampled_data = phrase, sampled_data 23 | logging.warning(self.phrase) 24 | super(MultiSessionsGraph, self).__init__(root, transform, pre_transform) 25 | self.data, self.slices = torch.load(self.processed_paths[0]) 26 | 27 | @property 28 | def raw_file_names(self): 29 | return [self.phrase + '.txt.csv'] 30 | 31 | @property 32 | def processed_file_names(self): 33 | return [self.phrase + '.pt'] 34 | 35 | def download(self): 36 | pass 37 | 38 | def process(self): 39 | # data = [[x], [y], [user]] 40 | if self.sampled_data is not None: 41 | data = self.sampled_data 42 | else: 43 | if self.phrase == 'train': 44 | data, valid = load_data_valid(self.raw_dir + '/' + self.raw_file_names[0], 0) 45 | else: 46 | data = load_testdata(self.raw_dir + '/' + self.raw_file_names[0]) 47 | data_list = [] 48 | for sequence, y, userid in zip(data[0], data[1], data[2]): 49 | count = collections.Counter(sequence) 50 | i = 0 51 | nodes = {} # dict{15: 0, 16: 1, 18: 2, ...} 52 | senders = [] 53 | x = [] 54 | for node in sequence: 55 | if node not in nodes: 56 | nodes[node] = i 57 | x.append([node]) 58 | i += 1 59 | senders.append(nodes[node]) 60 | receivers = senders[:] 61 | num_count = [count[i[0]] for i in x] 62 | 63 | if len(senders) != 1: 64 | del senders[-1] # the last item is a receiver 65 | del receivers[0] # the first item is a sender 66 | 67 | pair = {} 68 | sur_senders = senders[:] 69 | sur_receivers = receivers[:] 70 | i = 0 71 | for sender, receiver in zip(sur_senders, sur_receivers): 72 | if str(sender) + '-' + str(receiver) in pair: 73 | pair[str(sender) + '-' + str(receiver)] += 1 74 | del senders[i] 75 | del receivers[i] 76 | else: 77 | pair[str(sender) + '-' + str(receiver)] = 1 78 | i += 1 79 | 80 | count = collections.Counter(senders) 81 | out_degree_inv = [1 / count[i] for i in senders] 82 | 83 | count = collections.Counter(receivers) 84 | in_degree_inv = [1 / count[i] for i in receivers] 85 | 86 | in_degree_inv = torch.tensor(in_degree_inv, dtype=torch.float) 87 | out_degree_inv = torch.tensor(out_degree_inv, dtype=torch.float) 88 | 89 | edge_count = [pair[str(senders[i]) + '-' + str(receivers[i])] for i in range(len(senders))] 90 | edge_count = torch.tensor(edge_count, dtype=torch.float) 91 | 92 | edge_index = torch.tensor([senders, receivers], dtype=torch.long) 93 | x = torch.tensor(x, dtype=torch.long) 94 | y = torch.tensor([y], dtype=torch.long) 95 | userid = torch.tensor([userid], dtype=torch.long) 96 | num_count = torch.tensor(num_count, dtype=torch.float) 97 | sequence = torch.tensor(sequence, dtype=torch.long) 98 | sequence_len = torch.tensor([len(sequence)], dtype=torch.long) 99 | session_graph = Data(x=x, y=y, num_count=num_count, 100 | edge_index=edge_index, edge_count=edge_count, 101 | sequence=sequence, sequence_len=sequence_len, 102 | in_degree_inv=in_degree_inv, out_degree_inv=out_degree_inv, userid=userid) 103 | data_list.append(session_graph) 104 | 105 | data, slices = self.collate(data_list) 106 | torch.save((data, slices), self.processed_paths[0]) 107 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 5/4/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import pandas as pd 10 | from torch.nn.functional import softmax 11 | from torch.distributions.categorical import Categorical 12 | from scipy.stats import wasserstein_distance 13 | 14 | 15 | def forward(model, loader, device, writer, epoch, optimizer=None, train_flag=True, max_item_id=0, last_update=0): 16 | if train_flag: 17 | model.train() 18 | else: 19 | model.eval() 20 | hit20, mrr20, hit10, mrr10, hit5, mrr5, hit1, mrr1 = [], [], [], [], [], [], [], [] 21 | 22 | mean_loss = 0.0 23 | 24 | for i, batch in enumerate(loader): 25 | if train_flag: 26 | optimizer.zero_grad() 27 | scores = model(batch.to(device), max_item_id) 28 | targets = batch.y - 1 29 | loss = model.loss_function(scores, targets) 30 | 31 | if train_flag: 32 | loss.backward() 33 | optimizer.step() 34 | writer.add_scalar('loss/train_batch_loss', loss.item(), last_update + i) 35 | writer.add_scalar('embedding/user_embedding', model.user_embedding.weight.mean(), last_update + i) 36 | writer.add_scalar('embedding/item_embedding', model.item_embedding.weight.mean(), last_update + i) 37 | else: 38 | sub_scores = scores.topk(20)[1] # batch * top_k indices 39 | for score, target in zip(sub_scores.detach().cpu().numpy(), targets.detach().cpu().numpy()): 40 | hit20.append(np.isin(target, score)) 41 | if len(np.where(score == target)[0]) == 0: 42 | mrr20.append(0) 43 | else: 44 | mrr20.append(1 / (np.where(score == target)[0][0] + 1)) 45 | 46 | sub_scores = scores.topk(10)[1] # batch * top_k indices 47 | for score, target in zip(sub_scores.detach().cpu().numpy(), targets.detach().cpu().numpy()): 48 | hit10.append(np.isin(target, score)) 49 | if len(np.where(score == target)[0]) == 0: 50 | mrr10.append(0) 51 | else: 52 | mrr10.append(1 / (np.where(score == target)[0][0] + 1)) 53 | 54 | sub_scores = scores.topk(5)[1] # batch * top_k indices 55 | for score, target in zip(sub_scores.detach().cpu().numpy(), targets.detach().cpu().numpy()): 56 | hit5.append(np.isin(target, score)) 57 | if len(np.where(score == target)[0]) == 0: 58 | mrr5.append(0) 59 | else: 60 | mrr5.append(1 / (np.where(score == target)[0][0] + 1)) 61 | 62 | sub_scores = scores.topk(1)[1] # batch * top_k indices 63 | for score, target in zip(sub_scores.detach().cpu().numpy(), targets.detach().cpu().numpy()): 64 | hit1.append(np.isin(target, score)) 65 | if len(np.where(score == target)[0]) == 0: 66 | mrr1.append(0) 67 | else: 68 | mrr1.append(1 / (np.where(score == target)[0][0] + 1)) 69 | 70 | mean_loss += loss / batch.num_graphs 71 | 72 | if train_flag: 73 | writer.add_scalar('loss/train_loss', mean_loss.item(), epoch) 74 | else: 75 | writer.add_scalar('loss/test_loss', mean_loss.item(), epoch) 76 | hit20 = np.mean(hit20) * 100 77 | mrr20 = np.mean(mrr20) * 100 78 | writer.add_scalar('index/hit20', hit20, epoch) 79 | writer.add_scalar('index/mrr20', mrr20, epoch) 80 | hit10 = np.mean(hit10) * 100 81 | mrr10 = np.mean(mrr10) * 100 82 | writer.add_scalar('index/hit10', hit10, epoch) 83 | writer.add_scalar('index/mrr10', mrr10, epoch) 84 | hit5 = np.mean(hit5) * 100 85 | mrr5 = np.mean(mrr5) * 100 86 | writer.add_scalar('index/hit5', hit5, epoch) 87 | writer.add_scalar('index/mrr5', mrr5, epoch) 88 | hit1 = np.mean(hit1) * 100 89 | mrr1 = np.mean(mrr1) * 100 90 | writer.add_scalar('index/hit1', hit1, epoch) 91 | writer.add_scalar('index/mrr1', mrr1, epoch) 92 | 93 | 94 | def forward_entropy(model, loader, device, max_item_id=0): 95 | for i, batch in enumerate(loader): 96 | scores = softmax(model(batch.to(device), max_item_id), dim=1) 97 | dis_score = Categorical(scores) 98 | if i == 0: 99 | entropy = dis_score.entropy() 100 | else: 101 | entropy = torch.cat((entropy, dis_score.entropy())) 102 | 103 | # pro = softmax(entropy).cpu().detach().numpy() 104 | pro = entropy.cpu().detach().numpy() 105 | weights = np.exp((pd.Series(pro).rank() / len(pro)).values) 106 | return weights / np.sum(weights) 107 | # return pro / pro.sum() 108 | 109 | 110 | def forward_cross_entropy(model, loader, device, max_item_id=0): 111 | for i, batch in enumerate(loader): 112 | scores = softmax(model(batch.to(device), max_item_id), dim=1) 113 | targets = batch.y - 1 114 | if i == 0: 115 | cross_entropy = torch.nn.functional.cross_entropy(scores, targets, reduction='none') 116 | else: 117 | cross_entropy = torch.cat((cross_entropy, torch.nn.functional.cross_entropy(scores, targets, reduction='none'))) 118 | 119 | pro = cross_entropy.cpu().detach().numpy() 120 | return pro / pro.sum() 121 | 122 | 123 | def forward_wass(model, loader, device, max_item_id=0): 124 | distance = [] 125 | for i, batch in enumerate(loader): 126 | scores = softmax(model(batch.to(device), max_item_id), dim=1) 127 | targets = batch.y - 1 128 | targets_1hot = torch.zeros_like(scores).scatter_(1, targets.view(-1, 1), 1).cpu().numpy() 129 | distance += list(wasserstein_distance(score, target) for score, target in zip(scores.cpu().numpy(), targets_1hot)) 130 | # distance += list(map(wasserstein_distance, scores.cpu().numpy(), targets_1hot)) 131 | 132 | weights = np.exp((pd.Series(distance).rank() / len(distance)).values) 133 | return weights / np.sum(weights) 134 | # return distance / np.sum(distance) 135 | -------------------------------------------------------------------------------- /src/sampling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 23/9/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import os 10 | from dataset import MultiSessionsGraph 11 | from torch_geometric.data import DataLoader 12 | from train import forward_entropy, forward_cross_entropy, forward_wass 13 | 14 | 15 | def random_on_union(current_res, current_win, win_size, p=None): 16 | # R' = R U R^{new} 17 | uni_x = current_res[0] + current_win[0] 18 | uni_y = current_res[1] + current_win[1] 19 | uni_user = current_res[2] + current_win[2] 20 | 21 | # random sampling on the union set 22 | res_index = [i for i in range(len(uni_x))] 23 | sampled_index = np.random.choice(res_index, win_size, replace=False, p=p) 24 | sampled_x = np.array(uni_x)[sampled_index].tolist() 25 | sampled_y = np.array(uni_y)[sampled_index].tolist() 26 | sampled_user = np.array(uni_user)[sampled_index].tolist() 27 | return sampled_x, sampled_y, sampled_user 28 | 29 | 30 | def random_on_new(current_win, win_size, p=None): 31 | # random sampling on the new set 32 | res_index = [i for i in range(len(current_win[0]))] 33 | sampled_index = np.random.choice(res_index, win_size, replace=False, p=p) 34 | sampled_x = np.array(current_win[0])[sampled_index].tolist() 35 | sampled_y = np.array(current_win[1])[sampled_index].tolist() 36 | sampled_user = np.array(current_win[2])[sampled_index].tolist() 37 | return sampled_x, sampled_y, sampled_user 38 | 39 | 40 | def fix_new(current_win, win_size, max_item, max_user): 41 | # random sampling on the old items and users while must include all new items and users 42 | sampled_x, sampled_y, sampled_user = [], [], [] 43 | deleted_index = [] 44 | 45 | for i in range(len(current_win[0])): 46 | if max(current_win[0][i]) > max_item or current_win[1][i] > max_item or current_win[2][i] > max_user: 47 | sampled_x.append(current_win[0][i]) 48 | sampled_y.append(current_win[1][i]) 49 | sampled_user.append(current_win[2][i]) 50 | win_size -= 1 51 | deleted_index.append(i) 52 | 53 | left_win = tuple(np.delete(data, deleted_index).tolist() for data in current_win) 54 | 55 | return sampled_x, sampled_y, sampled_user, left_win, win_size 56 | 57 | 58 | def fix_new_random_on_new(current_win, win_size, max_item, max_user): 59 | sampled_x, sampled_y, sampled_user, left_win, win_size = fix_new(current_win, win_size, max_item, max_user) 60 | sampled_old = random_on_new(left_win, win_size) 61 | 62 | return sampled_x + sampled_old[0], sampled_y + sampled_old[1], sampled_user + sampled_old[2] 63 | 64 | 65 | def fix_new_random_on_union(current_res, current_win, win_size, max_item, max_user): 66 | sampled_x, sampled_y, sampled_user, left_win, win_size = fix_new(current_win, win_size, max_item, max_user) 67 | sampled_old = random_on_union(current_res, left_win, win_size) 68 | 69 | return sampled_x + sampled_old[0], sampled_y + sampled_old[1], sampled_user + sampled_old[2] 70 | 71 | 72 | def entropy_on_union(cur_dir, now, opt, model, device, current_res, current_win, win_size, ent='entropy'): 73 | # R' = R U R^{new} 74 | uni_x = current_res[0] + current_win[0] 75 | uni_y = current_res[1] + current_win[1] 76 | uni_user = current_res[2] + current_win[2] 77 | uni_data = (uni_x, uni_y, uni_user) 78 | 79 | uni_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, 80 | phrase='uni' + now, 81 | sampled_data=uni_data) 82 | uni_loader = DataLoader(uni_dataset, batch_size=opt.batch_size, shuffle=False) 83 | 84 | with torch.no_grad(): 85 | if ent == 'entropy': 86 | pro = forward_entropy(model, uni_loader, device, max(max(max(current_win[0])), max(current_win[1]))) 87 | elif ent == 'cross': 88 | pro = forward_cross_entropy(model, uni_loader, device, max(max(max(current_win[0])), max(current_win[1]))) 89 | elif ent == 'wass': 90 | pro = forward_wass(model, uni_loader, device, max(max(max(current_win[0])), max(current_win[1]))) 91 | 92 | os.remove('../datasets/' + opt.dataset + '/processed/uni' + now + '.pt') 93 | 94 | return random_on_union(current_res, current_win, win_size, p=pro) 95 | 96 | 97 | def fix_new_entropy_on_union(cur_dir, now, opt, model, device, current_res, current_win, win_size, max_item, max_user, 98 | ent='entropy'): 99 | sampled_x, sampled_y, sampled_user, left_win, left_win_size = fix_new(current_win, win_size, max_item, max_user) 100 | if left_win_size > 0: 101 | sampled_old = entropy_on_union(cur_dir, now, opt, model, device, current_res, left_win, left_win_size, ent=ent) 102 | 103 | return sampled_x + sampled_old[0], sampled_y + sampled_old[1], sampled_user + sampled_old[2] 104 | else: 105 | return entropy_on_new(cur_dir, now, opt, model, device, (sampled_x, sampled_y, sampled_user), win_size) 106 | 107 | 108 | def entropy_on_new(cur_dir, now, opt, model, device, current_win, win_size): 109 | new_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, 110 | phrase='new' + now, 111 | sampled_data=current_win) 112 | new_loader = DataLoader(new_dataset, batch_size=opt.batch_size, shuffle=False) 113 | 114 | with torch.no_grad(): 115 | pro = forward_entropy(model, new_loader, device, max(max(max(current_win[0])), max(current_win[1]))) 116 | 117 | os.remove('../datasets/' + opt.dataset + '/processed/new' + now + '.pt') 118 | 119 | return random_on_new(current_win, win_size, p=pro) 120 | 121 | 122 | def fix_new_entropy_on_new(cur_dir, now, opt, model, device, current_win, win_size, max_item, max_user): 123 | sampled_x, sampled_y, sampled_user, left_win, win_size = fix_new(current_win, win_size, max_item, max_user) 124 | sampled_old = entropy_on_new(cur_dir, now, opt, model, device, left_win, win_size) 125 | 126 | return sampled_x + sampled_old[0], sampled_y + sampled_old[1], sampled_user + sampled_old[2] 127 | 128 | -------------------------------------------------------------------------------- /src/AGCN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 26/9/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import torch 8 | from torch.nn import Parameter 9 | from torch_scatter import scatter_add 10 | from torch_geometric.nn.conv import MessagePassing 11 | from torch_geometric.utils import add_remaining_self_loops 12 | 13 | from torch_geometric.nn.inits import glorot, zeros 14 | 15 | 16 | class AGCN(MessagePassing): 17 | r"""The graph convolutional operator from the `"Semi-supervised 18 | Classification with Graph Convolutional Networks" 19 | `_ paper 20 | 21 | .. math:: 22 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 23 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 24 | 25 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 26 | adjacency matrix with inserted self-loops and 27 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 28 | 29 | Args: 30 | in_channels (int): Size of each input sample. 31 | out_channels (int): Size of each output sample. 32 | improved (bool, optional): If set to :obj:`True`, the layer computes 33 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 34 | (default: :obj:`False`) 35 | cached (bool, optional): If set to :obj:`True`, the layer will cache 36 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 37 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 38 | cached version for further executions. 39 | This parameter should only be set to :obj:`True` in transductive 40 | learning scenarios. (default: :obj:`False`) 41 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 42 | an additive bias. (default: :obj:`True`) 43 | **kwargs (optional): Additional arguments of 44 | :class:`torch_geometric.nn.conv.MessagePassing`. 45 | """ 46 | 47 | def __init__(self, in_channels, out_channels, improved=False, cached=False, 48 | bias=True, **kwargs): 49 | super(AGCN, self).__init__(aggr='add', **kwargs) 50 | 51 | self.in_channels = in_channels 52 | self.out_channels = out_channels 53 | self.improved = improved 54 | self.cached = cached 55 | 56 | self.weight = Parameter(torch.Tensor(2, in_channels, out_channels)) 57 | self.out_linear = torch.nn.Linear(in_channels, out_channels) 58 | 59 | if bias: 60 | self.bias = Parameter(torch.Tensor(out_channels)) 61 | else: 62 | self.register_parameter('bias', None) 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | glorot(self.weight) 68 | zeros(self.bias) 69 | self.cached_result = None 70 | self.cached_num_edges = None 71 | 72 | @staticmethod 73 | def norm(edge_index, num_nodes, edge_weight=None, improved=False, 74 | dtype=None): 75 | if edge_weight is None: 76 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 77 | device=edge_index.device) 78 | 79 | fill_value = 1 if not improved else 2 80 | edge_index, edge_weight = add_remaining_self_loops( 81 | edge_index, edge_weight, fill_value, num_nodes) 82 | 83 | row, col = edge_index 84 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 85 | deg_inv_sqrt = deg.pow(-0.5) 86 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 87 | 88 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 89 | 90 | def forward(self, x, edge_index, edge_weight=None, u=None): 91 | """""" 92 | x = x.view(-1, self.out_channels) 93 | u = u.view(-1, self.out_channels) 94 | x0 = torch.matmul(torch.cat((x, u), dim=-1), self.weight[0]) 95 | 96 | if self.cached and self.cached_result is not None: 97 | if edge_index.size(1) != self.cached_num_edges: 98 | raise RuntimeError( 99 | 'Cached {} number of edges, but found {}. Please ' 100 | 'disable the caching behavior of this layer by removing ' 101 | 'the `cached=True` argument in its constructor.'.format( 102 | self.cached_num_edges, edge_index.size(1))) 103 | 104 | if not self.cached or self.cached_result is None: 105 | self.cached_num_edges = edge_index.size(1) 106 | edge_index1, norm = self.norm(edge_index, x.size(0), edge_weight[0], 107 | self.improved, x.dtype) 108 | self.cached_result = edge_index1, norm 109 | 110 | edge_index1, norm = self.cached_result 111 | 112 | self.flow = 'source_to_target' 113 | m0 = self.propagate(edge_index1, x=x0, norm=norm) 114 | 115 | x1 = torch.matmul(torch.cat((x, u), dim=-1), self.weight[1]) 116 | 117 | if self.cached and self.cached_result is not None: 118 | if edge_index.size(1) != self.cached_num_edges: 119 | raise RuntimeError( 120 | 'Cached {} number of edges, but found {}. Please ' 121 | 'disable the caching behavior of this layer by removing ' 122 | 'the `cached=True` argument in its constructor.'.format( 123 | self.cached_num_edges, edge_index.size(1))) 124 | 125 | if not self.cached or self.cached_result is None: 126 | self.cached_num_edges = edge_index.size(1) 127 | edge_index2, norm = self.norm(edge_index, x.size(0), edge_weight[1], 128 | self.improved, x.dtype) 129 | self.cached_result = edge_index2, norm 130 | 131 | edge_index2, norm = self.cached_result 132 | 133 | self.flow = 'target_to_source' 134 | m1 = self.propagate(edge_index2, x=x1, norm=norm) 135 | 136 | return self.out_linear(torch.cat((m0, m1), dim=-1)) 137 | 138 | def message(self, x_j, norm): 139 | return norm.view(-1, 1) * x_j 140 | 141 | def update(self, aggr_out): 142 | if self.bias is not None: 143 | aggr_out = aggr_out + self.bias 144 | return aggr_out 145 | 146 | def __repr__(self): 147 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 148 | self.out_channels) 149 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 17/9/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import argparse 8 | import logging 9 | import time 10 | from tqdm import tqdm 11 | from model import GNNModel 12 | from train import forward 13 | from torch.utils.tensorboard import SummaryWriter 14 | from se_data_process import load_data_valid, load_testdata 15 | from reservoir import Reservoir 16 | from sampling import * 17 | 18 | # Logger configuration 19 | logging.basicConfig(level=logging.DEBUG, 20 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s') 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', default='gowalla', help='dataset name: gowalla/lastfm') 24 | parser.add_argument('--batch_size', type=int, default=100, help='input batch size') 25 | parser.add_argument('--hidden_size', type=int, default=200, help='hidden state size') 26 | parser.add_argument('--epoch', type=int, default=4, help='the number of epochs to train for') 27 | parser.add_argument('--lr', type=float, default=0.003, help='learning rate') # [0.001, 0.0005, 0.0001] 28 | parser.add_argument('--lr_dc', type=float, default=1.0, help='learning rate decay rate') 29 | parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay') 30 | parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 31 | parser.add_argument('--u', type=int, default=1, help='the number of layer with u') 32 | parser.add_argument('--res_size', type=int, default=100, help='the denominator of the reservoir size') 33 | parser.add_argument('--win_size', type=int, default=1, help='the denominator of the window size') 34 | opt = parser.parse_args() 35 | logging.warning(opt) 36 | 37 | 38 | def main(): 39 | assert opt.dataset in ['gowalla', 'lastfm'] 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | 42 | cur_dir = os.getcwd() 43 | 44 | train_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, phrase='train') 45 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True) 46 | train_for_res, _ = load_data_valid( 47 | os.path.expanduser(os.path.normpath(cur_dir + '/../datasets/' + opt.dataset + '/raw/train.txt.csv')), 0) 48 | max_train_item = max(max(max(train_for_res[0])), max(train_for_res[1])) 49 | max_train_user = max(train_for_res[2]) 50 | 51 | test_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, phrase='test1') 52 | test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False) 53 | test_for_res = load_testdata( 54 | os.path.expanduser(os.path.normpath(cur_dir + '/../datasets/' + opt.dataset + '/raw/test1.txt.csv'))) 55 | max_item = max(max(max(test_for_res[0])), max(test_for_res[1])) 56 | max_user = max(test_for_res[2]) 57 | pre_max_item = max_train_item 58 | pre_max_user = max_train_user 59 | 60 | log_dir = cur_dir + '/../log/' + str(opt.dataset) + '/paper200/' + str( 61 | opt) + '_fix_new_entropy(rank)_on_union+' + str(opt.u) + 'tanh*u_AGCN***GAG-win' + str(opt.win_size) \ 62 | + '***concat3_linear_tanh_in_e2s_' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 63 | if not os.path.exists(log_dir): 64 | os.makedirs(log_dir) 65 | logging.warning('logging to {}'.format(log_dir)) 66 | writer = SummaryWriter(log_dir) 67 | 68 | if opt.dataset == 'gowalla': 69 | n_item = 30000 70 | n_user = 33005 71 | else: 72 | n_item = 10000 73 | n_user = 984 74 | 75 | model = GNNModel(hidden_size=opt.hidden_size, n_item=n_item, n_user=n_user, u=opt.u).to(device) 76 | 77 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2) 78 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 3], gamma=opt.lr_dc) 79 | 80 | logging.warning(model) 81 | 82 | # offline training on 'train' and test on 'test1' 83 | logging.warning('*********Begin offline training*********') 84 | updates_per_epoch = len(train_loader) 85 | updates_count = 0 86 | for train_epoch in tqdm(range(opt.epoch)): 87 | forward(model, train_loader, device, writer, train_epoch, optimizer=optimizer, 88 | train_flag=True, max_item_id=max_train_item, last_update=updates_count) 89 | scheduler.step() 90 | updates_count += updates_per_epoch 91 | with torch.no_grad(): 92 | forward(model, test_loader, device, writer, train_epoch, train_flag=False, max_item_id=max_item) 93 | 94 | # reservoir construction with 'train' 95 | logging.warning('*********Constructing the reservoir with offline training data*********') 96 | res = Reservoir(train_for_res, opt.res_size) 97 | res.update(train_for_res) 98 | 99 | # test and online training on 'test2~5' 100 | logging.warning('*********Begin online training*********') 101 | now = time.asctime() 102 | for test_epoch in tqdm(range(1, 6)): 103 | if test_epoch != 1: 104 | test_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, phrase='test' + str(test_epoch)) 105 | test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False) 106 | 107 | test_for_res = load_testdata( 108 | os.path.expanduser(os.path.normpath( 109 | cur_dir + '/../datasets/' + opt.dataset + '/raw/test' + str(test_epoch) + '.txt.csv'))) 110 | pre_max_item = max_item 111 | pre_max_user = max_user 112 | max_item = max(max(max(test_for_res[0])), max(test_for_res[1])) 113 | max_user = max(test_for_res[2]) 114 | 115 | # test on the current test set 116 | # no need to test on test1 because it's done in the online training part 117 | # epoch + 10 is a number only for the visualization convenience 118 | with torch.no_grad(): 119 | forward(model, test_loader, device, writer, test_epoch + 10, 120 | train_flag=False, max_item_id=max_item) 121 | 122 | # reservoir sampling 123 | sampled_data = fix_new_entropy_on_union(cur_dir, now, opt, model, device, res.data, test_for_res, 124 | len(test_for_res[0]) // opt.win_size, pre_max_item, pre_max_user, 125 | ent='wass') 126 | 127 | # cast the sampled set to dataset 128 | sampled_dataset = MultiSessionsGraph(cur_dir + '/../datasets/' + opt.dataset, 129 | phrase='sampled' + now, 130 | sampled_data=sampled_data) 131 | sampled_loader = DataLoader(sampled_dataset, batch_size=opt.batch_size, shuffle=True) 132 | 133 | # update with the sampled set 134 | forward(model, sampled_loader, device, writer, test_epoch + opt.epoch, optimizer=optimizer, 135 | train_flag=True, max_item_id=max_item, last_update=updates_count) 136 | 137 | updates_count += len(test_loader) 138 | 139 | scheduler.step() 140 | 141 | res.update(test_for_res) 142 | os.remove('../datasets/' + opt.dataset + '/processed/sampled' + now + '.pt') 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 17/9/2019 4 | @author: RuihongQiu 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from AGCN import AGCN 12 | 13 | 14 | class Embedding2Score(nn.Module): 15 | def __init__(self, hidden_size): 16 | super(Embedding2Score, self).__init__() 17 | self.hidden_size = hidden_size 18 | self.W_1 = nn.Linear(self.hidden_size, self.hidden_size) 19 | self.W_2 = nn.Linear(self.hidden_size, self.hidden_size) 20 | self.q = nn.Linear(self.hidden_size, 1) 21 | self.W_3 = nn.Linear(2 * self.hidden_size, self.hidden_size) 22 | self.user_linear = nn.Linear(self.hidden_size, self.hidden_size) 23 | 24 | def forward(self, node_embedding, item_embedding_table, sections, num_count, user_embedding, max_item_id, u_n_repeat): 25 | v_i = torch.split(node_embedding, tuple(sections.cpu().numpy())) # split whole x back into graphs G_i 26 | v_n_repeat = tuple(nodes[-1].view(1, -1).repeat(nodes.shape[0], 1) for nodes in v_i) # repeat |V|_i times for the last node embedding 27 | 28 | # Eq(6) 29 | alpha = self.q(torch.sigmoid(self.W_1(torch.cat(v_n_repeat, dim=0)) + self.W_2(node_embedding))) # |V|_i * 1 30 | s_g_whole = num_count.view(-1, 1) * alpha * node_embedding # |V|_i * hidden_size 31 | s_g_split = torch.split(s_g_whole, tuple(sections.cpu().numpy())) # split whole s_g into graphs G_i 32 | s_g = tuple(torch.sum(embeddings, dim=0).view(1, -1) for embeddings in s_g_split) 33 | 34 | # Eq(7) 35 | v_n = tuple(nodes[-1].view(1, -1) for nodes in v_i) 36 | s_h = self.W_3(torch.cat((torch.cat(v_n, dim=0), torch.cat(s_g, dim=0)), dim=1)) 37 | 38 | return s_h 39 | 40 | 41 | class Embedding2ScoreWithU(nn.Module): 42 | def __init__(self, hidden_size): 43 | super(Embedding2ScoreWithU, self).__init__() 44 | self.hidden_size = hidden_size 45 | self.W_1 = nn.Linear(self.hidden_size, 1) 46 | self.W_2 = nn.Linear(2 * self.hidden_size + self.hidden_size, self.hidden_size) 47 | self.W_3 = nn.Linear(self.hidden_size, self.hidden_size) 48 | self.W_4 = nn.Linear(self.hidden_size, self.hidden_size) 49 | self.W_5 = nn.Linear(2 * self.hidden_size, self.hidden_size) 50 | self.user_linear = nn.Linear(self.hidden_size, self.hidden_size) 51 | self.user_out = nn.Linear(2 * self.hidden_size, self.hidden_size) 52 | 53 | def forward(self, node_embedding, item_embedding_table, sections, num_count, user_embedding, max_item_id, 54 | u_n_repeat): 55 | if list(sections.size())[0] == 1: 56 | u_n_repeat = u_n_repeat.view(1, -1) 57 | node_embedding = node_embedding.view(-1, self.hidden_size) 58 | v_n_repeat = tuple(node_embedding.repeat(sections[0], 1)) 59 | alpha = self.W_1( 60 | torch.sigmoid(self.W_2(torch.cat((v_n_repeat[0].view(1, -1), node_embedding, u_n_repeat), dim=-1)))) 61 | else: 62 | v_i = torch.split(node_embedding, tuple(sections.cpu().numpy())) # split whole x back into graphs G_i 63 | v_n_repeat = tuple(nodes[-1].view(1, -1).repeat(nodes.shape[0], 1) for nodes in 64 | v_i) # repeat |V|_i times for the last node embedding 65 | 66 | alpha = self.W_1( 67 | torch.sigmoid(self.W_2(torch.cat((torch.cat(v_n_repeat, dim=0), node_embedding, u_n_repeat), dim=-1)))) 68 | s_g_whole = num_count.view(-1, 1) * alpha * node_embedding # |V|_i * hidden_size 69 | if list(sections.size())[0] == 1: 70 | s_g = tuple(torch.sum(s_g_whole.view(-1, self.hidden_size), dim=0).view(1, -1)) 71 | s_h = self.W_5(torch.cat((node_embedding, s_g[0].view(-1, self.hidden_size)), dim=-1)) 72 | else: 73 | s_g_split = torch.split(s_g_whole, tuple(sections.cpu().numpy())) # split whole s_g into graphs G_i 74 | s_g = tuple(torch.sum(embeddings, dim=0).view(1, -1) for embeddings in s_g_split) 75 | 76 | v_n = tuple(nodes[-1].view(1, -1) for nodes in v_i) 77 | stack_v_n = torch.cat(v_n, dim=0) 78 | s_h = self.W_5(torch.cat((stack_v_n, torch.cat(s_g, dim=0)), dim=-1)) 79 | 80 | s_h += self.user_linear(user_embedding).tanh() 81 | return s_h 82 | 83 | 84 | class GNNModel(nn.Module): 85 | """ 86 | Args: 87 | hidden_size: the number of units in a hidden layer. 88 | n_item: the number of items in the whole item set for embedding layer. 89 | n_user: the number of users 90 | """ 91 | def __init__(self, hidden_size, n_item, n_user=None, heads=None, u=1): 92 | super(GNNModel, self).__init__() 93 | self.hidden_size, self.n_item, self.n_user, self.heads, self.u = hidden_size, n_item, n_user, heads, u 94 | self.item_embedding = nn.Embedding(self.n_item, self.hidden_size) 95 | if self.n_user: 96 | self.user_embedding = nn.Embedding(self.n_user, self.hidden_size) 97 | if self.u > 0: 98 | self.gnn = [] 99 | for i in range(self.u): 100 | self.gnn.append(AGCN(2 * self.hidden_size, self.hidden_size).cuda()) 101 | else: 102 | self.gnn = AGCN(self.hidden_size, self.hidden_size) 103 | self.e2s = Embedding2ScoreWithU(self.hidden_size) 104 | self.loss_function = nn.CrossEntropyLoss() 105 | self.reset_parameters() 106 | 107 | def reset_parameters(self): 108 | stdv = 1.0 / math.sqrt(self.hidden_size) 109 | for weight in self.parameters(): 110 | weight.data.uniform_(-stdv, stdv) 111 | 112 | def forward(self, data, max_item_id=0): 113 | x, edge_index, batch, edge_count, in_degree_inv, out_degree_inv, sequence, num_count, userid = \ 114 | data.x - 1, data.edge_index, data.batch, data.edge_count, data.in_degree_inv, data.out_degree_inv,\ 115 | data.sequence, data.num_count, data.userid - 1 116 | 117 | hidden = self.item_embedding(x).squeeze() 118 | sections = torch.bincount(batch) 119 | u = self.user_embedding(userid).squeeze() 120 | 121 | if self.u > 0: 122 | for layer in range(self.u): 123 | if list(sections.size())[0] == 1: 124 | u_n_repeat = tuple(u.view(1, -1).repeat(sections[0], 1)) 125 | else: 126 | u_n_repeat = tuple(u.view(1, -1).repeat(times, 1) for (u, times) in zip(u, sections)) 127 | hidden = self.gnn[layer](hidden, edge_index, 128 | [edge_count * in_degree_inv, edge_count * out_degree_inv], 129 | u=torch.cat(u_n_repeat, dim=0)) 130 | if self.heads is not None: 131 | hidden = torch.stack(hidden.chunk(self.heads, dim=-1), dim=1).mean(dim=1) 132 | hidden = torch.tanh(hidden) 133 | u = self.e2s(hidden, self.item_embedding, sections, num_count, u, max_item_id, torch.cat(u_n_repeat, dim=0)) 134 | else: 135 | hidden = self.gnn(hidden, edge_index, [edge_count * in_degree_inv, edge_count * out_degree_inv], u=None) 136 | if self.heads is not None: 137 | hidden = torch.stack(hidden.chunk(self.heads, dim=-1), dim=1).mean(dim=1) 138 | u = self.e2s(hidden, self.item_embedding, sections, num_count, u, max_item_id) 139 | 140 | z_i_hat = torch.mm(u, self.item_embedding.weight[:max_item_id].transpose(1, 0)) 141 | return z_i_hat 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/se_data_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 17/9/2019 4 | @author: LeiGuo, RuihongQiu 5 | """ 6 | 7 | import pickle 8 | import numpy 9 | 10 | 11 | def prepare_data(seqs, labels): 12 | """Create the matrices from the datasets. 13 | This pad each sequence to the same lenght: the lenght of the 14 | longuest sequence or maxlen. 15 | if maxlen is set, we will cut all sequence to this maximum 16 | lenght. 17 | This swap the axis! 18 | """ 19 | lengths = [len(s) for s in seqs] 20 | n_samples = len(seqs) 21 | maxlen = numpy.max(lengths) 22 | 23 | x = numpy.zeros((maxlen, n_samples)).astype('int64') 24 | x_mask = numpy.ones((maxlen, n_samples)).astype(theano.config.floatX) 25 | for idx, s in enumerate(seqs): 26 | x[:lengths[idx], idx] = s 27 | 28 | x_mask *= (1 - (x == 0)) 29 | return x, x_mask, labels 30 | 31 | 32 | def load_data(path, valid_portion=0.1, maxlen=19, sort_by_len=False): 33 | """Loads the dataset 34 | :type path: String 35 | :param path: The path to the dataset (here RSC2015) 36 | :type n_items: int 37 | :param n_items: The number of items. 38 | :type valid_portion: float 39 | :param valid_portion: The proportion of the full train set used for 40 | the validation set. 41 | :type maxlen: None or positive int 42 | :param maxlen: the max sequence length we use in the train/valid set. 43 | :type sort_by_len: bool 44 | :name sort_by_len: Sort by the sequence lenght for the train, 45 | valid and test set. This allow faster execution as it cause 46 | less padding per minibatch. Another mechanism must be used to 47 | shuffle the train set at each epoch. 48 | """ 49 | path_train_data = path+'train_final.csv' 50 | path_test_data = path+'test_final.csv' 51 | 52 | f1 = open(path_train_data, 'rb') 53 | train_set = pickle.load(f1) 54 | f1.close() 55 | 56 | f2 = open(path_test_data, 'rb') 57 | test_set = pickle.load(f2) 58 | f2.close() 59 | 60 | if maxlen: 61 | new_train_set_x = [] 62 | new_train_set_y = [] 63 | new_train_set_u = [] 64 | for x, y, u in list(zip(train_set[0], train_set[1],train_set[2])): 65 | if len(x) < maxlen: 66 | new_train_set_x.append(x) 67 | new_train_set_y.append(y) 68 | new_train_set_u.append(u) 69 | else: 70 | new_train_set_x.append(x[:maxlen]) 71 | new_train_set_y.append(y) 72 | new_train_set_u.append(u) 73 | train_set = (new_train_set_x, new_train_set_y, new_train_set_u) 74 | del new_train_set_x, new_train_set_y, new_train_set_u 75 | 76 | new_test_set_x = [] 77 | new_test_set_y = [] 78 | new_test_set_u = [] 79 | for xx, yy, uu in zip(test_set[0], test_set[1], test_set[2]): 80 | if len(xx) < maxlen: 81 | new_test_set_x.append(xx) 82 | new_test_set_y.append(yy) 83 | new_test_set_u.append(uu) 84 | else: 85 | new_test_set_x.append(xx[:maxlen]) 86 | new_test_set_y.append(yy) 87 | new_test_set_u.append(uu) 88 | test_set = (new_test_set_x, new_test_set_y, new_test_set_u) 89 | del new_test_set_x, new_test_set_y, new_test_set_u 90 | 91 | # split training set into validation set 92 | train_set_x, train_set_y, train_set_u = train_set 93 | n_samples = len(train_set_x) 94 | sidx = numpy.arange(n_samples, dtype='int32') 95 | numpy.random.shuffle(sidx) 96 | n_train = int(numpy.round(n_samples * (1. - valid_portion))) 97 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]] 98 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]] 99 | valid_set_u = [train_set_u[s] for s in sidx[n_train:]] 100 | 101 | train_set_x = [train_set_x[s] for s in sidx[:n_train]] 102 | train_set_y = [train_set_y[s] for s in sidx[:n_train]] 103 | train_set_u = [train_set_u[s] for s in sidx[:n_train]] 104 | 105 | train_set = (train_set_x, train_set_y, train_set_u) 106 | valid_set = (valid_set_x, valid_set_y, valid_set_u) 107 | 108 | test_set_x, test_set_y, test_set_u = test_set 109 | valid_set_x, valid_set_y, valid_set_u = valid_set 110 | train_set_x, train_set_y, train_set_u = train_set 111 | 112 | def len_argsort(seq): 113 | return sorted(range(len(seq)), key=lambda x: len(seq[x])) 114 | 115 | if sort_by_len: 116 | sorted_index = len_argsort(test_set_x) 117 | test_set_x = [test_set_x[i] for i in sorted_index] 118 | test_set_y = [test_set_y[i] for i in sorted_index] 119 | 120 | sorted_index = len_argsort(valid_set_x) 121 | valid_set_x = [valid_set_x[i] for i in sorted_index] 122 | valid_set_y = [valid_set_y[i] for i in sorted_index] 123 | 124 | train = (train_set_x, train_set_y, train_set_u) 125 | valid = (valid_set_x, valid_set_y, valid_set_u) 126 | test = (test_set_x, test_set_y, test_set_u) 127 | 128 | return train, valid, test 129 | 130 | 131 | def load_traindata(trainFile, validFile, valid_portion=0.1, maxlen=19, sort_by_len=False): 132 | """Loads the dataset 133 | :type path: String 134 | :param path: The path to the dataset (here RSC2015) 135 | :type n_items: int 136 | :param n_items: The number of items. 137 | :type valid_portion: float 138 | :param valid_portion: The proportion of the full train set used for 139 | the validation set. 140 | :type maxlen: None or positive int 141 | :param maxlen: the max sequence length we use in the train/valid set. 142 | :type sort_by_len: bool 143 | :name sort_by_len: Sort by the sequence lenght for the train, 144 | valid and test set. This allow faster execution as it cause 145 | less padding per minibatch. Another mechanism must be used to 146 | shuffle the train set at each epoch. 147 | """ 148 | 149 | path_train_data = trainFile 150 | path_test_data = validFile 151 | 152 | f1 = open(path_train_data, 'rb') 153 | train_set = pickle.load(f1) 154 | f1.close() 155 | 156 | f2 = open(path_test_data, 'rb') 157 | test_set = pickle.load(f2) 158 | f2.close() 159 | 160 | if maxlen: 161 | new_train_set_x = [] 162 | new_train_set_y = [] 163 | new_train_set_u = [] 164 | for x, y, u in list(zip(train_set[0], train_set[1], train_set[2])): 165 | if len(x) < maxlen: 166 | new_train_set_x.append(x) 167 | new_train_set_y.append(y) 168 | new_train_set_u.append(u) 169 | else: 170 | new_train_set_x.append(x[:maxlen]) 171 | new_train_set_y.append(y) 172 | new_train_set_u.append(u) 173 | train_set = (new_train_set_x, new_train_set_y, new_train_set_u) 174 | del new_train_set_x, new_train_set_y, new_train_set_u 175 | 176 | new_test_set_x = [] 177 | new_test_set_y = [] 178 | new_test_set_u = [] 179 | for xx, yy, uu in zip(test_set[0], test_set[1], test_set[2]): 180 | if len(xx) < maxlen: 181 | new_test_set_x.append(xx) 182 | new_test_set_y.append(yy) 183 | new_test_set_u.append(uu) 184 | else: 185 | new_test_set_x.append(xx[:maxlen]) 186 | new_test_set_y.append(yy) 187 | new_test_set_u.append(uu) 188 | test_set = (new_test_set_x, new_test_set_y, new_test_set_u) 189 | del new_test_set_x, new_test_set_y, new_test_set_u 190 | 191 | test_set_x, test_set_y, test_set_u = test_set 192 | train_set_x, train_set_y, train_set_u = train_set 193 | 194 | def len_argsort(seq): 195 | return sorted(range(len(seq)), key=lambda x: len(seq[x])) 196 | 197 | if sort_by_len: 198 | sorted_index = len_argsort(test_set_x) 199 | test_set_x = [test_set_x[i] for i in sorted_index] 200 | test_set_y = [test_set_y[i] for i in sorted_index] 201 | 202 | train = (train_set_x, train_set_y, train_set_u) 203 | test = (test_set_x, test_set_y, test_set_u) 204 | 205 | return train, test 206 | 207 | 208 | def load_testdata(testFile, maxlen=19, sort_by_len=False): 209 | """Loads the dataset 210 | :type path: String 211 | :param path: The path to the dataset (here RSC2015) 212 | :type n_items: int 213 | :param n_items: The number of items. 214 | :type valid_portion: float 215 | :param valid_portion: The proportion of the full train set used for 216 | the validation set. 217 | :type maxlen: None or positive int 218 | :param maxlen: the max sequence length we use in the train/valid set. 219 | :type sort_by_len: bool 220 | :name sort_by_len: Sort by the sequence lenght for the train, 221 | valid and test set. This allow faster execution as it cause 222 | less padding per minibatch. Another mechanism must be used to 223 | shuffle the train set at each epoch. 224 | """ 225 | path_test_data = testFile 226 | 227 | f2 = open(path_test_data, 'rb') 228 | test_set = pickle.load(f2) 229 | f2.close() 230 | 231 | if maxlen: 232 | new_test_set_x = [] 233 | new_test_set_y = [] 234 | new_test_set_u = [] 235 | for xx, yy, uu in zip(test_set[0], test_set[1], test_set[2]): 236 | if len(xx) < maxlen: 237 | new_test_set_x.append(xx) 238 | new_test_set_y.append(yy) 239 | new_test_set_u.append(uu) 240 | else: 241 | new_test_set_x.append(xx[:maxlen]) 242 | new_test_set_y.append(yy) 243 | new_test_set_u.append(uu) 244 | test_set = (new_test_set_x, new_test_set_y, new_test_set_u) 245 | del new_test_set_x, new_test_set_y, new_test_set_u 246 | 247 | test_set_x, test_set_y, test_set_u = test_set 248 | 249 | def len_argsort(seq): 250 | return sorted(range(len(seq)), key=lambda x: len(seq[x])) 251 | 252 | if sort_by_len: 253 | sorted_index = len_argsort(test_set_x) 254 | test_set_x = [test_set_x[i] for i in sorted_index] 255 | test_set_y = [test_set_y[i] for i in sorted_index] 256 | 257 | test = (test_set_x, test_set_y, test_set_u) 258 | 259 | return test 260 | 261 | 262 | def load_data_valid(train_file, valid_portion=0.1, maxlen=19, sort_by_len=False): 263 | path_train_data = train_file 264 | 265 | f1 = open(path_train_data, 'rb') 266 | train_set = pickle.load(f1) 267 | f1.close() 268 | if maxlen: 269 | new_train_set_x = [] 270 | new_train_set_y = [] 271 | new_train_set_u = [] 272 | for x, y, u in list(zip(train_set[0], train_set[1],train_set[2])): 273 | if len(x) < maxlen: 274 | new_train_set_x.append(x) 275 | new_train_set_y.append(y) 276 | new_train_set_u.append(u) 277 | else: 278 | new_train_set_x.append(x[:maxlen]) 279 | new_train_set_y.append(y) 280 | new_train_set_u.append(u) 281 | train_set = (new_train_set_x, new_train_set_y, new_train_set_u) 282 | del new_train_set_x, new_train_set_y, new_train_set_u 283 | 284 | # split training set into validation set 285 | train_set_x, train_set_y, train_set_u = train_set 286 | n_samples = len(train_set_x) 287 | sidx = numpy.arange(n_samples, dtype='int32') 288 | numpy.random.shuffle(sidx) 289 | n_train = int(numpy.round(n_samples * (1. - valid_portion))) 290 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]] 291 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]] 292 | valid_set_u = [train_set_u[s] for s in sidx[n_train:]] 293 | 294 | train_set_x = [train_set_x[s] for s in sidx[:n_train]] 295 | train_set_y = [train_set_y[s] for s in sidx[:n_train]] 296 | train_set_u = [train_set_u[s] for s in sidx[:n_train]] 297 | 298 | train_set = (train_set_x, train_set_y, train_set_u) 299 | valid_set = (valid_set_x, valid_set_y, valid_set_u) 300 | 301 | valid_set_x, valid_set_y, valid_set_u = valid_set 302 | train_set_x, train_set_y, train_set_u = train_set 303 | 304 | def len_argsort(seq): 305 | return sorted(range(len(seq)), key=lambda x: len(seq[x])) 306 | 307 | if sort_by_len: 308 | 309 | sorted_index = len_argsort(valid_set_x) 310 | valid_set_x = [valid_set_x[i] for i in sorted_index] 311 | valid_set_y = [valid_set_y[i] for i in sorted_index] 312 | 313 | train = (train_set_x, train_set_y, train_set_u) 314 | valid = (valid_set_x, valid_set_y, valid_set_u) 315 | 316 | return train, valid 317 | --------------------------------------------------------------------------------