├── contact.tar.gz ├── img ├── framework.png └── data_processing.png ├── utils.py ├── README.md ├── train_model.py ├── model.py └── baselines └── DDNE.py /contact.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianz94/e-lstm-d/HEAD/contact.tar.gz -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianz94/e-lstm-d/HEAD/img/framework.png -------------------------------------------------------------------------------- /img/data_processing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianz94/e-lstm-d/HEAD/img/data_processing.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from sklearn.metrics import roc_auc_score, precision_recall_curve, auc 4 | import keras.backend as K 5 | import tensorflow as tf 6 | 7 | 8 | def get_auc(x, y): 9 | return roc_auc_score(np.reshape(y, (-1, )), np.reshape(x, (-1, ))) 10 | 11 | 12 | def get_err_rate(x, y): 13 | return np.sum(np.abs(x - y)) / np.sum(y) 14 | 15 | 16 | def load_data(filePath): 17 | if not os.path.exists(filePath): 18 | raise FileNotFoundError 19 | else: 20 | return np.load(filePath) 21 | 22 | 23 | def build_refined_loss(beta): 24 | 25 | def refined_loss(y_true, y_pred): 26 | weight = y_true * (beta - 1) + 1 27 | return K.mean(K.sum(tf.multiply(weight, K.square(y_true - y_pred)), axis=1), axis=-1) 28 | return refined_loss 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # e-lstm-d 2 | 3 | This is a TensorFlow implementation of the paper: [E-LSTM-D: A Deep Learning Framework for Dynamic Network Link Prediction](https://arxiv.org/abs/1902.08329). The baselines used in the paper will be released as a toolbox soon. 4 | 5 | 6 | # Requirements 7 | - tensorflow (1.3.0) 8 | - keras (2.2.4) 9 | - scikit-learn (0.19.0) 10 | - numpy (1.14.2) 11 | 12 | # run the demo 13 | #### The framework of E-LSTM-D 14 | 15 | We provide the framework of E-LSTM-D and the detailed structure of it when applied on LKML. 16 | 17 | ![](img/framework.png) 18 | 19 | #### Data processing 20 | 21 | ![](img/data_processing.png) 22 | 23 | 1. prepare the data 24 | ``` 25 | mkdir data 26 | tar -xzvf contact.tar.gz ./data 27 | ``` 28 | 29 | 2. train model 30 | ``` 31 | python train_model.py --dataset contact --encoder [128] --lstm [256,256] --decoder [274] --num_epochs 1600 --batch_size 32 --BETA 10 --learning_rate 0.001 32 | ``` 33 | 34 | # Cite 35 | Please cite our paper if you use this code in your own work: 36 | ``` 37 | @article{chen2019lstm, 38 | title={E-LSTM-D: A Deep Learning Framework for Dynamic Network Link Prediction}, 39 | author={Chen, Jinyin and Zhang, Jian and Xu, Xuanheng and Fu, Chengbo and Zhang, Dan and Zhang, Qingpeng and Xuan, Qi}, 40 | journal={arXiv preprint arXiv:1902.08329}, 41 | year={2019} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | from model import e_lstm_d 2 | import os 3 | import tensorflow as tf 4 | import numpy as np 5 | from utils import * 6 | 7 | 8 | flags = tf.app.flags 9 | FLAGS = flags.FLAGS 10 | flags.DEFINE_string('GPU', '0', 'train model on which GPU devide. -1 for CPU') 11 | flags.DEFINE_string('dataset', 'contact', 'the dataset used for training and testing') 12 | flags.DEFINE_integer('historical_len', 10, 'number of historial snapshots each sample') 13 | flags.DEFINE_string('encoder', None, 'encoder structure parameters') 14 | flags.DEFINE_string('lstm', None, 'stacked lstm structure parameters') 15 | flags.DEFINE_string('decoder', None, 'decoder structure parameters') 16 | flags.DEFINE_integer('num_epochs', 800, 'Number of training epochs.') 17 | flags.DEFINE_integer('batch_size', 64, 'Batch size.') 18 | flags.DEFINE_float('weight_decay', 5e-4, 'Weight for regularization item') 19 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate. ') 20 | flags.DEFINE_float('BETA', 2., 'Beta.') 21 | 22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 23 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.GPU 24 | 25 | 26 | data = load_data('data/{}.npy'.format(FLAGS.dataset)) 27 | model = e_lstm_d(num_nodes=274, historical_len=FLAGS.historical_len, encoder_units=[int(x) for x in FLAGS.encoder[1:-1].split(',')], 28 | lstm_units=[int(x) for x in FLAGS.lstm[1:-1].split(',')], 29 | decoder_units=[int(x) for x in FLAGS.decoder[1:-1].split(',')], 30 | name=FLAGS.dataset) 31 | 32 | trainX = np.array([data[k: FLAGS.historical_len+k] for k in range(240)], dtype=np.float32) 33 | trainY = np.array(data[FLAGS.historical_len: 240+FLAGS.historical_len], dtype=np.float32) 34 | testX = np.array([data[240+k:240+FLAGS.historical_len+k] for k in range(80)], dtype=np.float32) 35 | testY = np.array(data[240+FLAGS.historical_len:320+FLAGS.historical_len], dtype=np.float32) 36 | 37 | history = model.train(trainX, trainY) 38 | loss = history.history['loss'] 39 | # np.save('loss.npy', np.array(loss)) 40 | aucs, err_rates = model.evaluate(testX, testY) 41 | # model.save_weights('tmp/') 42 | print(np.average(aucs), np.average(err_rates)) 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import keras.backend as K 3 | from keras.layers import Dense, Flatten, Input, LSTM, CuDNNLSTM, Reshape, Dropout, Add, Permute 4 | from keras.models import Sequential, Model 5 | from keras.regularizers import l2 6 | from keras.layers import Layer, TimeDistributed, Lambda 7 | from keras.optimizers import Adam, SGD, Adadelta 8 | import tensorflow as tf 9 | import numpy as np 10 | from sklearn.metrics import roc_auc_score 11 | from utils import * 12 | 13 | flags = tf.app.flags 14 | FLAGS = flags.FLAGS 15 | # flags.DEFINE_string('dataset', 'cora', 'Dataset string.') # 'cora', 'citeseer', 'pubmed' 16 | # flags.DEFINE_string('GPU', '-1', 'Model string.') # 'gcn', 'gcn_cheby', 'dense' 17 | # flags.DEFINE_float('weight_decay', 0.01, 'Initial learning rate.') 18 | 19 | 20 | class e_lstm_d(): 21 | 22 | def __init__(self, num_nodes, historical_len, encoder_units, lstm_units, decoder_units, name=None): 23 | 24 | self.historical_len = historical_len 25 | self.num_nodes = num_nodes 26 | self.encoder_units = encoder_units 27 | self.stacked_lstm_units = lstm_units 28 | self.decoder_units = decoder_units 29 | self.model = None 30 | self.loss = build_refined_loss(FLAGS.BETA) 31 | 32 | if name: 33 | self.name = name 34 | else: 35 | self.name = 'e_lstm_d' 36 | 37 | self._build() 38 | 39 | def _build(self): 40 | self.encoder = self._build_encoder() 41 | self.stacked_lstm = self._build_stack_lstm() 42 | self.decoder = self._build_decoder() 43 | 44 | x = Input(shape=(self.historical_len, self.num_nodes, self.num_nodes)) 45 | h = TimeDistributed(self.encoder)(x) 46 | h = Reshape((self.historical_len, -1))(h) 47 | h = Lambda(lambda x: K.sum(x, axis=1))(h) 48 | h = Reshape((self.num_nodes, -1))(h) 49 | h = self.stacked_lstm(h) 50 | y = self.decoder(h) 51 | 52 | self.model = Model(inputs=x, outputs=y) 53 | 54 | def _build_encoder(self): 55 | model = Sequential() 56 | for i in range(len(self.encoder_units)): 57 | if i == 0: 58 | model.add(Dense(self.encoder_units[i], input_shape=(self.historical_len, self.num_nodes, self.num_nodes), 59 | activation='relu', kernel_regularizer=l2(FLAGS.weight_decay))) 60 | else: 61 | model.add(Dense(self.encoder_units[i], activation='relu', kernel_regularizer=l2(FLAGS.weight_decay))) 62 | return model 63 | 64 | def _build_decoder(self): 65 | model = Sequential() 66 | for i in range(len(self.decoder_units)): 67 | if i == len(self.decoder_units) - 1: 68 | if i == 0: 69 | model.add(Dense(self.decoder_units[i], input_shape=(self.num_nodes, self.stacked_lstm_units[-1]), activation='sigmoid', kernel_regularizer=l2(FLAGS.weight_decay))) 70 | else: 71 | model.add(Dense(self.decoder_units[i], activation='sigmoid', kernel_regularizer=l2(FLAGS.weight_decay))) 72 | else: 73 | if i == 0: 74 | model.add(Dense(self.decoder_units[i], input_shape=(self.num_nodes, self.stacked_lstm_units[-1]), activation='relu', kernel_regularizer=l2(FLAGS.weight_decay))) 75 | else: 76 | model.add(Dense(self.decoder_units[i], activation='relu', kernel_regularizer=l2(FLAGS.weight_decay))) 77 | return model 78 | 79 | def _build_stack_lstm(self): 80 | model = Sequential() 81 | if FLAGS.GPU == '-1': 82 | _lstm = LSTM 83 | else: 84 | _lstm = CuDNNLSTM 85 | for i in range(len(self.stacked_lstm_units)): 86 | if i == 0: 87 | model.add(_lstm(units=self.stacked_lstm_units[i], input_shape=(self.num_nodes, self.encoder_units[-1]), return_sequences=True)) 88 | else: 89 | model.add(_lstm(units=self.stacked_lstm_units[i], return_sequences=True)) 90 | return model 91 | 92 | def train(self, x, y): 93 | config = tf.ConfigProto() 94 | config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配 95 | session = tf.Session(config=config) 96 | K.set_session(session) 97 | self.model.compile(optimizer=Adam(lr=FLAGS.learning_rate), loss=self.loss) 98 | history = self.model.fit(x, y, batch_size=FLAGS.batch_size, epochs=FLAGS.num_epochs, verbose=1) 99 | return history 100 | 101 | def evaluate(self, x, y): 102 | y_preds = self.model.predict(x, batch_size=32) 103 | template = np.ones((self.num_nodes, self.num_nodes)) - np.identity(self.num_nodes) 104 | aucs, err_rates = [], [] 105 | for i in range(y_preds.shape[0]): 106 | y_pred = y_preds[i] * template 107 | aucs.append(get_auc(np.reshape(y_pred, (-1, )), np.reshape(y[i], (-1, )))) 108 | err_rates.append(get_err_rate(y_pred, y[i])) 109 | return aucs, err_rates 110 | 111 | def predict(self, x): 112 | return self.model.predict(x, batch_size=32) 113 | 114 | def save_weights(self, path): 115 | if not os.path.exists(path): 116 | os.makedirs(path) 117 | self.model.save_weights(path+self.name+'.h5') 118 | 119 | def load_weights(self, weightFile): 120 | if not os.path.exists(weightFile): 121 | raise FileNotFoundError 122 | else: 123 | self.model.load_weights(weightFile) 124 | 125 | 126 | if __name__ == "__main__": 127 | model = e_lstm_d(num_nodes=274, historical_len=10, encoder_units=[128], lstm_units=[256, 256], decoder_units=[274]) 128 | print(model.model.summary()) 129 | -------------------------------------------------------------------------------- /baselines/DDNE.py: -------------------------------------------------------------------------------- 1 | #coding -*- utf-8 -*- 2 | import tensorflow as tf 3 | from utils import edge_wise_loss, build_reconstruction_loss 4 | from tensorflow.contrib.keras.api.keras.layers import GRU, Dense, Reshape, Add, concatenate, Activation, BatchNormalization 5 | 6 | flags = tf.app.flags 7 | FLAGS = flags.FLAGS 8 | 9 | 10 | class DDNE(): 11 | 12 | def __init__(self, placeholders, input_shape, **kwargs): 13 | 14 | allowed_kwargs = {'name', 'logging'} 15 | for kwarg in kwargs.keys(): 16 | assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg 17 | name = kwargs.get('name') 18 | if not name: 19 | name = self.__class__.__name__.lower() 20 | self.name = name 21 | 22 | self.input_s = placeholders['input_s'] 23 | self.input_t = placeholders['input_t'] 24 | self.weight = placeholders['weight'] 25 | self.input_shape = input_shape 26 | self.output_dim = self.input_shape[-1] 27 | self.output_s = [] 28 | self.output_t = [] 29 | self.diff = None 30 | self.historical_len = 2 31 | self.loss = 0 32 | self.layers = [] 33 | self.gradients = 0 34 | self.placeholders = placeholders 35 | self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) 36 | self.opt_op = None 37 | # self.trainable= placeholders['in_train_mode'] 38 | 39 | self.build() 40 | 41 | 42 | def build(self): 43 | 44 | gru = GRU(units=self.input_shape[-1], return_sequences=False, name='gru') 45 | bn = BatchNormalization() 46 | decoder_1 = Dense(units=128, activation='sigmoid', name='decoder-1') 47 | decoder_2 = Dense(units=self.output_dim, activation='sigmoid',name='decoder-2') 48 | self.layers.append(gru) 49 | self.layers.append(decoder_1) 50 | self.layers.append(decoder_2) 51 | 52 | activations = {} 53 | for i in range(self.historical_len): 54 | s_input = self.input_s[i] 55 | t_input = self.input_t[i] 56 | if i == 0: 57 | # hs = bn(s_input) 58 | hs = gru(s_input) 59 | activations['s-r-{}-o'.format(i)] = hs 60 | hs = Reshape((1, -1))(hs) 61 | activations['s-r-{}'.format(i)] = hs 62 | 63 | # ht = bn(t_input) 64 | ht = gru(t_input) 65 | activations['t-r-{}-o'.format(i)] = ht 66 | ht = Reshape((1, -1))(ht) 67 | activations['t-r-{}'.format(i)] = ht 68 | 69 | else: 70 | # hs = bn(s_input) 71 | hs = gru(Add()([activations['s-r-{}-o'.format(i - 1)], s_input])) 72 | activations['s-r-{}-o'.format(i)] = hs 73 | hs = Reshape((1, -1))(hs) 74 | activations['s-r-{}'.format(i)] = hs 75 | 76 | # ht = bn(t_input) 77 | ht = gru(Add()([activations['t-r-{}-o'.format(i - 1)], t_input])) 78 | activations['s-r-{}-o'.format(i)] = ht 79 | ht = Reshape((1, -1))(ht) 80 | activations['t-r-{}'.format(i)] = ht 81 | 82 | for i in range(self.historical_len): 83 | idx = self.historical_len - 1 - i 84 | s_input = self.input_s[idx] 85 | t_input = self.input_t[idx] 86 | if i == 0: 87 | # hs = bn(s_input) 88 | hs = gru(s_input) 89 | activations['s-l-{}-o'.format(i)] = hs 90 | hs = Reshape((1, -1))(hs) 91 | activations['s-l-{}'.format(i)] = hs 92 | 93 | # ht = bn(t_input) 94 | ht = gru(t_input) 95 | activations['t-l-{}-o'.format(i)] = ht 96 | ht = Reshape((1, -1))(ht) 97 | activations['t-l-{}'.format(i)] = ht 98 | 99 | else: 100 | # hs = bn(s_input) 101 | hs = gru(Add()([activations['s-l-{}-o'.format(i - 1)], s_input])) 102 | activations['s-l-{}-o'.format(i)] = hs 103 | hs = Reshape((1, -1))(hs) 104 | activations['s-l-{}'.format(i)] = hs 105 | 106 | # ht = bn(t_input) 107 | ht = gru(Add()([activations['t-r-{}-o'.format(i - 1)], t_input])) 108 | activations['s-l-{}-o'.format(i)] = ht 109 | ht = Reshape((1, -1))(ht) 110 | activations['t-l-{}'.format(i)] = ht 111 | 112 | s_h_r = concatenate([activations['s-r-{}'.format(i)] for i in range(self.historical_len)], axis=1) 113 | s_h_l = concatenate([activations['s-l-{}'.format(i)] for i in range(self.historical_len)], axis=1) 114 | t_h_r = concatenate([activations['t-r-{}'.format(i)] for i in range(self.historical_len)], axis=1) 115 | t_h_l = concatenate([activations['t-l-{}'.format(i)] for i in range(self.historical_len)], axis=1) 116 | 117 | s_c = concatenate([s_h_r, s_h_l], axis=1) 118 | t_c = concatenate([t_h_r, t_h_l], axis=1) 119 | s_c = Activation('relu')(Reshape((-1,))(s_c)) 120 | t_c = Activation('relu')(Reshape((-1,))(t_c)) 121 | 122 | # diff = tf.square(tf.subtract(s_c, t_c)) 123 | # tmp = diff.get_shape().as_list()[-1] 124 | # self.weight = RepeatVector(tmp)(self.weight) 125 | # self.weight = tf.reshape(self.weight, [-1, tmp]) 126 | # self.diff = tf.multiply(diff, self.weight) 127 | 128 | self.output_s = decoder_2(decoder_1(s_c)) 129 | self.output_t = decoder_2(decoder_1(t_c)) 130 | 131 | self.variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 132 | self.vars = {var.name: var for var in self.variables} 133 | 134 | self._loss() 135 | self._gradients() 136 | 137 | self.opt_op = self.optimizer.minimize(self.loss) 138 | 139 | def _loss(self): 140 | for layer in self.layers: 141 | for weight in layer.get_weights(): 142 | self.loss += FLAGS.weight_decay*tf.nn.l2_loss(weight) 143 | reconstruction_loss = build_reconstruction_loss(10) 144 | 145 | self.loss += reconstruction_loss(self.placeholders['true_s'], self.output_s) + \ 146 | reconstruction_loss(self.placeholders['true_t'], self.output_t) 147 | # 0*edge_wise_loss(self.diff) 148 | 149 | 150 | def _gradients(self): 151 | self.gradients = tf.gradients(self.loss, self.input_s) 152 | 153 | def predict(self, x): 154 | pass 155 | 156 | def save(self, sess=None): 157 | if not sess: 158 | raise AttributeError("TensorFlow session not provided.") 159 | saver = tf.train.Saver(self.vars) 160 | save_path = saver.save(sess, "tmp/%s.ckpt" % self.name) 161 | print("Model saved in file: %s" % save_path) 162 | 163 | def load(self, sess=None): 164 | if not sess: 165 | raise AttributeError("TensorFlow session not provided.") 166 | saver = tf.train.Saver(self.vars) 167 | save_path = "tmp/%s.ckpt" % self.name 168 | saver.restore(sess, save_path) 169 | print("Model restored from file: %s" % save_path) --------------------------------------------------------------------------------