├── pytorch_code ├── main.py ├── utils.py └── model.py ├── tensorflow_code ├── main.py ├── utils.py └── model.py ├── README.md └── datasets └── preprocess.py /pytorch_code/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python36 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on July, 2018 5 | 6 | @author: Tangrizzly 7 | """ 8 | 9 | import argparse 10 | import pickle 11 | import time 12 | from utils import build_graph, Data, split_validation 13 | from model import * 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--dataset', default='sample', help='dataset name: diginetica/yoochoose1_4/yoochoose1_64/sample') 17 | parser.add_argument('--batchSize', type=int, default=100, help='input batch size') 18 | parser.add_argument('--hiddenSize', type=int, default=100, help='hidden state size') 19 | parser.add_argument('--epoch', type=int, default=30, help='the number of epochs to train for') 20 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') # [0.001, 0.0005, 0.0001] 21 | parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate') 22 | parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay') 23 | parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001] 24 | parser.add_argument('--step', type=int, default=1, help='gnn propogation steps') 25 | parser.add_argument('--patience', type=int, default=10, help='the number of epoch to wait before early stop ') 26 | parser.add_argument('--nonhybrid', action='store_true', help='only use the global preference to predict') 27 | parser.add_argument('--validation', action='store_true', help='validation') 28 | parser.add_argument('--valid_portion', type=float, default=0.1, help='split the portion of training set as validation set') 29 | opt = parser.parse_args() 30 | print(opt) 31 | 32 | 33 | def main(): 34 | train_data = pickle.load(open('../datasets/' + opt.dataset + '/train.txt', 'rb')) 35 | if opt.validation: 36 | train_data, valid_data = split_validation(train_data, opt.valid_portion) 37 | test_data = valid_data 38 | else: 39 | test_data = pickle.load(open('../datasets/' + opt.dataset + '/test.txt', 'rb')) 40 | # all_train_seq = pickle.load(open('../datasets/' + opt.dataset + '/all_train_seq.txt', 'rb')) 41 | # g = build_graph(all_train_seq) 42 | train_data = Data(train_data, shuffle=True) 43 | test_data = Data(test_data, shuffle=False) 44 | # del all_train_seq, g 45 | if opt.dataset == 'diginetica': 46 | n_node = 43098 47 | elif opt.dataset == 'yoochoose1_64' or opt.dataset == 'yoochoose1_4': 48 | n_node = 37484 49 | else: 50 | n_node = 310 51 | 52 | model = trans_to_cuda(SessionGraph(opt, n_node)) 53 | 54 | start = time.time() 55 | best_result = [0, 0] 56 | best_epoch = [0, 0] 57 | bad_counter = 0 58 | for epoch in range(opt.epoch): 59 | print('-------------------------------------------------------') 60 | print('epoch: ', epoch) 61 | hit, mrr = train_test(model, train_data, test_data) 62 | flag = 0 63 | if hit >= best_result[0]: 64 | best_result[0] = hit 65 | best_epoch[0] = epoch 66 | flag = 1 67 | if mrr >= best_result[1]: 68 | best_result[1] = mrr 69 | best_epoch[1] = epoch 70 | flag = 1 71 | print('Best Result:') 72 | print('\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d'% (best_result[0], best_result[1], best_epoch[0], best_epoch[1])) 73 | bad_counter += 1 - flag 74 | if bad_counter >= opt.patience: 75 | break 76 | print('-------------------------------------------------------') 77 | end = time.time() 78 | print("Run time: %f s" % (end - start)) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /tensorflow_code/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/10/17 5:40 4 | # @Author : {ZM7} 5 | # @File : main.py 6 | # @Software: PyCharm 7 | 8 | from __future__ import division 9 | import numpy as np 10 | from model import * 11 | from utils import build_graph, Data, split_validation 12 | import pickle 13 | import argparse 14 | import datetime 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dataset', default='sample', help='dataset name: diginetica/yoochoose1_4/yoochoose1_64/sample') 18 | parser.add_argument('--method', type=str, default='ggnn', help='ggnn/gat/gcn') 19 | parser.add_argument('--validation', action='store_true', help='validation') 20 | parser.add_argument('--epoch', type=int, default=30, help='number of epochs to train for') 21 | parser.add_argument('--batchSize', type=int, default=100, help='input batch size') 22 | parser.add_argument('--hiddenSize', type=int, default=100, help='hidden state size') 23 | parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty') 24 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 25 | parser.add_argument('--step', type=int, default=1, help='gnn propogation steps') 26 | parser.add_argument('--nonhybrid', action='store_true', help='global preference') 27 | parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate') 28 | parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay') 29 | opt = parser.parse_args() 30 | train_data = pickle.load(open('../datasets/' + opt.dataset + '/train.txt', 'rb')) 31 | test_data = pickle.load(open('../datasets/' + opt.dataset + '/test.txt', 'rb')) 32 | # all_train_seq = pickle.load(open('../datasets/' + opt.dataset + '/all_train_seq.txt', 'rb')) 33 | if opt.dataset == 'diginetica': 34 | n_node = 43098 35 | elif opt.dataset == 'yoochoose1_64' or opt.dataset == 'yoochoose1_4': 36 | n_node = 37484 37 | else: 38 | n_node = 310 39 | # g = build_graph(all_train_seq) 40 | train_data = Data(train_data, sub_graph=True, method=opt.method, shuffle=True) 41 | test_data = Data(test_data, sub_graph=True, method=opt.method, shuffle=False) 42 | model = GGNN(hidden_size=opt.hiddenSize, out_size=opt.hiddenSize, batch_size=opt.batchSize, n_node=n_node, 43 | lr=opt.lr, l2=opt.l2, step=opt.step, decay=opt.lr_dc_step * len(train_data.inputs) / opt.batchSize, lr_dc=opt.lr_dc, 44 | nonhybrid=opt.nonhybrid) 45 | print(opt) 46 | best_result = [0, 0] 47 | best_epoch = [0, 0] 48 | for epoch in range(opt.epoch): 49 | print('epoch: ', epoch, '===========================================') 50 | slices = train_data.generate_batch(model.batch_size) 51 | fetches = [model.opt, model.loss_train, model.global_step] 52 | print('start training: ', datetime.datetime.now()) 53 | loss_ = [] 54 | for i, j in zip(slices, np.arange(len(slices))): 55 | adj_in, adj_out, alias, item, mask, targets = train_data.get_slice(i) 56 | _, loss, _ = model.run(fetches, targets, item, adj_in, adj_out, alias, mask) 57 | loss_.append(loss) 58 | loss = np.mean(loss_) 59 | slices = test_data.generate_batch(model.batch_size) 60 | print('start predicting: ', datetime.datetime.now()) 61 | hit, mrr, test_loss_ = [], [],[] 62 | for i, j in zip(slices, np.arange(len(slices))): 63 | adj_in, adj_out, alias, item, mask, targets = test_data.get_slice(i) 64 | scores, test_loss = model.run([model.score_test, model.loss_test], targets, item, adj_in, adj_out, alias, mask) 65 | test_loss_.append(test_loss) 66 | index = np.argsort(scores, 1)[:, -20:] 67 | for score, target in zip(index, targets): 68 | hit.append(np.isin(target - 1, score)) 69 | if len(np.where(score == target - 1)[0]) == 0: 70 | mrr.append(0) 71 | else: 72 | mrr.append(1 / (20-np.where(score == target - 1)[0][0])) 73 | hit = np.mean(hit)*100 74 | mrr = np.mean(mrr)*100 75 | test_loss = np.mean(test_loss_) 76 | if hit >= best_result[0]: 77 | best_result[0] = hit 78 | best_epoch[0] = epoch 79 | if mrr >= best_result[1]: 80 | best_result[1] = mrr 81 | best_epoch[1]=epoch 82 | print('train_loss:\t%.4f\ttest_loss:\t%4f\tRecall@20:\t%.4f\tMMR@20:\t%.4f\tEpoch:\t%d,\t%d'% 83 | (loss, test_loss, best_result[0], best_result[1], best_epoch[0], best_epoch[1])) 84 | -------------------------------------------------------------------------------- /pytorch_code/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python36 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on July, 2018 5 | 6 | @author: Tangrizzly 7 | """ 8 | 9 | import networkx as nx 10 | import numpy as np 11 | 12 | 13 | def build_graph(train_data): 14 | graph = nx.DiGraph() 15 | for seq in train_data: 16 | for i in range(len(seq) - 1): 17 | if graph.get_edge_data(seq[i], seq[i + 1]) is None: 18 | weight = 1 19 | else: 20 | weight = graph.get_edge_data(seq[i], seq[i + 1])['weight'] + 1 21 | graph.add_edge(seq[i], seq[i + 1], weight=weight) 22 | for node in graph.nodes: 23 | sum = 0 24 | for j, i in graph.in_edges(node): 25 | sum += graph.get_edge_data(j, i)['weight'] 26 | if sum != 0: 27 | for j, i in graph.in_edges(i): 28 | graph.add_edge(j, i, weight=graph.get_edge_data(j, i)['weight'] / sum) 29 | return graph 30 | 31 | 32 | def data_masks(all_usr_pois, item_tail): 33 | us_lens = [len(upois) for upois in all_usr_pois] 34 | len_max = max(us_lens) 35 | us_pois = [upois + item_tail * (len_max - le) for upois, le in zip(all_usr_pois, us_lens)] 36 | us_msks = [[1] * le + [0] * (len_max - le) for le in us_lens] 37 | return us_pois, us_msks, len_max 38 | 39 | 40 | def split_validation(train_set, valid_portion): 41 | train_set_x, train_set_y = train_set 42 | n_samples = len(train_set_x) 43 | sidx = np.arange(n_samples, dtype='int32') 44 | np.random.shuffle(sidx) 45 | n_train = int(np.round(n_samples * (1. - valid_portion))) 46 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]] 47 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]] 48 | train_set_x = [train_set_x[s] for s in sidx[:n_train]] 49 | train_set_y = [train_set_y[s] for s in sidx[:n_train]] 50 | 51 | return (train_set_x, train_set_y), (valid_set_x, valid_set_y) 52 | 53 | 54 | class Data(): 55 | def __init__(self, data, shuffle=False, graph=None): 56 | inputs = data[0] 57 | inputs, mask, len_max = data_masks(inputs, [0]) 58 | self.inputs = np.asarray(inputs) 59 | self.mask = np.asarray(mask) 60 | self.len_max = len_max 61 | self.targets = np.asarray(data[1]) 62 | self.length = len(inputs) 63 | self.shuffle = shuffle 64 | self.graph = graph 65 | 66 | def generate_batch(self, batch_size): 67 | if self.shuffle: 68 | shuffled_arg = np.arange(self.length) 69 | np.random.shuffle(shuffled_arg) 70 | self.inputs = self.inputs[shuffled_arg] 71 | self.mask = self.mask[shuffled_arg] 72 | self.targets = self.targets[shuffled_arg] 73 | n_batch = int(self.length / batch_size) 74 | if self.length % batch_size != 0: 75 | n_batch += 1 76 | slices = np.split(np.arange(n_batch * batch_size), n_batch) 77 | slices[-1] = slices[-1][:(self.length - batch_size * (n_batch - 1))] 78 | return slices 79 | 80 | def get_slice(self, i): 81 | inputs, mask, targets = self.inputs[i], self.mask[i], self.targets[i] 82 | items, n_node, A, alias_inputs = [], [], [], [] 83 | for u_input in inputs: 84 | n_node.append(len(np.unique(u_input))) 85 | max_n_node = np.max(n_node) 86 | for u_input in inputs: 87 | node = np.unique(u_input) 88 | items.append(node.tolist() + (max_n_node - len(node)) * [0]) 89 | u_A = np.zeros((max_n_node, max_n_node)) 90 | for i in np.arange(len(u_input) - 1): 91 | if u_input[i + 1] == 0: 92 | break 93 | u = np.where(node == u_input[i])[0][0] 94 | v = np.where(node == u_input[i + 1])[0][0] 95 | u_A[u][v] = 1 96 | u_sum_in = np.sum(u_A, 0) 97 | u_sum_in[np.where(u_sum_in == 0)] = 1 98 | u_A_in = np.divide(u_A, u_sum_in) 99 | u_sum_out = np.sum(u_A, 1) 100 | u_sum_out[np.where(u_sum_out == 0)] = 1 101 | u_A_out = np.divide(u_A.transpose(), u_sum_out) 102 | u_A = np.concatenate([u_A_in, u_A_out]).transpose() 103 | A.append(u_A) 104 | alias_inputs.append([np.where(node == i)[0][0] for i in u_input]) 105 | return alias_inputs, A, items, mask, targets 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SR-GNN 2 | 3 | ## Paper data and code 4 | 5 | This is the code for the AAAI 2019 Paper: [Session-based Recommendation with Graph Neural Networks](https://arxiv.org/abs/1811.00855). We have implemented our methods in both **Tensorflow** and **Pytorch**. 6 | 7 | Here are two datasets we used in our paper. After downloaded the datasets, you can put them in the folder `datasets/`: 8 | 9 | - YOOCHOOSE: or 10 | 11 | - DIGINETICA: or 12 | 13 | There is a small dataset `sample` included in the folder `datasets/`, which can be used to test the correctness of the code. 14 | 15 | We have also written a [blog](https://sxkdz.github.io/research/SR-GNN) explaining the paper. 16 | 17 | ## Usage 18 | 19 | You need to run the file `datasets/preprocess.py` first to preprocess the data. 20 | 21 | For example: `cd datasets; python preprocess.py --dataset=sample` 22 | 23 | ```bash 24 | usage: preprocess.py [-h] [--dataset DATASET] 25 | 26 | optional arguments: 27 | -h, --help show this help message and exit 28 | --dataset DATASET dataset name: diginetica/yoochoose/sample 29 | ``` 30 | 31 | Then you can run the file `pytorch_code/main.py` or `tensorflow_code/main.py` to train the model. 32 | 33 | For example: `cd pytorch_code; python main.py --dataset=sample` 34 | 35 | You can add the suffix `--nonhybrid` to use the global preference of a session graph to recommend instead of the hybrid preference. 36 | 37 | You can also change other parameters according to the usage: 38 | 39 | ```bash 40 | usage: main.py [-h] [--dataset DATASET] [--batchSize BATCHSIZE] 41 | [--hiddenSize HIDDENSIZE] [--epoch EPOCH] [--lr LR] 42 | [--lr_dc LR_DC] [--lr_dc_step LR_DC_STEP] [--l2 L2] 43 | [--step STEP] [--patience PATIENCE] [--nonhybrid] 44 | [--validation] [--valid_portion VALID_PORTION] 45 | 46 | optional arguments: 47 | -h, --help show this help message and exit 48 | --dataset DATASET dataset name: 49 | diginetica/yoochoose1_4/yoochoose1_64/sample 50 | --batchSize BATCHSIZE 51 | input batch size 52 | --hiddenSize HIDDENSIZE 53 | hidden state size 54 | --epoch EPOCH the number of epochs to train for 55 | --lr LR learning rate 56 | --lr_dc LR_DC learning rate decay rate 57 | --lr_dc_step LR_DC_STEP 58 | the number of epochs after which the learning rate 59 | decay 60 | --l2 L2 l2 penalty 61 | --step STEP gnn propogation steps 62 | --patience PATIENCE the number of epoch to wait before early stop 63 | --nonhybrid only use the global preference to predict 64 | --validation validation 65 | --valid_portion VALID_PORTION 66 | split the portion of training set as validation set 67 | ``` 68 | 69 | ## Requirements 70 | 71 | - Python 3 72 | - PyTorch 0.4.0 or Tensorflow 1.9.0 73 | 74 | ## Other Implementation for Reference 75 | There are other implementation available for reference: 76 | - Implementation based on PaddlePaddle by Baidu [[Link]](https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/gnn) 77 | - Implementation based on PyTorch Geometric [[Link]](https://github.com/RuihongQiu/SR-GNN_PyTorch-Geometric) 78 | - Another implementation based on Tensorflow [[Link]](https://github.com/jimanvlad/SR-GNN) 79 | - Yet another implementation based on Tensorflow [[Link]](https://github.com/loserChen/TensorFlow-In-Practice/tree/master/SRGNN) 80 | 81 | ## Citation 82 | 83 | Please cite our paper if you use the code: 84 | 85 | ``` 86 | @inproceedings{Wu:2019ke, 87 | title = {{Session-based Recommendation with Graph Neural Networks}}, 88 | author = {Wu, Shu and Tang, Yuyuan and Zhu, Yanqiao and Wang, Liang and Xie, Xing and Tan, Tieniu}, 89 | year = 2019, 90 | booktitle = {Proceedings of the Twenty-Third AAAI Conference on Artificial Intelligence}, 91 | location = {Honolulu, HI, USA}, 92 | month = jul, 93 | volume = 33, 94 | number = 1, 95 | series = {AAAI '19}, 96 | pages = {346--353}, 97 | url = {https://aaai.org/ojs/index.php/AAAI/article/view/3804}, 98 | doi = {10.1609/aaai.v33i01.3301346}, 99 | editor = {Pascal Van Hentenryck and Zhi-Hua Zhou}, 100 | } 101 | ``` 102 | 103 | -------------------------------------------------------------------------------- /tensorflow_code/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/9/23 2:52 4 | # @Author : {ZM7} 5 | # @File : utils.py 6 | # @Software: PyCharm 7 | 8 | import networkx as nx 9 | import numpy as np 10 | 11 | 12 | def build_graph(train_data): 13 | graph = nx.DiGraph() 14 | for seq in train_data: 15 | for i in range(len(seq) - 1): 16 | if graph.get_edge_data(seq[i], seq[i + 1]) is None: 17 | weight = 1 18 | else: 19 | weight = graph.get_edge_data(seq[i], seq[i + 1])['weight'] + 1 20 | graph.add_edge(seq[i], seq[i + 1], weight=weight) 21 | for node in graph.nodes: 22 | sum = 0 23 | for j, i in graph.in_edges(node): 24 | sum += graph.get_edge_data(j, i)['weight'] 25 | if sum != 0: 26 | for j, i in graph.in_edges(i): 27 | graph.add_edge(j, i, weight=graph.get_edge_data(j, i)['weight'] / sum) 28 | return graph 29 | 30 | 31 | def data_masks(all_usr_pois, item_tail): 32 | us_lens = [len(upois) for upois in all_usr_pois] 33 | len_max = max(us_lens) 34 | us_pois = [upois + item_tail * (len_max - le) for upois, le in zip(all_usr_pois, us_lens)] 35 | us_msks = [[1] * le + [0] * (len_max - le) for le in us_lens] 36 | return us_pois, us_msks, len_max 37 | 38 | 39 | def split_validation(train_set, valid_portion): 40 | train_set_x, train_set_y = train_set 41 | n_samples = len(train_set_x) 42 | sidx = np.arange(n_samples, dtype='int32') 43 | np.random.shuffle(sidx) 44 | n_train = int(np.round(n_samples * (1. - valid_portion))) 45 | valid_set_x = [train_set_x[s] for s in sidx[n_train:]] 46 | valid_set_y = [train_set_y[s] for s in sidx[n_train:]] 47 | train_set_x = [train_set_x[s] for s in sidx[:n_train]] 48 | train_set_y = [train_set_y[s] for s in sidx[:n_train]] 49 | 50 | return (train_set_x, train_set_y), (valid_set_x, valid_set_y) 51 | 52 | 53 | class Data(): 54 | def __init__(self, data, sub_graph=False, method='ggnn', sparse=False, shuffle=False): 55 | inputs = data[0] 56 | inputs, mask, len_max = data_masks(inputs, [0]) 57 | self.inputs = np.asarray(inputs) 58 | self.mask = np.asarray(mask) 59 | self.len_max = len_max 60 | self.targets = np.asarray(data[1]) 61 | self.length = len(inputs) 62 | self.shuffle = shuffle 63 | self.sub_graph = sub_graph 64 | self.sparse = sparse 65 | self.method = method 66 | 67 | def generate_batch(self, batch_size): 68 | if self.shuffle: 69 | shuffled_arg = np.arange(self.length) 70 | np.random.shuffle(shuffled_arg) 71 | self.inputs = self.inputs[shuffled_arg] 72 | self.mask = self.mask[shuffled_arg] 73 | self.targets = self.targets[shuffled_arg] 74 | n_batch = int(self.length / batch_size) 75 | if self.length % batch_size != 0: 76 | n_batch += 1 77 | slices = np.split(np.arange(n_batch * batch_size), n_batch) 78 | slices[-1] = np.arange(self.length-batch_size, self.length) 79 | return slices 80 | 81 | def get_slice(self, index): 82 | if 1: 83 | items, n_node, A_in, A_out, alias_inputs = [], [], [], [], [] 84 | for u_input in self.inputs[index]: 85 | n_node.append(len(np.unique(u_input))) 86 | max_n_node = np.max(n_node) 87 | if self.method == 'ggnn': 88 | for u_input in self.inputs[index]: 89 | node = np.unique(u_input) 90 | items.append(node.tolist() + (max_n_node - len(node)) * [0]) 91 | u_A = np.zeros((max_n_node, max_n_node)) 92 | for i in np.arange(len(u_input) - 1): 93 | if u_input[i + 1] == 0: 94 | break 95 | u = np.where(node == u_input[i])[0][0] 96 | v = np.where(node == u_input[i + 1])[0][0] 97 | u_A[u][v] = 1 98 | u_sum_in = np.sum(u_A, 0) 99 | u_sum_in[np.where(u_sum_in == 0)] = 1 100 | u_A_in = np.divide(u_A, u_sum_in) 101 | u_sum_out = np.sum(u_A, 1) 102 | u_sum_out[np.where(u_sum_out == 0)] = 1 103 | u_A_out = np.divide(u_A.transpose(), u_sum_out) 104 | 105 | A_in.append(u_A_in) 106 | A_out.append(u_A_out) 107 | alias_inputs.append([np.where(node == i)[0][0] for i in u_input]) 108 | return A_in, A_out, alias_inputs, items, self.mask[index], self.targets[index] 109 | elif self.method == 'gat': 110 | A_in = [] 111 | A_out = [] 112 | for u_input in self.inputs[index]: 113 | node = np.unique(u_input) 114 | items.append(node.tolist() + (max_n_node - len(node)) * [0]) 115 | u_A = np.eye(max_n_node) 116 | for i in np.arange(len(u_input) - 1): 117 | if u_input[i + 1] == 0: 118 | break 119 | u = np.where(node == u_input[i])[0][0] 120 | v = np.where(node == u_input[i + 1])[0][0] 121 | u_A[u][v] = 1 122 | A_in.append(-1e9 * (1 - u_A)) 123 | A_out.append(-1e9 * (1 - u_A.transpose())) 124 | alias_inputs.append([np.where(node == i)[0][0] for i in u_input]) 125 | return A_in, A_out, alias_inputs, items, self.mask[index], self.targets[index] 126 | 127 | else: 128 | return self.inputs[index], self.mask[index], self.targets[index] -------------------------------------------------------------------------------- /pytorch_code/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python36 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on July, 2018 5 | 6 | @author: Tangrizzly 7 | """ 8 | 9 | import datetime 10 | import math 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.nn import Module, Parameter 15 | import torch.nn.functional as F 16 | 17 | 18 | class GNN(Module): 19 | def __init__(self, hidden_size, step=1): 20 | super(GNN, self).__init__() 21 | self.step = step 22 | self.hidden_size = hidden_size 23 | self.input_size = hidden_size * 2 24 | self.gate_size = 3 * hidden_size 25 | self.w_ih = Parameter(torch.Tensor(self.gate_size, self.input_size)) 26 | self.w_hh = Parameter(torch.Tensor(self.gate_size, self.hidden_size)) 27 | self.b_ih = Parameter(torch.Tensor(self.gate_size)) 28 | self.b_hh = Parameter(torch.Tensor(self.gate_size)) 29 | self.b_iah = Parameter(torch.Tensor(self.hidden_size)) 30 | self.b_oah = Parameter(torch.Tensor(self.hidden_size)) 31 | 32 | self.linear_edge_in = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 33 | self.linear_edge_out = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 34 | self.linear_edge_f = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 35 | 36 | def GNNCell(self, A, hidden): 37 | input_in = torch.matmul(A[:, :, :A.shape[1]], self.linear_edge_in(hidden)) + self.b_iah 38 | input_out = torch.matmul(A[:, :, A.shape[1]: 2 * A.shape[1]], self.linear_edge_out(hidden)) + self.b_oah 39 | inputs = torch.cat([input_in, input_out], 2) 40 | gi = F.linear(inputs, self.w_ih, self.b_ih) 41 | gh = F.linear(hidden, self.w_hh, self.b_hh) 42 | i_r, i_i, i_n = gi.chunk(3, 2) 43 | h_r, h_i, h_n = gh.chunk(3, 2) 44 | resetgate = torch.sigmoid(i_r + h_r) 45 | inputgate = torch.sigmoid(i_i + h_i) 46 | newgate = torch.tanh(i_n + resetgate * h_n) 47 | hy = newgate + inputgate * (hidden - newgate) 48 | return hy 49 | 50 | def forward(self, A, hidden): 51 | for i in range(self.step): 52 | hidden = self.GNNCell(A, hidden) 53 | return hidden 54 | 55 | 56 | class SessionGraph(Module): 57 | def __init__(self, opt, n_node): 58 | super(SessionGraph, self).__init__() 59 | self.hidden_size = opt.hiddenSize 60 | self.n_node = n_node 61 | self.batch_size = opt.batchSize 62 | self.nonhybrid = opt.nonhybrid 63 | self.embedding = nn.Embedding(self.n_node, self.hidden_size) 64 | self.gnn = GNN(self.hidden_size, step=opt.step) 65 | self.linear_one = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 66 | self.linear_two = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 67 | self.linear_three = nn.Linear(self.hidden_size, 1, bias=False) 68 | self.linear_transform = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=True) 69 | self.loss_function = nn.CrossEntropyLoss() 70 | self.optimizer = torch.optim.Adam(self.parameters(), lr=opt.lr, weight_decay=opt.l2) 71 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | stdv = 1.0 / math.sqrt(self.hidden_size) 76 | for weight in self.parameters(): 77 | weight.data.uniform_(-stdv, stdv) 78 | 79 | def compute_scores(self, hidden, mask): 80 | ht = hidden[torch.arange(mask.shape[0]).long(), torch.sum(mask, 1) - 1] # batch_size x latent_size 81 | q1 = self.linear_one(ht).view(ht.shape[0], 1, ht.shape[1]) # batch_size x 1 x latent_size 82 | q2 = self.linear_two(hidden) # batch_size x seq_length x latent_size 83 | alpha = self.linear_three(torch.sigmoid(q1 + q2)) 84 | a = torch.sum(alpha * hidden * mask.view(mask.shape[0], -1, 1).float(), 1) 85 | if not self.nonhybrid: 86 | a = self.linear_transform(torch.cat([a, ht], 1)) 87 | b = self.embedding.weight[1:] # n_nodes x latent_size 88 | scores = torch.matmul(a, b.transpose(1, 0)) 89 | return scores 90 | 91 | def forward(self, inputs, A): 92 | hidden = self.embedding(inputs) 93 | hidden = self.gnn(A, hidden) 94 | return hidden 95 | 96 | 97 | def trans_to_cuda(variable): 98 | if torch.cuda.is_available(): 99 | return variable.cuda() 100 | else: 101 | return variable 102 | 103 | 104 | def trans_to_cpu(variable): 105 | if torch.cuda.is_available(): 106 | return variable.cpu() 107 | else: 108 | return variable 109 | 110 | 111 | def forward(model, i, data): 112 | alias_inputs, A, items, mask, targets = data.get_slice(i) 113 | alias_inputs = trans_to_cuda(torch.Tensor(alias_inputs).long()) 114 | items = trans_to_cuda(torch.Tensor(items).long()) 115 | A = trans_to_cuda(torch.Tensor(A).float()) 116 | mask = trans_to_cuda(torch.Tensor(mask).long()) 117 | hidden = model(items, A) 118 | get = lambda i: hidden[i][alias_inputs[i]] 119 | seq_hidden = torch.stack([get(i) for i in torch.arange(len(alias_inputs)).long()]) 120 | return targets, model.compute_scores(seq_hidden, mask) 121 | 122 | 123 | def train_test(model, train_data, test_data): 124 | model.scheduler.step() 125 | print('start training: ', datetime.datetime.now()) 126 | model.train() 127 | total_loss = 0.0 128 | slices = train_data.generate_batch(model.batch_size) 129 | for i, j in zip(slices, np.arange(len(slices))): 130 | model.optimizer.zero_grad() 131 | targets, scores = forward(model, i, train_data) 132 | targets = trans_to_cuda(torch.Tensor(targets).long()) 133 | loss = model.loss_function(scores, targets - 1) 134 | loss.backward() 135 | model.optimizer.step() 136 | total_loss += loss 137 | if j % int(len(slices) / 5 + 1) == 0: 138 | print('[%d/%d] Loss: %.4f' % (j, len(slices), loss.item())) 139 | print('\tLoss:\t%.3f' % total_loss) 140 | 141 | print('start predicting: ', datetime.datetime.now()) 142 | model.eval() 143 | hit, mrr = [], [] 144 | slices = test_data.generate_batch(model.batch_size) 145 | for i in slices: 146 | targets, scores = forward(model, i, test_data) 147 | sub_scores = scores.topk(20)[1] 148 | sub_scores = trans_to_cpu(sub_scores).detach().numpy() 149 | for score, target, mask in zip(sub_scores, targets, test_data.mask): 150 | hit.append(np.isin(target - 1, score)) 151 | if len(np.where(score == target - 1)[0]) == 0: 152 | mrr.append(0) 153 | else: 154 | mrr.append(1 / (np.where(score == target - 1)[0][0] + 1)) 155 | hit = np.mean(hit) * 100 156 | mrr = np.mean(mrr) * 100 157 | return hit, mrr 158 | -------------------------------------------------------------------------------- /tensorflow_code/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/10/16 4:36 4 | # @Author : {ZM7} 5 | # @File : model.py 6 | # @Software: PyCharm 7 | import tensorflow as tf 8 | import math 9 | 10 | 11 | class Model(object): 12 | def __init__(self, hidden_size=100, out_size=100, batch_size=100, nonhybrid=True): 13 | self.hidden_size = hidden_size 14 | self.out_size = out_size 15 | self.batch_size = batch_size 16 | self.mask = tf.placeholder(dtype=tf.float32) 17 | self.alias = tf.placeholder(dtype=tf.int32) # 给给每个输入重新 18 | self.item = tf.placeholder(dtype=tf.int32) # 重新编号的序列构成的矩阵 19 | self.tar = tf.placeholder(dtype=tf.int32) 20 | self.nonhybrid = nonhybrid 21 | self.stdv = 1.0 / math.sqrt(self.hidden_size) 22 | 23 | self.nasr_w1 = tf.get_variable('nasr_w1', [self.out_size, self.out_size], dtype=tf.float32, 24 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 25 | self.nasr_w2 = tf.get_variable('nasr_w2', [self.out_size, self.out_size], dtype=tf.float32, 26 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 27 | self.nasr_v = tf.get_variable('nasrv', [1, self.out_size], dtype=tf.float32, 28 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 29 | self.nasr_b = tf.get_variable('nasr_b', [self.out_size], dtype=tf.float32, initializer=tf.zeros_initializer()) 30 | 31 | def forward(self, re_embedding, train=True): 32 | rm = tf.reduce_sum(self.mask, 1) 33 | last_id = tf.gather_nd(self.alias, tf.stack([tf.range(self.batch_size), tf.to_int32(rm)-1], axis=1)) 34 | last_h = tf.gather_nd(re_embedding, tf.stack([tf.range(self.batch_size), last_id], axis=1)) 35 | seq_h = tf.stack([tf.nn.embedding_lookup(re_embedding[i], self.alias[i]) for i in range(self.batch_size)], 36 | axis=0) #batch_size*T*d 37 | last = tf.matmul(last_h, self.nasr_w1) 38 | seq = tf.matmul(tf.reshape(seq_h, [-1, self.out_size]), self.nasr_w2) 39 | last = tf.reshape(last, [self.batch_size, 1, -1]) 40 | m = tf.nn.sigmoid(last + tf.reshape(seq, [self.batch_size, -1, self.out_size]) + self.nasr_b) 41 | coef = tf.matmul(tf.reshape(m, [-1, self.out_size]), self.nasr_v, transpose_b=True) * tf.reshape( 42 | self.mask, [-1, 1]) 43 | b = self.embedding[1:] 44 | if not self.nonhybrid: 45 | ma = tf.concat([tf.reduce_sum(tf.reshape(coef, [self.batch_size, -1, 1]) * seq_h, 1), 46 | tf.reshape(last, [-1, self.out_size])], -1) 47 | self.B = tf.get_variable('B', [2 * self.out_size, self.out_size], 48 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 49 | y1 = tf.matmul(ma, self.B) 50 | logits = tf.matmul(y1, b, transpose_b=True) 51 | else: 52 | ma = tf.reduce_sum(tf.reshape(coef, [self.batch_size, -1, 1]) * seq_h, 1) 53 | logits = tf.matmul(ma, b, transpose_b=True) 54 | loss = tf.reduce_mean( 55 | tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.tar - 1, logits=logits)) 56 | self.vars = tf.trainable_variables() 57 | if train: 58 | lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in self.vars if v.name not 59 | in ['bias', 'gamma', 'b', 'g', 'beta']]) * self.L2 60 | loss = loss + lossL2 61 | return loss, logits 62 | 63 | def run(self, fetches, tar, item, adj_in, adj_out, alias, mask): 64 | return self.sess.run(fetches, feed_dict={self.tar: tar, self.item: item, self.adj_in: adj_in, 65 | self.adj_out: adj_out, self.alias: alias, self.mask: mask}) 66 | 67 | 68 | class GGNN(Model): 69 | def __init__(self,hidden_size=100, out_size=100, batch_size=300, n_node=None, 70 | lr=None, l2=None, step=1, decay=None, lr_dc=0.1, nonhybrid=False): 71 | super(GGNN,self).__init__(hidden_size, out_size, batch_size, nonhybrid) 72 | self.embedding = tf.get_variable(shape=[n_node, hidden_size], name='embedding', dtype=tf.float32, 73 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 74 | self.adj_in = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, None, None]) 75 | self.adj_out = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, None, None]) 76 | self.n_node = n_node 77 | self.L2 = l2 78 | self.step = step 79 | self.nonhybrid = nonhybrid 80 | self.W_in = tf.get_variable('W_in', shape=[self.out_size, self.out_size], dtype=tf.float32, 81 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 82 | self.b_in = tf.get_variable('b_in', [self.out_size], dtype=tf.float32, 83 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 84 | self.W_out = tf.get_variable('W_out', [self.out_size, self.out_size], dtype=tf.float32, 85 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 86 | self.b_out = tf.get_variable('b_out', [self.out_size], dtype=tf.float32, 87 | initializer=tf.random_uniform_initializer(-self.stdv, self.stdv)) 88 | with tf.variable_scope('ggnn_model', reuse=None): 89 | self.loss_train, _ = self.forward(self.ggnn()) 90 | with tf.variable_scope('ggnn_model', reuse=True): 91 | self.loss_test, self.score_test = self.forward(self.ggnn(), train=False) 92 | self.global_step = tf.Variable(0) 93 | self.learning_rate = tf.train.exponential_decay(lr, global_step=self.global_step, decay_steps=decay, 94 | decay_rate=lr_dc, staircase=True) 95 | self.opt = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss_train, global_step=self.global_step) 96 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8) 97 | config = tf.ConfigProto(gpu_options=gpu_options) 98 | config.gpu_options.allow_growth = True 99 | self.sess = tf.Session(config=config) 100 | self.sess.run(tf.global_variables_initializer()) 101 | 102 | def ggnn(self): 103 | fin_state = tf.nn.embedding_lookup(self.embedding, self.item) 104 | cell = tf.nn.rnn_cell.GRUCell(self.out_size) 105 | with tf.variable_scope('gru'): 106 | for i in range(self.step): 107 | fin_state = tf.reshape(fin_state, [self.batch_size, -1, self.out_size]) 108 | fin_state_in = tf.reshape(tf.matmul(tf.reshape(fin_state, [-1, self.out_size]), 109 | self.W_in) + self.b_in, [self.batch_size, -1, self.out_size]) 110 | fin_state_out = tf.reshape(tf.matmul(tf.reshape(fin_state, [-1, self.out_size]), 111 | self.W_out) + self.b_out, [self.batch_size, -1, self.out_size]) 112 | av = tf.concat([tf.matmul(self.adj_in, fin_state_in), 113 | tf.matmul(self.adj_out, fin_state_out)], axis=-1) 114 | state_output, fin_state = \ 115 | tf.nn.dynamic_rnn(cell, tf.expand_dims(tf.reshape(av, [-1, 2*self.out_size]), axis=1), 116 | initial_state=tf.reshape(fin_state, [-1, self.out_size])) 117 | return tf.reshape(fin_state, [self.batch_size, -1, self.out_size]) 118 | 119 | 120 | -------------------------------------------------------------------------------- /datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python36 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on July, 2018 5 | 6 | @author: Tangrizzly 7 | """ 8 | 9 | import argparse 10 | import time 11 | import csv 12 | import pickle 13 | import operator 14 | import datetime 15 | import os 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default='sample', help='dataset name: diginetica/yoochoose/sample') 19 | opt = parser.parse_args() 20 | print(opt) 21 | 22 | dataset = 'sample_train-item-views.csv' 23 | if opt.dataset == 'diginetica': 24 | dataset = 'train-item-views.csv' 25 | elif opt.dataset =='yoochoose': 26 | dataset = 'yoochoose-clicks.dat' 27 | 28 | print("-- Starting @ %ss" % datetime.datetime.now()) 29 | with open(dataset, "r") as f: 30 | if opt.dataset == 'yoochoose': 31 | reader = csv.DictReader(f, delimiter=',') 32 | else: 33 | reader = csv.DictReader(f, delimiter=';') 34 | sess_clicks = {} 35 | sess_date = {} 36 | ctr = 0 37 | curid = -1 38 | curdate = None 39 | for data in reader: 40 | sessid = data['session_id'] 41 | if curdate and not curid == sessid: 42 | date = '' 43 | if opt.dataset == 'yoochoose': 44 | date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S')) 45 | else: 46 | date = time.mktime(time.strptime(curdate, '%Y-%m-%d')) 47 | sess_date[curid] = date 48 | curid = sessid 49 | if opt.dataset == 'yoochoose': 50 | item = data['item_id'] 51 | else: 52 | item = data['item_id'], int(data['timeframe']) 53 | curdate = '' 54 | if opt.dataset == 'yoochoose': 55 | curdate = data['timestamp'] 56 | else: 57 | curdate = data['eventdate'] 58 | 59 | if sessid in sess_clicks: 60 | sess_clicks[sessid] += [item] 61 | else: 62 | sess_clicks[sessid] = [item] 63 | ctr += 1 64 | date = '' 65 | if opt.dataset == 'yoochoose': 66 | date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S')) 67 | else: 68 | date = time.mktime(time.strptime(curdate, '%Y-%m-%d')) 69 | for i in list(sess_clicks): 70 | sorted_clicks = sorted(sess_clicks[i], key=operator.itemgetter(1)) 71 | sess_clicks[i] = [c[0] for c in sorted_clicks] 72 | sess_date[curid] = date 73 | print("-- Reading data @ %ss" % datetime.datetime.now()) 74 | 75 | # Filter out length 1 sessions 76 | for s in list(sess_clicks): 77 | if len(sess_clicks[s]) == 1: 78 | del sess_clicks[s] 79 | del sess_date[s] 80 | 81 | # Count number of times each item appears 82 | iid_counts = {} 83 | for s in sess_clicks: 84 | seq = sess_clicks[s] 85 | for iid in seq: 86 | if iid in iid_counts: 87 | iid_counts[iid] += 1 88 | else: 89 | iid_counts[iid] = 1 90 | 91 | sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1)) 92 | 93 | length = len(sess_clicks) 94 | for s in list(sess_clicks): 95 | curseq = sess_clicks[s] 96 | filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq)) 97 | if len(filseq) < 2: 98 | del sess_clicks[s] 99 | del sess_date[s] 100 | else: 101 | sess_clicks[s] = filseq 102 | 103 | # Split out test set based on dates 104 | dates = list(sess_date.items()) 105 | maxdate = dates[0][1] 106 | 107 | for _, date in dates: 108 | if maxdate < date: 109 | maxdate = date 110 | 111 | # 7 days for test 112 | splitdate = 0 113 | if opt.dataset == 'yoochoose': 114 | splitdate = maxdate - 86400 * 1 # the number of seconds for a day:86400 115 | else: 116 | splitdate = maxdate - 86400 * 7 117 | 118 | print('Splitting date', splitdate) # Yoochoose: ('Split date', 1411930799.0) 119 | tra_sess = filter(lambda x: x[1] < splitdate, dates) 120 | tes_sess = filter(lambda x: x[1] > splitdate, dates) 121 | 122 | # Sort sessions by date 123 | tra_sess = sorted(tra_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 124 | tes_sess = sorted(tes_sess, key=operator.itemgetter(1)) # [(session_id, timestamp), (), ] 125 | print(len(tra_sess)) # 186670 # 7966257 126 | print(len(tes_sess)) # 15979 # 15324 127 | print(tra_sess[:3]) 128 | print(tes_sess[:3]) 129 | print("-- Splitting train set and test set @ %ss" % datetime.datetime.now()) 130 | 131 | # Choosing item count >=5 gives approximately the same number of items as reported in paper 132 | item_dict = {} 133 | # Convert training sessions to sequences and renumber items to start from 1 134 | def obtian_tra(): 135 | train_ids = [] 136 | train_seqs = [] 137 | train_dates = [] 138 | item_ctr = 1 139 | for s, date in tra_sess: 140 | seq = sess_clicks[s] 141 | outseq = [] 142 | for i in seq: 143 | if i in item_dict: 144 | outseq += [item_dict[i]] 145 | else: 146 | outseq += [item_ctr] 147 | item_dict[i] = item_ctr 148 | item_ctr += 1 149 | if len(outseq) < 2: # Doesn't occur 150 | continue 151 | train_ids += [s] 152 | train_dates += [date] 153 | train_seqs += [outseq] 154 | print(item_ctr) # 43098, 37484 155 | return train_ids, train_dates, train_seqs 156 | 157 | 158 | # Convert test sessions to sequences, ignoring items that do not appear in training set 159 | def obtian_tes(): 160 | test_ids = [] 161 | test_seqs = [] 162 | test_dates = [] 163 | for s, date in tes_sess: 164 | seq = sess_clicks[s] 165 | outseq = [] 166 | for i in seq: 167 | if i in item_dict: 168 | outseq += [item_dict[i]] 169 | if len(outseq) < 2: 170 | continue 171 | test_ids += [s] 172 | test_dates += [date] 173 | test_seqs += [outseq] 174 | return test_ids, test_dates, test_seqs 175 | 176 | 177 | tra_ids, tra_dates, tra_seqs = obtian_tra() 178 | tes_ids, tes_dates, tes_seqs = obtian_tes() 179 | 180 | 181 | def process_seqs(iseqs, idates): 182 | out_seqs = [] 183 | out_dates = [] 184 | labs = [] 185 | ids = [] 186 | for id, seq, date in zip(range(len(iseqs)), iseqs, idates): 187 | for i in range(1, len(seq)): 188 | tar = seq[-i] 189 | labs += [tar] 190 | out_seqs += [seq[:-i]] 191 | out_dates += [date] 192 | ids += [id] 193 | return out_seqs, out_dates, labs, ids 194 | 195 | 196 | tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates) 197 | te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates) 198 | tra = (tr_seqs, tr_labs) 199 | tes = (te_seqs, te_labs) 200 | print(len(tr_seqs)) 201 | print(len(te_seqs)) 202 | print(tr_seqs[:3], tr_dates[:3], tr_labs[:3]) 203 | print(te_seqs[:3], te_dates[:3], te_labs[:3]) 204 | all = 0 205 | 206 | for seq in tra_seqs: 207 | all += len(seq) 208 | for seq in tes_seqs: 209 | all += len(seq) 210 | print('avg length: ', all/(len(tra_seqs) + len(tes_seqs) * 1.0)) 211 | if opt.dataset == 'diginetica': 212 | if not os.path.exists('diginetica'): 213 | os.makedirs('diginetica') 214 | pickle.dump(tra, open('diginetica/train.txt', 'wb')) 215 | pickle.dump(tes, open('diginetica/test.txt', 'wb')) 216 | pickle.dump(tra_seqs, open('diginetica/all_train_seq.txt', 'wb')) 217 | elif opt.dataset == 'yoochoose': 218 | if not os.path.exists('yoochoose1_4'): 219 | os.makedirs('yoochoose1_4') 220 | if not os.path.exists('yoochoose1_64'): 221 | os.makedirs('yoochoose1_64') 222 | pickle.dump(tes, open('yoochoose1_4/test.txt', 'wb')) 223 | pickle.dump(tes, open('yoochoose1_64/test.txt', 'wb')) 224 | 225 | split4, split64 = int(len(tr_seqs) / 4), int(len(tr_seqs) / 64) 226 | print(len(tr_seqs[-split4:])) 227 | print(len(tr_seqs[-split64:])) 228 | 229 | tra4, tra64 = (tr_seqs[-split4:], tr_labs[-split4:]), (tr_seqs[-split64:], tr_labs[-split64:]) 230 | seq4, seq64 = tra_seqs[tr_ids[-split4]:], tra_seqs[tr_ids[-split64]:] 231 | 232 | pickle.dump(tra4, open('yoochoose1_4/train.txt', 'wb')) 233 | pickle.dump(seq4, open('yoochoose1_4/all_train_seq.txt', 'wb')) 234 | 235 | pickle.dump(tra64, open('yoochoose1_64/train.txt', 'wb')) 236 | pickle.dump(seq64, open('yoochoose1_64/all_train_seq.txt', 'wb')) 237 | 238 | else: 239 | if not os.path.exists('sample'): 240 | os.makedirs('sample') 241 | pickle.dump(tra, open('sample/train.txt', 'wb')) 242 | pickle.dump(tes, open('sample/test.txt', 'wb')) 243 | pickle.dump(tra_seqs, open('sample/all_train_seq.txt', 'wb')) 244 | 245 | print('Done.') 246 | --------------------------------------------------------------------------------