├── README.md ├── code ├── main.py ├── model │ ├── aggregator.py │ ├── framework.py │ └── score_function.py ├── test.py └── utils │ ├── data_helper.py │ ├── link_prediction.py │ └── triplet_classify.py ├── configs ├── lan_LinkPredict.sh ├── lan_TripletClassify.sh ├── lstm_LinkPredict.sh ├── lstm_TripletClassify.sh ├── mean_LinkPredict.sh └── mean_TripletClassify.sh ├── data └── fb15K.zip └── run.sh /README.md: -------------------------------------------------------------------------------- 1 | # Logic Attention Network (LAN) 2 | 3 | This is a TensorFlow implementation of [Logic Attention Based Neighborhood Aggregation for Inductive Knowledge Graph Embedding](https://arxiv.org/pdf/1811.01399.pdf) (AAAI 2019). 4 | 5 | ## Run the code 6 | 7 | ```bash 8 | $ ./run.sh ./configs/model_name.sh 9 | ``` 10 | 11 | ## Cite 12 | 13 | ``` 14 | @inproceedings{wang2019logic, 15 | title={Logic attention based neighborhood aggregation for inductive knowledge graph embedding}, 16 | author={Wang, Peifeng and Han, Jialong and Li, Chenliang and Pan, Rong}, 17 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 18 | volume={33}, 19 | pages={7152--7159}, 20 | year={2019} 21 | } 22 | ``` 23 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import datetime 4 | import time 5 | import random 6 | import numpy as np 7 | import logging 8 | import sys 9 | 10 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 11 | import tensorflow as tf 12 | from model.framework import FrameWork 13 | from utils.data_helper import DataSet 14 | from utils.triplet_classify import run_triplet_classify 15 | from utils.link_prediction import run_link_prediction 16 | 17 | logger = logging.getLogger() 18 | 19 | def run_training(args): 20 | # ----------------------------------------------------- # 21 | 22 | # gpu setting 23 | os.environ.setdefault('CUDA_VISIBLE_DEVICES', args.gpu_device) 24 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction) 25 | 26 | session_conf = tf.ConfigProto( 27 | gpu_options=gpu_options, 28 | allow_soft_placement=args.allow_soft_placement, 29 | log_device_placement=False) 30 | sess = tf.Session(config=session_conf) 31 | 32 | # Checkpoint directory 33 | checkpoint_dir = args.save_dir 34 | checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') 35 | if not os.path.exists(checkpoint_dir): 36 | os.makedirs(checkpoint_dir) 37 | 38 | # log file 39 | logger.setLevel(logging.INFO) 40 | handler = logging.FileHandler(checkpoint_dir+'train.log', 'w') 41 | handler.setLevel(logging.INFO) 42 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S') 43 | handler.setFormatter(formatter) 44 | logger.addHandler(handler) 45 | 46 | logger.info('args: {}'.format(args)) 47 | 48 | # ----------------------------------------------------- # 49 | 50 | # prepare data 51 | logger.info("Loading data...") 52 | dataset = DataSet(args, logger) 53 | logger.info("Loading finish...") 54 | 55 | model = FrameWork(args, dataset.num_training_entity, dataset.num_relation) 56 | 57 | # saver for checkpoint and initialization 58 | saver = tf.train.Saver(max_to_keep=1) 59 | sess.run(tf.global_variables_initializer()) 60 | 61 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 62 | if checkpoint_file != None: 63 | logger.info('Restore the model from {}'.format(checkpoint_file)) 64 | saver.restore(sess, checkpoint_file) 65 | logger.info('Start testing on checkpoints') 66 | if dataset.task == 'triplet_classify': 67 | run_triplet_classify(args, sess, model, dataset, 0, logger, is_test=True) 68 | else: 69 | run_link_prediction(args, sess, model, dataset, 0, logger, is_test=True) 70 | logger.info('Testing finish') 71 | return 72 | 73 | # ----------------------------------------------------- # 74 | 75 | # training 76 | num_batch = dataset.num_sample // args.batch_size 77 | logger.info('Train with {} batches'.format(num_batch)) 78 | 79 | best_performance = 0. 80 | for epoch in xrange(args.num_epoch): 81 | st_epoch = time.time() 82 | loss_epoch = 0. 83 | cnt_batch = 0 84 | for batch_data in dataset.batch_iter_epoch(dataset.triplets_train, args.batch_size, args.n_neg): 85 | st_batch = time.time() 86 | batch_weight_ph, batch_weight_pt, batch_weight_nh, batch_weight_nt, batch_positive, batch_negative, batch_relation_ph, batch_relation_pt, batch_relation_nh, batch_relation_nt, batch_neighbor_hp, batch_neighbor_tp, batch_neighbor_hn, batch_neighbor_tn = batch_data 87 | # batch_positive, batch_negative, batch_relation_ph, batch_relation_pt, batch_relation_nh, batch_relation_nt, batch_neighbor_hp, batch_neighbor_tp, batch_neighbor_hn, batch_neighbor_tn = batch_data 88 | feed_dict = { 89 | model.neighbor_head_pos: batch_neighbor_hp, 90 | model.neighbor_tail_pos: batch_neighbor_tp, 91 | model.neighbor_head_neg: batch_neighbor_hn, 92 | model.neighbor_tail_neg: batch_neighbor_tn, 93 | model.input_relation_ph: batch_relation_ph, 94 | model.input_relation_pt: batch_relation_pt, 95 | model.input_relation_nh: batch_relation_nh, 96 | model.input_relation_nt: batch_relation_nt, 97 | model.input_triplet_pos: batch_positive, 98 | model.input_triplet_neg: batch_negative, 99 | model.neighbor_weight_ph: batch_weight_ph, 100 | model.neighbor_weight_pt: batch_weight_pt, 101 | model.neighbor_weight_nh: batch_weight_nh, 102 | model.neighbor_weight_nt: batch_weight_nt 103 | } 104 | 105 | _, loss_batch, _step = sess.run( 106 | [model.train_op, model.loss, model.global_step], 107 | feed_dict=feed_dict 108 | ) 109 | cnt_batch += 1 110 | loss_epoch += loss_batch 111 | en_batch = time.time() 112 | 113 | # print an overview every some batches 114 | if (cnt_batch+1) % args.steps_per_display == 0 or (cnt_batch+1) == num_batch: 115 | logger.info('epoch {}, batch {}, loss: {:.3f}, time: {:.3f}s'.format( 116 | epoch, cnt_batch, loss_batch, en_batch - st_batch)) 117 | 118 | en_epoch = time.time() 119 | logger.info('epoch {}, mean loss: {:.3f}, time: {:.3f}s'.format( 120 | epoch, 121 | loss_epoch / cnt_batch, 122 | en_epoch - st_epoch 123 | )) 124 | 125 | # evaluate the model every some steps 126 | if (epoch + 1) % args.epoch_per_checkpoint == 0 or (epoch + 1) == args.num_epoch: 127 | st_test = time.time() 128 | if dataset.task == 'triplet_classify': 129 | performance = run_triplet_classify(args, sess, model, dataset, epoch, logger, is_test=False) 130 | else: 131 | performance = run_link_prediction(args, sess, model, dataset, epoch, logger, is_test=False) 132 | if performance > best_performance: 133 | best_performance = performance 134 | save_path = saver.save(sess, checkpoint_prefix, global_step=epoch) 135 | time_str = datetime.datetime.now().isoformat() 136 | logger.info('{}: model at epoch {} save in file {}'.format(time_str, epoch, save_path)) 137 | en_test = time.time() 138 | logger.info('testing finished with time: {:.3f}s'.format(en_test - st_test)) 139 | 140 | logger.info('Training finished') 141 | # if dataset.task == 'triplet_classify': 142 | # run_triplet_classify(args, sess, model, dataset, epoch, logger, is_test=True) 143 | # else: 144 | # run_link_prediction(args, sess, model, dataset, epoch, logger, is_test=True) 145 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 146 | logger.info('Restore the model from {}'.format(checkpoint_file)) 147 | saver.restore(sess, checkpoint_file) 148 | st_test = time.time() 149 | if dataset.task == 'triplet_classify': 150 | run_triplet_classify(args, sess, model, dataset, epoch, logger, is_test=True) 151 | else: 152 | run_link_prediction(args, sess, model, dataset, epoch, logger, is_test=True) 153 | en_test = time.time() 154 | logger.info('Testing finished with time: {:.3f}s'.format(en_test - st_test)) 155 | 156 | def main(): 157 | parser = argparse.ArgumentParser(description='Run training zero-shot KB model.') 158 | parser.add_argument('--data_dir', '-D', type=str) 159 | parser.add_argument('--save_dir', '-S', type=str) 160 | 161 | # model 162 | parser.add_argument('--use_relation', type=int, default=0) 163 | parser.add_argument('--embedding_dim', '-e', type=int, default=50) 164 | parser.add_argument('--max_neighbor', type=int, default=64) 165 | parser.add_argument('--n_neg', '-n', type=int, default=1) 166 | parser.add_argument('--aggregate_type', type=str, default='gnn_mean') 167 | parser.add_argument('--score_function', type=str, default='TransE') 168 | parser.add_argument('--loss_function', type=str, default='margin') 169 | parser.add_argument('--margin', type=float, default='1.0') 170 | parser.add_argument('--corrupt_mode', type=str, default='both') 171 | 172 | # training 173 | parser.add_argument('--learning_rate', type=float, default=1e-3) 174 | parser.add_argument('--num_epoch', type=int, default=1) 175 | parser.add_argument('--weight_decay', '-w', type=float, default=0.0) 176 | parser.add_argument('--batch_size', type=int, default=128) 177 | parser.add_argument('--evaluate_size', type=int, default=1000) 178 | parser.add_argument('--steps_per_display', type=int, default=100) 179 | parser.add_argument('--epoch_per_checkpoint', type=int, default=50) 180 | 181 | # gpu option 182 | parser.add_argument('--gpu_fraction', type=float, default=0.2) 183 | parser.add_argument('--gpu_device', type=str, default='0') 184 | parser.add_argument('--allow_soft_placement', type=bool, default=False) 185 | 186 | args = parser.parse_args() 187 | 188 | run_training(args=args) 189 | 190 | if __name__ == '__main__': 191 | main() -------------------------------------------------------------------------------- /code/model/aggregator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class GNN_MEAN(object): 5 | def __init__(self, num_relation, embedding_dim): 6 | self.embedding_dim = embedding_dim 7 | self.num_relation = num_relation 8 | self.mlp_w = tf.get_variable( 9 | name='mlp_w', 10 | shape=[self.num_relation * 2 + 1, self.embedding_dim], 11 | initializer=tf.contrib.layers.xavier_initializer(uniform=True) 12 | ) 13 | 14 | self.l2_regularization = tf.nn.l2_loss(self.mlp_w) 15 | 16 | def _projection(self, e, n): 17 | norm = tf.nn.l2_normalize(n, 2) 18 | return e - tf.reduce_sum(e * norm, 2, keep_dims = True) * norm 19 | 20 | def __call__(self, input, input_relation): 21 | # input: [batch, len, emb] 22 | projection = tf.nn.embedding_lookup(self.mlp_w, input_relation) 23 | output = self._projection(input, projection) 24 | return tf.reduce_mean(output, -2), 0 25 | 26 | class LSTM(object): 27 | def __init__(self, num_relation, embedding_dim): 28 | self.embedding_dim = embedding_dim 29 | self.num_relation = num_relation 30 | self.mlp_w = tf.get_variable( 31 | name='mlp_w', 32 | shape=[self.num_relation * 2 + 1, self.embedding_dim], 33 | initializer=tf.contrib.layers.xavier_initializer(uniform=True) 34 | ) 35 | 36 | self.rnn_cell = tf.nn.rnn_cell.LSTMCell(self.embedding_dim) 37 | self.l2_regularization = tf.nn.l2_loss(self.mlp_w) 38 | 39 | def _projection(self, e, n): 40 | norm = tf.nn.l2_normalize(n, 2) 41 | return e - tf.reduce_sum(e * norm, 2, keep_dims = True) * norm 42 | 43 | def __call__(self, input, input_relation): 44 | # input: [batch, len, emb] 45 | input_shape = input.shape 46 | batch_size = input_shape[0].value 47 | max_len = input_shape[1].value 48 | hidden_size = input_shape[2].value 49 | 50 | projection = tf.nn.embedding_lookup(self.mlp_w, input_relation) 51 | hidden = self._projection(input, projection) 52 | 53 | outputs, state = tf.nn.dynamic_rnn( 54 | self.rnn_cell, 55 | hidden, 56 | dtype=tf.float32) 57 | 58 | return tf.reduce_mean(outputs, -2), None 59 | 60 | class ATTENTION(object): 61 | """docstring for Attention""" 62 | def __init__(self, num_relation, num_entity, embedding_dim): 63 | self.embedding_dim = embedding_dim 64 | self.num_relation = num_relation 65 | self.num_entity = num_entity 66 | 67 | self.mlp_w = tf.get_variable( 68 | name='mlp_w', 69 | shape=[self.num_relation * 2 + 1, self.embedding_dim], 70 | initializer=tf.contrib.layers.xavier_initializer(uniform=False) 71 | ) 72 | 73 | self.att_w = tf.get_variable( 74 | name='att_w', 75 | shape=[self.embedding_dim * 2, self.embedding_dim * 2], 76 | initializer=tf.contrib.layers.xavier_initializer(uniform=False) 77 | ) 78 | 79 | self.att_v = tf.get_variable( 80 | name='att_v', 81 | shape=[self.embedding_dim * 2], 82 | initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 83 | 84 | self.att_b = tf.get_variable( 85 | name='att_b', 86 | shape=[self.embedding_dim * 2], 87 | initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 88 | 89 | self.query_relation_embedding = tf.get_variable( 90 | name='query_relation_embedding', 91 | shape=[self.num_relation * 2, self.embedding_dim], 92 | initializer=tf.contrib.layers.xavier_initializer(uniform=False) 93 | ) 94 | 95 | self.l2_regularization = tf.nn.l2_loss(self.att_w) \ 96 | + tf.nn.l2_loss(self.mlp_w) \ 97 | + tf.nn.l2_loss(self.query_relation_embedding) \ 98 | + tf.nn.l2_loss(self.att_v) 99 | # + tf.nn.l2_loss(self.att_b) \ 100 | 101 | self.mask_emb = tf.concat([tf.ones([self.num_entity, 1]), tf.zeros([1, 1])], 0) 102 | self.mask_weight = tf.concat([tf.zeros([self.num_entity, 1]), tf.ones([1, 1])*1e19], 0) 103 | 104 | def _projection(self, e, n): 105 | norm = tf.nn.l2_normalize(n, 2) 106 | return e - tf.reduce_sum(e * norm, 2, keep_dims = True) * norm 107 | 108 | def mlp(self, query, input, max_len): 109 | # query = tf.reshape(query, [-1, 1, self.embedding_dim]) 110 | # query = tf.tile(query, [1, max_len, 1]) 111 | 112 | hidden = tf.concat([query, input], 2) 113 | hidden = tf.reshape(hidden, [-1, self.embedding_dim * 2]) 114 | hidden = tf.tanh(tf.matmul(hidden, self.att_w)) 115 | hidden = tf.reshape(hidden, [-1, max_len, self.embedding_dim * 2]) 116 | attention_logit = tf.reduce_sum(hidden * self.att_v, axis=2) 117 | return attention_logit 118 | 119 | def __call__(self, input, neighbor, query_relation_id, weight): 120 | input_shape = input.shape 121 | max_len = input_shape[1].value 122 | hidden_size = input_shape[2].value 123 | 124 | input_relation = neighbor[:, :, 0] 125 | input_entity = neighbor[:, :, 1] 126 | 127 | # [batch, len, emb] -> [batch * len, emb] 128 | projection = tf.nn.embedding_lookup(self.mlp_w, input_relation) 129 | projection = self._projection(input, projection) 130 | mask = tf.nn.embedding_lookup(self.mask_emb, input_entity) 131 | projection = projection * mask 132 | 133 | # query: [batch, emb] 134 | query_relation = tf.nn.embedding_lookup(self.query_relation_embedding, query_relation_id) 135 | query_relation = tf.reshape(query_relation, [-1, 1, self.embedding_dim]) 136 | query_relation = tf.tile(query_relation, [1, max_len, 1]) 137 | 138 | # attention weight 139 | attention_logit = self.mlp(query_relation, projection, max_len) 140 | mask_logit = tf.nn.embedding_lookup(self.mask_weight, input_entity) 141 | attention_logit -= tf.reshape(mask_logit, [-1, max_len]) 142 | attention_weight = tf.nn.softmax(attention_logit) 143 | attention_weight += weight[:, :, 0] / (weight[:, :, 1] + 1) 144 | 145 | # output 146 | attention_weight = tf.reshape(attention_weight, [-1, max_len, 1]) 147 | output = tf.reduce_sum(projection * attention_weight, axis=1) 148 | attention_weight = tf.reshape(attention_weight, [-1, max_len]) 149 | return output, attention_weight 150 | 151 | -------------------------------------------------------------------------------- /code/model/framework.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class FrameWork(object): 5 | def __init__(self, args, num_entity, num_relation): 6 | self.max_neighbor = args.max_neighbor 7 | self.embedding_dim = args.embedding_dim 8 | self.learning_rate = args.learning_rate 9 | self.aggregate_type = args.aggregate_type 10 | self.score_function = args.score_function 11 | self.loss_function = args.loss_function 12 | self.use_relation = args.use_relation 13 | self.margin = args.margin 14 | self.weight_decay = args.weight_decay 15 | 16 | self.num_entity = num_entity 17 | self.num_relation = num_relation 18 | 19 | with tf.variable_scope('input'): 20 | 21 | self.neighbor_weight_ph = tf.placeholder( 22 | dtype=tf.float32, 23 | shape=[None, self.max_neighbor, 2], 24 | name='neighbor_weight_ph') 25 | 26 | self.neighbor_weight_pt = tf.placeholder( 27 | dtype=tf.float32, 28 | shape=[None, self.max_neighbor, 2], 29 | name='neighbor_weight_pt') 30 | 31 | self.neighbor_weight_nh = tf.placeholder( 32 | dtype=tf.float32, 33 | shape=[None, self.max_neighbor, 2], 34 | name='neighbor_weight_nh') 35 | 36 | self.neighbor_weight_nt = tf.placeholder( 37 | dtype=tf.float32, 38 | shape=[None, self.max_neighbor, 2], 39 | name='neighbor_weight_nt') 40 | 41 | self.neighbor_head_pos = tf.placeholder( 42 | dtype=tf.int32, 43 | shape=[None, self.max_neighbor, 2], 44 | name='neighbor_head_pos') 45 | 46 | self.neighbor_head_neg = tf.placeholder( 47 | dtype=tf.int32, 48 | shape=[None, self.max_neighbor, 2], 49 | name='neighbor_head_neg') 50 | 51 | self.neighbor_tail_pos = tf.placeholder( 52 | dtype=tf.int32, 53 | shape=[None, self.max_neighbor, 2], 54 | name='neighbor_tail_pos') 55 | 56 | self.neighbor_tail_neg = tf.placeholder( 57 | dtype=tf.int32, 58 | shape=[None, self.max_neighbor, 2], 59 | name='neighbor_tail_neg') 60 | 61 | self.input_relation_ph = tf.placeholder( 62 | dtype=tf.int32, 63 | shape=[None], 64 | name='relation_head_pos') 65 | 66 | self.input_relation_pt = tf.placeholder( 67 | dtype=tf.int32, 68 | shape=[None], 69 | name='relation_tail_pos') 70 | 71 | self.input_relation_nh = tf.placeholder( 72 | dtype=tf.int32, 73 | shape=[None], 74 | name='relation_head_neg') 75 | 76 | self.input_relation_nt = tf.placeholder( 77 | dtype=tf.int32, 78 | shape=[None], 79 | name='relation_tail_neg') 80 | 81 | self.input_triplet_pos = tf.placeholder( 82 | dtype=tf.int32, 83 | shape=[None, 3], 84 | name='input_triplet_pos') 85 | 86 | self.input_triplet_neg = tf.placeholder( 87 | dtype=tf.int32, 88 | shape=[None, 3], 89 | name='input_triplet_neg') 90 | 91 | self.embedding_placeholder = tf.placeholder( 92 | dtype=tf.float32, 93 | shape=[self.num_entity + 1, self.embedding_dim], 94 | name='embedding_placeholder' 95 | ) 96 | 97 | 98 | with tf.variable_scope('embeddings'): 99 | self.entity_embedding = tf.get_variable( 100 | name='entity_embedding', 101 | shape=[self.num_entity + 1, self.embedding_dim], 102 | initializer=tf.contrib.layers.xavier_initializer(uniform=False) 103 | ) 104 | 105 | self.relation_embedding_out = tf.get_variable( 106 | name='relation_embedding_out', 107 | shape=[self.num_relation, self.embedding_dim], 108 | initializer=tf.contrib.layers.xavier_initializer(uniform=False) 109 | ) 110 | 111 | # self.entity_embedding_init = self.entity_embedding.assign(self.embedding_placeholder) 112 | 113 | # get head, tail, relation embedded 114 | encoder = None 115 | if self.aggregate_type == 'gnn_mean': 116 | from aggregator import GNN_MEAN as Encoder 117 | encoder = Encoder(self.num_relation, self.embedding_dim) 118 | elif self.aggregate_type == 'lstm': 119 | from aggregator import LSTM as Encoder 120 | encoder = Encoder(self.num_relation, self.embedding_dim) 121 | elif self.aggregate_type == 'attention': 122 | from aggregator import ATTENTION as Encoder 123 | encoder = Encoder(self.num_relation, self.num_entity, self.embedding_dim) 124 | else: 125 | print 'Not emplemented yet!' 126 | assert encoder != None 127 | 128 | # aggregate on neighbors input 129 | head_pos_embedded, self.weight_ph = self.aggregate(encoder, self.neighbor_head_pos, self.input_relation_ph, self.neighbor_weight_ph) 130 | tail_pos_embedded, self.weight_pt = self.aggregate(encoder, self.neighbor_tail_pos, self.input_relation_pt, self.neighbor_weight_pt) 131 | 132 | head_neg_embedded, _ = self.aggregate(encoder, self.neighbor_head_neg, self.input_relation_nh, self.neighbor_weight_nh) 133 | tail_neg_embedded, _ = self.aggregate(encoder, self.neighbor_tail_neg, self.input_relation_nt, self.neighbor_weight_nt) 134 | 135 | # get score 136 | decoder = None 137 | if self.score_function == 'TransE': 138 | from score_function import TransE as Decoder 139 | decoder = Decoder() 140 | elif self.score_function == 'Distmult': 141 | from score_function import Distmult as Decoder 142 | decoder = Decoder() 143 | elif self.score_function == 'Complex': 144 | from score_function import Complex as Decoder 145 | decoder = Decoder(self.embedding_dim) 146 | elif self.score_function == 'Analogy': 147 | from score_function import Analogy as Decoder 148 | decoder = Decoder(self.embedding_dim) 149 | else: 150 | print 'Not emplemented yet!' 151 | assert decoder != None 152 | 153 | emb_relation_pos_out = tf.nn.embedding_lookup(self.relation_embedding_out, self.input_relation_ph) 154 | emb_relation_neg_out = tf.nn.embedding_lookup(self.relation_embedding_out, self.input_relation_nh) 155 | 156 | self.positive_score = self.score_triplet(decoder, head_pos_embedded, tail_pos_embedded, emb_relation_pos_out) 157 | negative_score = self.score_triplet(decoder, head_neg_embedded, tail_neg_embedded, emb_relation_neg_out) 158 | 159 | ph_origin_embedded = tf.nn.embedding_lookup(self.entity_embedding, self.input_triplet_pos[:, 0]) 160 | pt_origin_embedded = tf.nn.embedding_lookup(self.entity_embedding, self.input_triplet_pos[:, 2]) 161 | nh_origin_embedded = tf.nn.embedding_lookup(self.entity_embedding, self.input_triplet_neg[:, 0]) 162 | nt_origin_embedded = tf.nn.embedding_lookup(self.entity_embedding, self.input_triplet_neg[:, 2]) 163 | 164 | origin_positive_score = self.score_triplet(decoder, ph_origin_embedded, pt_origin_embedded, emb_relation_pos_out) 165 | origin_negative_score = self.score_triplet(decoder, nh_origin_embedded, nt_origin_embedded, emb_relation_neg_out) 166 | 167 | loss = None 168 | if self.loss_function == 'margin': 169 | loss = tf.reduce_mean(tf.nn.relu(self.margin - self.positive_score + negative_score)) 170 | loss += tf.reduce_mean(tf.nn.relu(self.margin - origin_positive_score + origin_negative_score)) 171 | loss += self.weight_decay * encoder.l2_regularization 172 | elif self.loss_function == 'bce': 173 | labels_positive = tf.ones_like(self.positive_score) 174 | labels_negative = tf.zeros_like(negative_score) 175 | 176 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 177 | labels=labels_positive, 178 | logits=self.positive_score)) 179 | loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 180 | labels=labels_negative, 181 | logits=negative_score)) 182 | loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 183 | labels=labels_positive, 184 | logits=origin_positive_score)) 185 | loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 186 | labels=labels_negative, 187 | logits=origin_negative_score)) 188 | loss += self.weight_decay * encoder.l2_regularization 189 | else: 190 | print 'Not such loss!' 191 | assert loss != None 192 | 193 | self.loss = loss 194 | 195 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 196 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 197 | self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step) 198 | 199 | def aggregate(self, encoder, neighbor_ids, query_relation, weight): 200 | neighbor_embedded = tf.nn.embedding_lookup(self.entity_embedding, neighbor_ids[:, :, 1]) 201 | if self.use_relation == 1: 202 | return encoder(neighbor_embedded, neighbor_ids, query_relation, weight) 203 | else: 204 | return encoder(neighbor_embedded, neighbor_ids[:, :, 0]) 205 | 206 | def score_triplet(self, decoder, head_embedded, tail_embedded, relation_embedded): 207 | score = decoder(head_embedded, tail_embedded, relation_embedded) 208 | return score 209 | -------------------------------------------------------------------------------- /code/model/score_function.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class TransE(object): 5 | """docstring for TransE""" 6 | def __init__(self): 7 | super(TransE, self).__init__() 8 | 9 | def __call__(self, head, tail, relation): 10 | # head = tf.clip_by_norm(head, 1.0, 1) 11 | head = tf.nn.l2_normalize(head, 1) 12 | relation = tf.nn.l2_normalize(relation, 1) 13 | tail = tf.nn.l2_normalize(tail, 1) 14 | # head = tf.nn.tanh(head) 15 | # relation = tf.nn.tanh(relation) 16 | # tail = tf.nn.tanh(tail) 17 | dissimilarity = tf.reduce_sum(tf.abs(head + relation - tail), 1) 18 | score = -dissimilarity 19 | return score 20 | 21 | class Distmult(object): 22 | def __init__(self): 23 | super(Distmult, self).__init__() 24 | 25 | def __call__(self, head, tail, relation): 26 | score = tf.reduce_sum(head * relation * tail, 1) 27 | return score 28 | 29 | class Dotproduct(object): 30 | """docstring for Dotproduct""" 31 | def __init__(self): 32 | super(Dotproduct, self).__init__() 33 | 34 | def __call__(self, head, tail, relation): 35 | score = tf.reduce_sum(head * tail, 1) 36 | return score 37 | 38 | class TransL(object): 39 | def __init__(self): 40 | super(TransL, self).__init__() 41 | 42 | def __call__(self, head, tail, relation): 43 | head = tf.nn.l2_normalize(head, 1) 44 | tail = tf.nn.l2_normalize(tail, 1) 45 | score = tf.reduce_sum(tf.abs(head - tail), 1) 46 | return -score 47 | 48 | class Complex(object): 49 | """docstring for Complex""" 50 | def __init__(self, embedding_dimension): 51 | super(Complex, self).__init__() 52 | self.embedding_dimension = embedding_dimension 53 | 54 | def __call__(self, head, tail, relation): 55 | offset = self.embedding_dimension / 2 56 | h1 = head[:, 0:offset] 57 | h2 = head[:, offset:] 58 | r1 = relation[:, 0:offset] 59 | r2 = relation[:, offset:] 60 | t1 = tail[:, 0:offset] 61 | t2 = tail[:, offset:] 62 | 63 | score = tf.reduce_sum(h1 * t1 * r1 + h2 * t2 * r1 + h1 * t2 * r2 - h2 * t1 * r2, 1) 64 | return score 65 | 66 | class Analogy(object): 67 | def __init__(self, embedding_dimension): 68 | super(Analogy, self).__init__() 69 | self.embedding_dimension = embedding_dimension 70 | 71 | def __call__(self, head, tail, relation): 72 | offset = self.embedding_dimension / 4 73 | score = tf.reduce_sum(relation[:, 0:offset*2]*head[:, 0:offset*2]*tail[:,0:offset*2], 1) 74 | h1 = head[:, offset*2:offset*3] 75 | h2 = head[:, offset*3:] 76 | r1 = relation[:, offset*2:offset*3] 77 | r2 = relation[:, offset*3:] 78 | t1 = tail[:, offset*2:offset*3] 79 | t2 = tail[:, offset*3:] 80 | 81 | score += tf.reduce_sum(h1 * t1 * r1 + h2 * t2 * r1 + h1 * t2 * r2 - h2 * t1 * r2, 1) 82 | return score 83 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import datetime 4 | import time 5 | import random 6 | import numpy as np 7 | import logging 8 | import sys 9 | 10 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 11 | import tensorflow as tf 12 | from model.zskb import ZSKB 13 | from utils.data_helper import DataSet 14 | from utils.triplet_classify import run_triplet_classify 15 | from utils.link_prediction import run_link_prediction 16 | 17 | logger = logging.getLogger() 18 | 19 | def softmax(x): 20 | """Compute softmax values for each sets of scores in x.""" 21 | e_x = np.exp(x - np.max(x)) 22 | return e_x / e_x.sum() 23 | 24 | def run_training(args): 25 | # ----------------------------------------------------- # 26 | 27 | # gpu setting 28 | os.environ.setdefault('CUDA_VISIBLE_DEVICES', args.gpu_device) 29 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction) 30 | 31 | session_conf = tf.ConfigProto( 32 | gpu_options=gpu_options, 33 | allow_soft_placement=args.allow_soft_placement, 34 | log_device_placement=False) 35 | sess = tf.Session(config=session_conf) 36 | 37 | # Checkpoint directory 38 | checkpoint_dir = args.save_dir 39 | checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') 40 | assert os.path.exists(checkpoint_dir) 41 | 42 | # log file 43 | logger.setLevel(logging.INFO) 44 | handler = logging.FileHandler(checkpoint_dir+'test.log', 'w') 45 | handler.setLevel(logging.INFO) 46 | formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y/%m/%d %H:%M:%S') 47 | handler.setFormatter(formatter) 48 | logger.addHandler(handler) 49 | 50 | logger.info('args: {}'.format(args)) 51 | 52 | # ----------------------------------------------------- # 53 | 54 | # prepare data 55 | logger.info("Loading data...") 56 | dataset = DataSet(args, logger) 57 | logger.info("Loading finish...") 58 | 59 | # args.aggregate_type = 'test_attention' 60 | model = ZSKB(args, dataset.num_training_entity, dataset.num_relation) 61 | 62 | # saver for checkpoint and initialization 63 | saver = tf.train.Saver(max_to_keep=1) 64 | sess.run(tf.global_variables_initializer()) 65 | 66 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 67 | assert checkpoint_file != None 68 | logger.info('Restore the model from {}'.format(checkpoint_file)) 69 | saver.restore(sess, checkpoint_file) 70 | 71 | html_file = open('visualization_mean_tail.html', 'w') 72 | cnt_batch = 0 73 | for batch_eval in dataset.batch_iter_epoch(dataset.triplets_test, 1024, corrupt=False, shuffle=True): 74 | batch_weight_ph, batch_weight_pt, batch_triplet, batch_relation_tail, batch_neighbor_head, batch_neighbor_tail = batch_eval 75 | batch_relation_head = batch_triplet[:, 1] 76 | # batch_weight_head, batch_weight_tail, batch_attention_head, batch_attention_tail = sess.run( 77 | # [model.weight_loss_ph, model.weight_loss_pt, model.attention_logit_ph, model.attention_logit_pt], 78 | # feed_dict={ 79 | # model.neighbor_head_pos: batch_neighbor_head, 80 | # model.neighbor_tail_pos: batch_neighbor_tail, 81 | # model.input_relation_ph: batch_relation_head, 82 | # model.input_relation_pt: batch_relation_tail, 83 | # model.neighbor_weight_ph: batch_weight_ph, 84 | # model.neighbor_weight_pt: batch_weight_pt, 85 | # model.dropout_keep_prob: 1.0 86 | # }) 87 | for id_triplet in xrange(len(batch_relation_head)): 88 | try: 89 | head = dataset.i2e[batch_triplet[id_triplet][0]] 90 | except: 91 | head = 'None' 92 | try: 93 | tail = dataset.i2e[batch_triplet[id_triplet][2]] 94 | except: 95 | tail = 'None' 96 | query_relation = dataset.i2r[batch_relation_head[id_triplet]] 97 | html_file.write('
' + head + ' -> ' + query_relation + ' -> ' + tail + '
\n') 98 | # html_file.write('head
\n') 99 | 100 | # weight_head = batch_weight_head[id_triplet] 101 | # prior_weight = batch_weight_ph[id_triplet] 102 | # attention_weight = batch_attention_head[id_triplet] 103 | # # print weight_head 104 | # weight_head = weight_head / weight_head.max() 105 | # # prior_weight = softmax(np.log(prior_weight + 1e-9)) 106 | # # prior_weight = prior_weight / prior_weight.max() 107 | # attention_weight = softmax(attention_weight) 108 | # attention_weight = attention_weight / attention_weight.max() 109 | # rank_weight = (-weight_head).argsort() 110 | # neighbors = batch_neighbor_head[id_triplet] 111 | # for rank in rank_weight: 112 | # neighbor_relation = dataset.i2r[neighbors[rank][0]] 113 | # if neighbor_relation == 'PAD': 114 | # continue 115 | # try: 116 | # neighbor_entity = dataset.i2e[neighbors[rank][1]] 117 | # except: 118 | # neighbor_entity = 'None' 119 | # neighbor_weight = weight_head[rank] 120 | # prior_neighbor_weight = prior_weight[rank] 121 | # attention_neighbor_weight = attention_weight[rank] 122 | # html_file.write('%s -> %stail
\n') 126 | # weight_tail = batch_weight_tail[id_triplet] 127 | # prior_weight = batch_weight_pt[id_triplet] 128 | # attention_weight = batch_attention_tail[id_triplet] 129 | # weight_tail = weight_tail / weight_tail.max() 130 | # attention_weight = softmax(attention_weight) 131 | # attention_weight = attention_weight / attention_weight.max() 132 | # rank_weight = (-weight_tail).argsort() 133 | # neighbors = batch_neighbor_tail[id_triplet] 134 | # for rank in rank_weight: 135 | # neighbor_relation = dataset.i2r[neighbors[rank][0]] 136 | # if neighbor_relation == 'PAD': 137 | # continue 138 | # try: 139 | # neighbor_entity = dataset.i2e[neighbors[rank][1]] 140 | # except: 141 | # neighbor_entity = 'None' 142 | # neighbor_weight = weight_tail[rank] 143 | # attention_neighbor_weight = attention_weight[rank] 144 | # html_file.write('%s -> %s' + '-' * 100 + '
') 181 | 182 | cnt_batch += 1 183 | if cnt_batch > 10: 184 | break 185 | html_file.close() 186 | 187 | # ----------------------------------------------------- # 188 | def main(): 189 | parser = argparse.ArgumentParser(description='Run testing zero-shot KB model.') 190 | parser.add_argument('--data_dir', '-D', type=str) 191 | parser.add_argument('--save_dir', '-S', type=str) 192 | 193 | parser.add_argument('--pretrain', action='store_true') 194 | 195 | # model 196 | parser.add_argument('--use_relation', type=int, default=0) 197 | parser.add_argument('--embedding_dim', '-e', type=int, default=50) 198 | parser.add_argument('--max_neighbor', type=int, default=64) 199 | parser.add_argument('--n_neg', '-n', type=int, default=1) 200 | parser.add_argument('--aggregate_type', type=str, default='gnn_mean') 201 | parser.add_argument('--iter_routing', '-r', type=int, default=1) 202 | parser.add_argument('--score_function', type=str, default='TransE') 203 | parser.add_argument('--loss_function', type=str, default='margin') 204 | parser.add_argument('--margin', type=float, default='1.0') 205 | parser.add_argument('--corrupt_mode', type=str, default='both') 206 | 207 | # training 208 | parser.add_argument('--learning_rate', type=float, default=1e-3) 209 | parser.add_argument('--num_epoch', type=int, default=1) 210 | parser.add_argument('--weight_decay', '-w', type=float, default=0.0) 211 | parser.add_argument('--dropout_keep_prob', type=float, default=1.0) 212 | parser.add_argument('--dis_weight', type=float, default=0.0) 213 | parser.add_argument('--batch_size', type=int, default=128) 214 | parser.add_argument('--evaluate_size', type=int, default=1000) 215 | parser.add_argument('--steps_per_display', type=int, default=100) 216 | parser.add_argument('--epoch_per_checkpoint', type=int, default=50) 217 | 218 | # gpu option 219 | parser.add_argument('--gpu_fraction', type=float, default=0.2) 220 | parser.add_argument('--gpu_device', type=str, default='0') 221 | parser.add_argument('--allow_soft_placement', type=bool, default=False) 222 | 223 | args = parser.parse_args() 224 | 225 | run_training(args=args) 226 | 227 | if __name__ == '__main__': 228 | main() -------------------------------------------------------------------------------- /code/utils/data_helper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import numpy as np 5 | import random 6 | from collections import defaultdict 7 | import pandas as pd 8 | import math 9 | 10 | 11 | class DataSet: 12 | def __init__(self, args, logger): 13 | self.data_dir = args.data_dir 14 | self.max_neighbor = args.max_neighbor 15 | self.corrupt_mode = args.corrupt_mode 16 | self.load_data(logger) 17 | 18 | def load_data(self, logger): 19 | train_path = os.path.join(self.data_dir, 'train') 20 | dev_path = os.path.join(self.data_dir, 'dev') 21 | test_path = os.path.join(self.data_dir, 'test') 22 | aux_path = os.path.join(self.data_dir, 'aux') 23 | relation_path = os.path.join(self.data_dir, 'relation2id.txt') 24 | entity_idx_path = os.path.join(self.data_dir, 'entity2id.txt') 25 | entity_name_path = os.path.join(self.data_dir, 'entity_name.txt') 26 | 27 | triplets_train, triplets_aux, graph_train, self.num_relation, self.num_entity = self.doc_to_tensor_graph(train_path, aux_path) 28 | self.r_PAD = self.num_relation * 2 29 | self.e_PAD = self.num_training_entity 30 | self.num_sample = len(triplets_train) 31 | try: 32 | self.i2r = self.build_relation_dict(relation_path, self.r_PAD, self.num_relation) 33 | except: 34 | self.i2r = None 35 | try: 36 | self.i2e = self.build_entity_dict2(entity_idx_path, entity_name_path, self.num_training_entity) 37 | # self.i2e = self.build_entity_dict(entity_idx_path) 38 | except: 39 | self.i2e = None 40 | 41 | logger.info('got {} entities for training'.format(self.num_training_entity)) 42 | logger.info('got {} relations for training'.format(self.num_relation)) 43 | self.graph_train, self.weight_graph = self.sample_neighbor(graph_train) 44 | 45 | triplets_test = self.doc_to_tensor(test_path) 46 | triplets_dev = self.doc_to_tensor(dev_path) 47 | if len(triplets_dev[0]) == 4: 48 | self.task = 'triplet_classify' 49 | else: 50 | self.task = 'link_prediciton' 51 | # consturct answer poor for filter results 52 | self.triplets_train_pool = set(triplets_train + triplets_dev) 53 | self.triplets_true_pool = set(triplets_train + triplets_dev + triplets_test + triplets_aux) 54 | self.predict_mode = self.data_dir.split('/')[-1] 55 | 56 | self.triplets_train = np.asarray(triplets_train) 57 | self.triplets_dev = np.asarray(triplets_dev) 58 | self.triplets_test = np.asarray(triplets_test) 59 | # self.triplets_sample = np.asarray(triplets_sample) 60 | logger.info('got {} triplets for train'.format(len(self.triplets_train))) 61 | logger.info('got {} triplets for valid'.format(len(self.triplets_dev))) 62 | logger.info('got {} triplets for test'.format(len(self.triplets_test))) 63 | 64 | def count_imply(self, graph, cnt_relation): 65 | co_relation = np.zeros((cnt_relation*2+1, cnt_relation*2+1), dtype=np.dtype('float32')) 66 | freq_relation = defaultdict(int) 67 | 68 | for entity in graph: 69 | relation_list = list(set([neighbor[0] for neighbor in graph[entity]])) 70 | for n_i in xrange(len(relation_list)): 71 | r_i = relation_list[n_i] 72 | freq_relation[r_i] += 1 73 | for n_j in xrange(n_i+1, len(relation_list)): 74 | r_j = relation_list[n_j] 75 | co_relation[r_i][r_j] += 1 76 | co_relation[r_j][r_i] += 1 77 | 78 | for r_i in xrange(cnt_relation*2): 79 | co_relation[r_i] = (co_relation[r_i] * 1.0) / freq_relation[r_i] 80 | # co_relation[r_i][r_i] = 1.0 81 | self.co_relation = co_relation.transpose() 82 | for r_i in xrange(cnt_relation*2): 83 | co_relation[r_i][r_i] = co_relation[r_i].mean() 84 | print 'finish calculating co relation' 85 | 86 | def doc_to_tensor_graph(self, data_path_train, data_path_aux): 87 | triplet_train = [] 88 | triplet_aux = [] 89 | graph = defaultdict(list) 90 | train_entity = {} 91 | cnt_entity = 0 92 | cnt_relation = 0 93 | with open(data_path_train, 'rb') as fr: 94 | for line in fr: 95 | line = line.strip().split('\t') 96 | line = [int(_id) for _id in line] 97 | assert len(line) == 3 98 | head, relation, tail = line 99 | triplet_train.append((head, relation, tail)) 100 | # graph[head].append((relation, tail)) 101 | # graph[tail].append((relation, head)) 102 | train_entity[head] = 1 103 | train_entity[tail] = 1 104 | if head >= cnt_entity: 105 | cnt_entity = head + 1 106 | if tail >= cnt_entity: 107 | cnt_entity = tail + 1 108 | if relation >= cnt_relation: 109 | cnt_relation = relation + 1 110 | 111 | self.num_training_entity = cnt_entity 112 | 113 | with open(data_path_aux, 'rb') as fr: 114 | for line in fr: 115 | line = line.strip().split('\t') 116 | line = [int(_id) for _id in line] 117 | assert len(line) == 3 118 | head, relation, tail = line 119 | if relation >= cnt_relation: 120 | continue 121 | triplet_aux.append((head, relation, tail)) 122 | if head >= cnt_entity: 123 | cnt_entity = head + 1 124 | if tail >= cnt_entity: 125 | cnt_entity = tail + 1 126 | # if relation >= cnt_relation: 127 | # cnt_relation = relation + 1 128 | 129 | for triplet in triplet_train: 130 | head, relation, tail = triplet 131 | # hpt, tph = self.relation_dist[relation] 132 | graph[head].append([relation, tail, 0.]) 133 | graph[tail].append([relation+cnt_relation, head, 0.]) 134 | 135 | cnt_train = len(graph) 136 | 137 | self.count_imply(graph, cnt_relation) 138 | 139 | for triplet in triplet_aux: 140 | head, relation, tail = triplet 141 | if not head in train_entity and tail in train_entity: 142 | graph[head].append([relation, tail, 0.]) 143 | if not tail in train_entity and head in train_entity: 144 | graph[tail].append([relation+cnt_relation, head, 0.]) 145 | 146 | graph = self.process_graph(graph) 147 | cnt_all = len(graph) 148 | 149 | return triplet_train, triplet_aux, graph, cnt_relation, cnt_entity 150 | 151 | def process_graph(self, graph): 152 | for entity in graph: 153 | # relation_list = list(set([neighbor[0] for neighbor in graph[entity]])) 154 | relation_list = defaultdict(int) 155 | for neighbor in graph[entity]: 156 | relation_list[neighbor[0]] += 1 157 | if len(relation_list) == 1: 158 | continue 159 | for rel_i in relation_list: 160 | other_relation_list = [rel for rel in relation_list if rel != rel_i] 161 | imply_i = self.co_relation[rel_i] 162 | j_imply_i = imply_i[other_relation_list].max() 163 | for _idx, neighbor in enumerate(graph[entity]): 164 | if neighbor[0] == rel_i: 165 | graph[entity][_idx][2] = j_imply_i 166 | print 'finish processing graph' 167 | return graph 168 | 169 | def doc_to_tensor(self, data_path): 170 | triplet_tensor = [] 171 | with open(data_path, 'rb') as fr: 172 | for line in fr: 173 | line = line.strip().split('\t') 174 | line = [int(_id) for _id in line] 175 | if line[0] >= self.num_entity or line[2] >= self.num_entity: 176 | continue 177 | if line[1] >= self.num_relation: 178 | continue 179 | if len(line) == 4: 180 | head, relation, tail, label = line 181 | if label != 1: 182 | label = -1 183 | triplet_tensor.append((head, relation, tail, label)) 184 | else: 185 | head, relation, tail = line 186 | triplet_tensor.append((head, relation, tail)) 187 | return triplet_tensor 188 | 189 | def build_relation_dict(self, data_path, pad, cnt): 190 | i2n = {} 191 | with open(data_path, 'rb') as fr: 192 | for line in fr: 193 | line = line.strip().split('\t') 194 | # name = '/'.join(line[0].split('/')[-2:]) 195 | name = line[0] 196 | idx = int(line[1]) 197 | # if idx >= cnt: 198 | # continue 199 | i2n[idx] = name 200 | i2n[idx + cnt] = '*' + name 201 | i2n[pad] = 'PAD' 202 | return i2n 203 | 204 | def build_entity_dict(self, data_path): 205 | i2n = {} 206 | with open(data_path, 'rb') as fr: 207 | for line in fr: 208 | line = line.strip().split('\t') 209 | # name = '/'.join(line[0].split('/')[-2:]) 210 | name = line[0] 211 | idx = int(line[1]) 212 | # if idx >= cnt: 213 | # continue 214 | i2n[idx] = name 215 | i2n[self.e_PAD] = 'PAD' 216 | return i2n 217 | 218 | def build_entity_dict2(self, index_path, name_path, cnt): 219 | m2i = {} 220 | with open(index_path, 'rb') as fr: 221 | for line in fr: 222 | line = line.strip().split('\t') 223 | m2i[line[0]] = int(line[1]) 224 | 225 | i2n = {} 226 | with open(name_path, 'rb') as fr: 227 | for line in fr: 228 | line = line.strip().split('\t') 229 | try: 230 | idx = m2i[line[0]] 231 | except: 232 | continue 233 | i2n[idx] = line[1] 234 | i2n[self.e_PAD] = 'PAD' 235 | return i2n 236 | 237 | def sample_neighbor(self, graph): 238 | sample_graph = np.ones((self.num_entity, self.max_neighbor, 2), dtype=np.dtype('int64')) 239 | weight_graph = np.ones((self.num_entity, self.max_neighbor), dtype=np.dtype('float32')) 240 | sample_graph[:, :, 0] *= self.r_PAD 241 | sample_graph[:, :, 1] *= self.e_PAD 242 | 243 | cnt = 0 244 | for entity in graph: 245 | num_neighbor = len(graph[entity]) 246 | cnt += num_neighbor 247 | num_sample = min(num_neighbor, self.max_neighbor) 248 | # sample_id = random.sample(xrange(len(graph[entity])), num_sample) 249 | sample_id = range(len(graph[entity]))[:num_sample] 250 | # sample_graph[entity][:num_sample] = np.asarray(graph[entity])[sample_id] 251 | sample_graph[entity][:num_sample] = np.asarray(graph[entity])[sample_id][:, 0:2] 252 | weight_graph[entity][:num_sample] = np.asarray(graph[entity])[sample_id][:, 2] 253 | 254 | return sample_graph, weight_graph 255 | 256 | def batch_iter_epoch(self, data, batch_size, num_negative=1, corrupt=True, shuffle=True): 257 | data_size = len(data) 258 | if data_size % batch_size == 0: 259 | num_batches_per_epoch = int(data_size/batch_size) 260 | else: 261 | num_batches_per_epoch = int(data_size/batch_size) + 1 262 | # Shuffle the data at each epoch 263 | if shuffle: 264 | shuffled_indices = np.random.permutation(np.arange(data_size)) 265 | else: 266 | shuffled_indices = np.arange(data_size) 267 | for batch_num in range(num_batches_per_epoch): 268 | start_index = batch_num * batch_size 269 | end_index = min((batch_num + 1) * batch_size, data_size) 270 | real_batch_num = end_index - start_index 271 | batch_indices = shuffled_indices[start_index:end_index] 272 | batch_positive = data[batch_indices] 273 | neighbor_head_pos = self.graph_train[batch_positive[:, 0]] #[:, :, 0:2] 274 | neighbor_tail_pos = self.graph_train[batch_positive[:, 2]] #[:, :, 0:2] 275 | batch_relation_ph = np.asarray(batch_positive[:, 1]) 276 | batch_relation_pt = batch_relation_ph + self.num_relation 277 | neighbor_imply_ph = self.weight_graph[batch_positive[:, 0]].reshape(-1, self.max_neighbor, 1) 278 | neighbor_imply_pt = self.weight_graph[batch_positive[:, 2]].reshape(-1, self.max_neighbor, 1) 279 | query_weight_ph = self.co_relation[batch_relation_ph] 280 | query_weight_pt = self.co_relation[batch_relation_pt] 281 | batch_weight_ph = query_weight_ph[np.arange(real_batch_num).repeat(self.max_neighbor), neighbor_head_pos[:, :, 0].reshape(-1)].reshape(real_batch_num, self.max_neighbor, 1) 282 | batch_weight_pt = query_weight_pt[np.arange(real_batch_num).repeat(self.max_neighbor), neighbor_tail_pos[:, :, 0].reshape(-1)].reshape(real_batch_num, self.max_neighbor, 1) 283 | 284 | batch_weight_ph = np.concatenate((batch_weight_ph, neighbor_imply_ph), axis=2) 285 | batch_weight_pt = np.concatenate((batch_weight_pt, neighbor_imply_pt), axis=2) 286 | if corrupt: 287 | batch_negative = [] 288 | for triplet in batch_positive: 289 | id_head_corrupted = triplet[0] 290 | id_tail_corrupted = triplet[2] 291 | id_relation = triplet[1] 292 | 293 | for n_neg in xrange(num_negative): 294 | if self.corrupt_mode == 'both': 295 | head_prob = np.random.binomial(1, 0.5) 296 | if head_prob: 297 | id_head_corrupted = random.sample(xrange(self.num_training_entity), 1)[0] 298 | else: 299 | id_tail_corrupted = random.sample(xrange(self.num_training_entity), 1)[0] 300 | else: 301 | if 'tail' in self.predict_mode: 302 | id_head_corrupted = random.sample(xrange(self.num_training_entity), 1)[0] 303 | elif 'head' in self.predict_mode: 304 | id_tail_corrupted = random.sample(xrange(self.num_training_entity), 1)[0] 305 | batch_negative.append([id_head_corrupted, triplet[1], id_tail_corrupted]) 306 | 307 | batch_negative = np.asarray(batch_negative) 308 | neighbor_head_neg = self.graph_train[batch_negative[:, 0]] 309 | neighbor_tail_neg = self.graph_train[batch_negative[:, 2]] 310 | neighbor_imply_nh = self.weight_graph[batch_negative[:, 0]].reshape(-1, self.max_neighbor, 1) 311 | neighbor_imply_nt = self.weight_graph[batch_negative[:, 2]].reshape(-1, self.max_neighbor, 1) 312 | 313 | batch_relation_nh = batch_negative[:, 1] 314 | batch_relation_nt = batch_relation_nh + self.num_relation 315 | query_weight_nh = self.co_relation[batch_relation_nh] 316 | query_weight_nt = self.co_relation[batch_relation_nt] 317 | batch_weight_nh = query_weight_nh[np.arange(real_batch_num).repeat(self.max_neighbor), neighbor_head_neg[:, :, 0].reshape(-1)].reshape(real_batch_num, self.max_neighbor, 1) 318 | batch_weight_nt = query_weight_nt[np.arange(real_batch_num).repeat(self.max_neighbor), neighbor_tail_neg[:, :, 0].reshape(-1)].reshape(real_batch_num, self.max_neighbor, 1) 319 | batch_weight_nh = np.concatenate((batch_weight_nh, neighbor_imply_nh), axis=2) 320 | batch_weight_nt = np.concatenate((batch_weight_nt, neighbor_imply_nt), axis=2) 321 | yield [batch_weight_ph, batch_weight_pt, batch_weight_nh, batch_weight_nt, 322 | batch_positive, batch_negative, batch_relation_ph, batch_relation_pt, batch_relation_nh, batch_relation_nt, neighbor_head_pos, neighbor_tail_pos, neighbor_head_neg, neighbor_tail_neg] 323 | else: 324 | yield [batch_weight_ph, batch_weight_pt, 325 | batch_positive, batch_relation_pt, neighbor_head_pos, neighbor_tail_pos] 326 | 327 | def next_sample_eval(self, triplet_evaluate, is_test): 328 | if is_test: 329 | answer_pool = self.triplets_true_pool 330 | else: 331 | answer_pool = self.triplets_train_pool 332 | # # construct two batches for head and tail prediction 333 | batch_predict_head = [triplet_evaluate] 334 | # replacing head 335 | id_heads_corrupted_list = xrange(self.num_training_entity) 336 | id_heads_corrupted_set = set(id_heads_corrupted_list) 337 | id_heads_corrupted_set.discard(triplet_evaluate[0]) # remove the golden head 338 | for head in id_heads_corrupted_list: 339 | if (head, triplet_evaluate[1], triplet_evaluate[2]) in answer_pool: 340 | id_heads_corrupted_set.discard(head) 341 | batch_predict_head.extend([(head, triplet_evaluate[1], triplet_evaluate[2]) for head in id_heads_corrupted_set]) 342 | 343 | batch_predict_tail = [triplet_evaluate] 344 | # replacing tail 345 | # id_tails_corrupted = set(random.sample(xrange(self.num_entity), 1000)) 346 | id_tails_corrupted_list = xrange(self.num_training_entity) 347 | id_tails_corrupted_set = set(id_tails_corrupted_list) 348 | id_tails_corrupted_set.discard(triplet_evaluate[2]) # remove the golden tail 349 | for tail in id_tails_corrupted_list: 350 | if (triplet_evaluate[0], triplet_evaluate[1], tail) in answer_pool: 351 | id_tails_corrupted_set.discard(tail) 352 | batch_predict_tail.extend([(triplet_evaluate[0], triplet_evaluate[1], tail) for tail in id_tails_corrupted_set]) 353 | 354 | if 'head' in self.predict_mode: # and self.corrupt_mode == 'partial': 355 | return np.asarray(batch_predict_tail) 356 | elif 'tail' in self.predict_mode: # and self.corrupt_mode == 'partial': 357 | return np.asarray(batch_predict_head) 358 | 359 | -------------------------------------------------------------------------------- /code/utils/link_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import datetime 5 | import random 6 | import numpy as np 7 | 8 | def run_link_prediction(args, sess, model, dataset, epoch, logger, is_test=False): 9 | logger.info('evaluating the current model...') 10 | rank_head = 0 11 | hit10_head = 0 12 | hit3_head = 0 13 | max_rank_head = 0 14 | min_rank_head = None 15 | acc_head = 0 16 | rec_rank_head = 0 17 | 18 | if is_test: 19 | evaluate_size = len(dataset.triplets_test) 20 | evaluate_data = dataset.triplets_test 21 | else: 22 | if args.evaluate_size == 0: 23 | evaluate_size = len(dataset.triplets_dev) 24 | else: 25 | evaluate_size = args.evaluate_size 26 | evaluate_data = dataset.triplets_dev 27 | 28 | cnt_sample = 0 29 | for triplet in random.sample(evaluate_data, evaluate_size): 30 | sample_predict_head = dataset.next_sample_eval(triplet, is_test=is_test) 31 | def eval_by_batch(data_eval): 32 | prediction_all = [] 33 | for batch_eval in dataset.batch_iter_epoch(data_eval, 4096, corrupt=False, shuffle=False): 34 | batch_weight_ph, batch_weight_pt, batch_triplet, batch_relation_tail, batch_neighbor_head, batch_neighbor_tail = batch_eval 35 | # batch_triplet, batch_relation_tail, batch_neighbor_head, batch_neighbor_tail = batch_eval 36 | batch_relation_head = batch_triplet[:, 1] 37 | prediction_batch = sess.run(model.positive_score, 38 | feed_dict={ 39 | model.neighbor_head_pos: batch_neighbor_head, 40 | model.neighbor_tail_pos: batch_neighbor_tail, 41 | model.input_relation_ph: batch_relation_head, 42 | model.input_relation_pt: batch_relation_tail, 43 | model.neighbor_weight_ph: batch_weight_ph, 44 | model.neighbor_weight_pt: batch_weight_pt, 45 | }) 46 | prediction_all.extend(prediction_batch) 47 | return np.asarray(prediction_all) 48 | 49 | prediction_head = eval_by_batch(sample_predict_head) 50 | 51 | rank_head_current = (-prediction_head).argsort().argmin() + 1 52 | 53 | rank_head += rank_head_current 54 | rec_rank_head += 1.0 / rank_head_current 55 | if rank_head_current <= 10: 56 | hit10_head += 1 57 | if rank_head_current <= 3: 58 | hit3_head += 1 59 | if max_rank_head < rank_head_current: 60 | max_rank_head = rank_head_current 61 | if min_rank_head == None: 62 | min_rank_head = rank_head_current 63 | elif min_rank_head > rank_head_current: 64 | min_rank_head = rank_head_current 65 | if rank_head_current == 1: 66 | acc_head += 1 67 | 68 | rank_head_mean = rank_head // evaluate_size 69 | hit10_head = hit10_head * 1.0 / evaluate_size 70 | hit3_head = hit3_head * 1.0 / evaluate_size 71 | acc_head = acc_head * 1.0 / evaluate_size 72 | rec_rank_head = rec_rank_head / evaluate_size 73 | 74 | logger.info('epoch {} MR: {:d}, MRR: {:.3f}, hit@10: {:.3f}%, hit@3: {:.3f}%, hit@1: {:.3f}%'.format(\ 75 | epoch, rank_head_mean, rec_rank_head, hit10_head * 100, hit3_head * 100, acc_head * 100)) 76 | return rec_rank_head 77 | 78 | -------------------------------------------------------------------------------- /code/utils/triplet_classify.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import argparse 4 | import os 5 | import time 6 | import datetime 7 | import random 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | def run_triplet_classify(args, sess, model, dataset, epoch, logger, is_test=True): 12 | 13 | def get_score(data_eval): 14 | score_all = [] 15 | for batch_eval in dataset.batch_iter_epoch(data_eval, args.batch_size, corrupt=False, shuffle=False): 16 | # batch_weight_ph, batch_weight_pt, batch_triplet, batch_relation_tail, batch_neighbor_head, batch_neighbor_tail = batch_eval 17 | batch_triplet, batch_relation_tail, batch_neighbor_head, batch_neighbor_tail = batch_eval 18 | batch_relation_head = batch_triplet[:, 1] 19 | score_batch = sess.run(model.positive_score, 20 | feed_dict={ 21 | model.neighbor_head_pos: batch_neighbor_head, 22 | model.neighbor_tail_pos: batch_neighbor_tail, 23 | model.input_relation_ph: batch_relation_head, 24 | model.input_relation_pt: batch_relation_tail, 25 | # model.neighbor_weight_ph: batch_weight_ph, 26 | # model.neighbor_weight_pt: batch_weight_pt, 27 | model.dropout_keep_prob: 1.0 28 | }) 29 | score_all.extend(score_batch) 30 | score_all = np.asarray(score_all) 31 | return score_all 32 | 33 | def get_threshold(): 34 | score_all = get_score(dataset.triplets_dev) 35 | 36 | min_score = min(score_all) 37 | max_score = max(score_all) 38 | 39 | best_thresholds = [] 40 | best_accuracies = [] 41 | for i in xrange(dataset.num_relation): 42 | best_thresholds.append(min_score) 43 | best_accuracies.append(-1) 44 | 45 | score = min_score 46 | increment = 0.01 47 | while(score <= max_score): 48 | for i in xrange(dataset.num_relation): 49 | current_relation_list = (dataset.triplets_dev[:, 1] == i) 50 | predictions = (score_all[current_relation_list] >= score) * 2 -1 51 | accuracy = np.mean(predictions == dataset.triplets_dev[current_relation_list, 3]) 52 | 53 | if accuracy > best_accuracies[i]: 54 | best_accuracies[i] = accuracy 55 | best_thresholds[i] = score 56 | 57 | score += increment 58 | # logger.info('thresholds on valid set:') 59 | # for i, th in enumerate(best_thresholds): 60 | # logger.info('{}\t{:.3f}'.format(i, th)) 61 | 62 | return best_thresholds 63 | 64 | def get_prediction(triplets_test): 65 | score_all = get_score(triplets_test) 66 | best_thresholds = get_threshold() 67 | prediction_all = [] 68 | for i in xrange(len(triplets_test)): 69 | rel = triplets_test[i, 1] 70 | if score_all[i] >= best_thresholds[rel]: 71 | prediction_all.append(1) 72 | else: 73 | prediction_all.append(-1) 74 | 75 | return np.asarray(prediction_all) 76 | 77 | if is_test: 78 | triplets_test = dataset.triplets_test 79 | else: 80 | triplets_test = dataset.triplets_dev 81 | prediction_all = get_prediction(triplets_test) 82 | precision = sum([1 for res in (prediction_all == triplets_test[:, 3]) if res]) 83 | precision = precision * 100.0 / len(triplets_test) 84 | logger.info('Epoch {} Triplet classify precision: {:.3f}%'.format(epoch, precision)) 85 | return precision 86 | 87 | -------------------------------------------------------------------------------- /configs/lan_LinkPredict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='attention' 5 | iter_routing=0 6 | use_relation=1 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=1.0 11 | weight_decay=0 12 | corrupt_mode='partial' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=1000 18 | epoch_per_checkpoint=50 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /configs/lan_TripletClassify.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='attention' 5 | iter_routing=0 6 | use_relation=1 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=300.0 11 | weight_decay=1e-3 12 | corrupt_mode='both' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=500 18 | epoch_per_checkpoint=1 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /configs/lstm_LinkPredict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='lstm' 5 | iter_routing=0 6 | use_relation=0 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=1.0 11 | weight_decay=0 12 | corrupt_mode='partial' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=1000 18 | epoch_per_checkpoint=50 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /configs/lstm_TripletClassify.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='lstm' 5 | iter_routing=0 6 | use_relation=0 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=300.0 11 | weight_decay=1e-3 12 | corrupt_mode='both' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=500 18 | epoch_per_checkpoint=1 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /configs/mean_LinkPredict.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='gnn_mean' 5 | iter_routing=0 6 | use_relation=0 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=1.0 11 | weight_decay=0 12 | corrupt_mode='partial' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=1000 18 | epoch_per_checkpoint=50 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /configs/mean_TripletClassify.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | data_dir="fb15K/head-10" 4 | aggregate_type='gnn_mean' 5 | iter_routing=0 6 | use_relation=0 7 | score_function='TransE' 8 | n_neg=1 9 | loss_function='margin' 10 | margin=300.0 11 | weight_decay=1e-3 12 | corrupt_mode='both' 13 | max_neighbor=64 14 | embedding_dim=100 15 | batch_size=1024 16 | learning_rate=1e-3 17 | num_epoch=500 18 | epoch_per_checkpoint=1 19 | gpu_device="0" 20 | gpu_fraction="0.2" 21 | hparam="w${weight_decay}_${score_function}_${loss_function}${margin}_corrupt-${corrupt_mode}${n_neg}_e${embedding_dim}r${use_relation}_n${max_neighbor}_b${batch_size}_lr${learning_rate}" 22 | save_dir="./checkpoints/${data_dir}/${aggregate_type}/${hparam}/" 23 | run_log="${save_dir}run.log" 24 | -------------------------------------------------------------------------------- /data/fb15K.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangpf3/LAN/bd590fd316eafbdedfce586d6d3bbbfa8c65e0e7/data/fb15K.zip -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $1 4 | 5 | mkdir -p $save_dir 6 | 7 | nohup python -u ./code/main.py \ 8 | -S $save_dir \ 9 | -D "./data/${data_dir}" \ 10 | --use_relation $use_relation \ 11 | --aggregate_type $aggregate_type \ 12 | --score_function $score_function\ 13 | --loss_function $loss_function \ 14 | --margin $margin \ 15 | --weight_decay $weight_decay \ 16 | --corrupt_mode $corrupt_mode \ 17 | --max_neighbor $max_neighbor \ 18 | --embedding_dim $embedding_dim \ 19 | --batch_size $batch_size \ 20 | --learning_rate $learning_rate \ 21 | --num_epoch $num_epoch \ 22 | --epoch_per_checkpoint $epoch_per_checkpoint \ 23 | --gpu_device $gpu_device \ 24 | --gpu_fraction $gpu_fraction \ 25 | > $run_log 2>&1 & 26 | --------------------------------------------------------------------------------