├── Rearev_framework.png ├── requirements.txt ├── LICENSE ├── utils.py ├── main.py ├── modules ├── question_encoding │ ├── tokenizers.py │ ├── lstm_encoder.py │ ├── bert_encoder.py │ └── base_encoder.py ├── kg_reasoning │ ├── base_gnn.py │ └── reasongnn.py ├── layer_init.py └── query_update.py ├── parsing.py ├── README.md ├── evaluate.py ├── train_model.py ├── models ├── ReaRev │ └── rearev.py └── base_model.py └── dataset_load.py /Rearev_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmavro/ReaRev_KGQA/HEAD/Rearev_framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Base==1.0.4 2 | numpy==1.19.5 3 | torch==1.7.1+cu110 4 | tqdm==4.59.0 5 | transformers==4.6.1 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Costas Mavromatis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def create_logger(args): 6 | log_file = os.path.join(args.checkpoint_dir, args.experiment_name + ".log") 7 | logger = logging.getLogger() 8 | log_level = logging.DEBUG if args.log_level == "debug" else logging.INFO 9 | logger.setLevel(level=log_level) 10 | # Formatter 11 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 12 | # FileHandler 13 | file_handler = logging.FileHandler(log_file) 14 | file_handler.setFormatter(formatter) 15 | logger.addHandler(file_handler) 16 | # StreamHandler 17 | stream_handler = logging.StreamHandler() 18 | stream_handler.setFormatter(formatter) 19 | logger.addHandler(stream_handler) 20 | 21 | logger.info("PARAMETER" + "-" * 10) 22 | for attr, value in sorted(args.__dict__.items()): 23 | logger.info("{}={}".format(attr.upper(), value)) 24 | logger.info("---------" + "-" * 10) 25 | 26 | return logger 27 | 28 | 29 | def get_dict(data_folder, filename): 30 | filename_true = os.path.join(data_folder, filename) 31 | word2id = dict() 32 | with open(filename_true, encoding='utf-8') as f_in: 33 | for line in f_in: 34 | word = line.strip() 35 | word2id[word] = len(word2id) 36 | return word2id 37 | 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils import create_logger 4 | import torch 5 | import numpy as np 6 | import os 7 | import time 8 | #from Models.ReaRev.rearev import 9 | from train_model import Trainer_KBQA 10 | from parsing import add_parse_args 11 | 12 | parser = argparse.ArgumentParser() 13 | add_parse_args(parser) 14 | 15 | args = parser.parse_args() 16 | args.use_cuda = torch.cuda.is_available() 17 | 18 | np.random.seed(args.seed) 19 | torch.manual_seed(args.seed) 20 | if args.experiment_name == None: 21 | timestamp = str(int(time.time())) 22 | args.experiment_name = "{}-{}-{}".format( 23 | args.dataset, 24 | args.model_name, 25 | timestamp, 26 | ) 27 | 28 | 29 | def main(): 30 | if not os.path.exists(args.checkpoint_dir): 31 | os.mkdir(args.checkpoint_dir) 32 | logger = create_logger(args) 33 | trainer = Trainer_KBQA(args=vars(args), model_name=args.model_name, logger=logger) 34 | if not args.is_eval: 35 | trainer.train(0, args.num_epoch - 1) 36 | else: 37 | assert args.load_experiment is not None 38 | if args.load_experiment is not None: 39 | ckpt_path = os.path.join(args.checkpoint_dir, args.load_experiment) 40 | print("Loading pre trained model from {}".format(ckpt_path)) 41 | else: 42 | ckpt_path = None 43 | trainer.evaluate_single(ckpt_path) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /modules/question_encoding/tokenizers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from transformers import BertTokenizer 4 | 5 | class LSTMTokenizer(): 6 | def __init__(self, word2id, max_query_word): 7 | super(LSTMTokenizer, self).__init__() 8 | self.word2id = word2id 9 | self.max_query_word = max_query_word 10 | 11 | def tokenize(self, question): 12 | tokens = self.tokenize_sent(question) 13 | query_text = np.full(self.max_query_word, len(self.word2id), dtype=int) 14 | #tokens = question.split() 15 | #if self.data_type == "train": 16 | # random.shuffle(tokens) 17 | for j, word in enumerate(tokens): 18 | if j < self.max_query_word: 19 | if word in self.word2id: 20 | query_text[j] = self.word2id[word] 21 | 22 | else: 23 | query_text[j] = len(self.word2id) 24 | 25 | return query_text 26 | 27 | @staticmethod 28 | def tokenize_sent(question_text): 29 | question_text = question_text.strip().lower() 30 | question_text = re.sub('\'s', ' s', question_text) 31 | words = [] 32 | toks = enumerate(question_text.split(' ')) 33 | 34 | for w_idx, w in toks: 35 | w = re.sub('^[^a-z0-9]|[^a-z0-9]$', '', w) 36 | if w == '': 37 | continue 38 | words += [w] 39 | return words 40 | 41 | class BERTTokenizer(): 42 | def __init__(self, max_query_word): 43 | super(BERTTokenizer, self).__init__() 44 | self.q_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 45 | self.max_query_word = max_query_word 46 | self.num_word = self.q_tokenizer.encode("[UNK]")[0] #len(self.q_tokenizer.vocab.keys()) 47 | 48 | 49 | 50 | def tokenize(self, question): 51 | query_text = np.full(self.max_query_word, 0, dtype=int) 52 | tokens = self.q_tokenizer.encode_plus(text=question, max_length=self.max_query_word, \ 53 | pad_to_max_length=True, return_attention_mask = False, truncation=True) 54 | return np.array(tokens['input_ids']) -------------------------------------------------------------------------------- /modules/question_encoding/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch.nn as nn 4 | from utils import get_dict 5 | from .base_encoder import BaseInstruction 6 | 7 | VERY_SMALL_NUMBER = 1e-10 8 | VERY_NEG_NUMBER = -100000000000 9 | 10 | class LSTMInstruction(BaseInstruction): 11 | 12 | def __init__(self, args, word_embedding, num_word): 13 | super(LSTMInstruction, self).__init__(args) 14 | self.word2id = get_dict(args['data_folder'],args['word2id']) 15 | 16 | self.word_embedding = word_embedding 17 | self.num_word = num_word 18 | self.encoder_def() 19 | entity_dim = self.entity_dim 20 | self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim) 21 | self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1) 22 | for i in range(self.num_ins): 23 | self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim)) 24 | 25 | def encoder_def(self): 26 | # initialize entity embedding 27 | word_dim = self.word_dim 28 | entity_dim = self.entity_dim 29 | self.node_encoder = nn.LSTM(input_size=word_dim, hidden_size=entity_dim, 30 | batch_first=True, bidirectional=False) 31 | 32 | def encode_question(self, query_text, store=True): 33 | batch_size = query_text.size(0) 34 | query_word_emb = self.word_embedding(query_text) # batch_size, max_query_word, word_dim 35 | query_hidden_emb, (h_n, c_n) = self.node_encoder(self.lstm_drop(query_word_emb), 36 | self.init_hidden(1, batch_size, 37 | self.entity_dim)) # 1, batch_size, entity_dim 38 | if store: 39 | self.instruction_hidden = h_n 40 | self.instruction_mem = c_n 41 | self.query_node_emb = h_n.squeeze(dim=0).unsqueeze(dim=1) # batch_size, 1, entity_dim 42 | self.query_hidden_emb = query_hidden_emb 43 | self.query_mask = (query_text != self.num_word).float() 44 | return query_hidden_emb, self.query_node_emb 45 | else: 46 | return query_hidden_emb 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /modules/kg_reasoning/base_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | VERY_NEG_NUMBER = -100000000000 6 | 7 | class BaseGNNLayer(torch.nn.Module): 8 | """ 9 | Builds sparse tensors that represent structure. 10 | """ 11 | def __init__(self, args, num_entity, num_relation): 12 | super(BaseGNNLayer, self).__init__() 13 | self.num_relation = num_relation 14 | self.num_entity = num_entity 15 | self.device = torch.device('cuda' if args['use_cuda'] else 'cpu') 16 | self.normalized_gnn = args['normalized_gnn'] 17 | 18 | 19 | def build_matrix(self): 20 | batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list = self.edge_list 21 | num_fact = len(fact_ids) 22 | num_relation = self.num_relation 23 | batch_size = self.batch_size 24 | max_local_entity = self.max_local_entity 25 | self.num_fact = num_fact 26 | fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device) 27 | fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device) 28 | head2fact = torch.LongTensor([fact_ids, batch_heads]).to(self.device) 29 | tail2fact = torch.LongTensor([fact_ids, batch_tails]).to(self.device) 30 | rel2fact = torch.LongTensor([fact_ids, batch_rels + batch_ids * num_relation]).to(self.device) 31 | fact2rel = torch.LongTensor([batch_rels + batch_ids * num_relation, fact_ids]).to(self.device) 32 | self.batch_rels = torch.LongTensor(batch_rels).to(self.device) 33 | self.batch_ids = torch.LongTensor(batch_ids).to(self.device) 34 | self.batch_heads = torch.LongTensor(batch_heads).to(self.device) 35 | self.batch_tails = torch.LongTensor(batch_tails).to(self.device) 36 | # self.batch_ids = batch_ids 37 | if self.normalized_gnn: 38 | vals = torch.FloatTensor(weight_list).to(self.device) 39 | else: 40 | vals = torch.ones_like(self.batch_ids).float().to(self.device) 41 | 42 | #vals = torch.ones_like(self.batch_ids).float().to(self.device) 43 | # Sparse Matrix for reason on graph 44 | self.fact2head_mat = self._build_sparse_tensor(fact2head, vals, (batch_size * max_local_entity, num_fact)) 45 | self.head2fact_mat = self._build_sparse_tensor(head2fact, vals, (num_fact, batch_size * max_local_entity)) 46 | self.fact2tail_mat = self._build_sparse_tensor(fact2tail, vals, (batch_size * max_local_entity, num_fact)) 47 | self.tail2fact_mat = self._build_sparse_tensor(tail2fact, vals, (num_fact, batch_size * max_local_entity)) 48 | self.fact2rel_mat = self._build_sparse_tensor(fact2rel, vals, (batch_size * num_relation, num_fact)) 49 | self.rel2fact_mat = self._build_sparse_tensor(rel2fact, vals, (num_fact, batch_size * num_relation)) 50 | 51 | def _build_sparse_tensor(self, indices, values, size): 52 | return torch.sparse.FloatTensor(indices, values, size).to(self.device) 53 | -------------------------------------------------------------------------------- /modules/layer_init.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | VERY_NEG_NUMBER = -100000000000 6 | VERY_SMALL_NUMBER = 1e-10 7 | 8 | 9 | class TypeLayer(nn.Module): 10 | """ 11 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 12 | """ 13 | 14 | def __init__(self, in_features, out_features, linear_drop, device): 15 | super(TypeLayer, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.linear_drop = linear_drop 19 | # self.kb_head_linear = nn.Linear(in_features, out_features) 20 | self.kb_self_linear = nn.Linear(in_features, out_features) 21 | # self.kb_tail_linear = nn.Linear(out_features, out_features) 22 | self.device = device 23 | 24 | def forward(self, local_entity, edge_list, rel_features): 25 | ''' 26 | input_vector: (batch_size, max_local_entity) 27 | curr_dist: (batch_size, max_local_entity) 28 | instruction: (batch_size, hidden_size) 29 | ''' 30 | batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list = edge_list 31 | num_fact = len(fact_ids) 32 | batch_size, max_local_entity = local_entity.size() 33 | hidden_size = self.in_features 34 | fact2head = torch.LongTensor([batch_heads, fact_ids]).to(self.device) 35 | fact2tail = torch.LongTensor([batch_tails, fact_ids]).to(self.device) 36 | batch_rels = torch.LongTensor(batch_rels).to(self.device) 37 | batch_ids = torch.LongTensor(batch_ids).to(self.device) 38 | val_one = torch.ones_like(batch_ids).float().to(self.device) 39 | 40 | 41 | # print("Prepare data:{:.4f}".format(time.time() - st)) 42 | # Step 1: Calculate value for every fact with rel and head 43 | fact_rel = torch.index_select(rel_features, dim=0, index=batch_rels) 44 | # fact_val = F.relu(self.kb_self_linear(fact_rel) + self.kb_head_linear(self.linear_drop(fact_ent))) 45 | fact_val = self.kb_self_linear(fact_rel) 46 | # fact_val = self.kb_self_linear(fact_rel)#self.kb_head_linear(self.linear_drop(fact_ent)) 47 | 48 | # Step 3: Edge Aggregation with Sparse MM 49 | fact2tail_mat = self._build_sparse_tensor(fact2tail, val_one, (batch_size * max_local_entity, num_fact)) 50 | fact2head_mat = self._build_sparse_tensor(fact2head, val_one, (batch_size * max_local_entity, num_fact)) 51 | 52 | # neighbor_rep = torch.sparse.mm(fact2tail_mat, self.kb_tail_linear(self.linear_drop(fact_val))) 53 | f2e_emb = F.relu(torch.sparse.mm(fact2tail_mat, fact_val) + torch.sparse.mm(fact2head_mat, fact_val)) 54 | assert not torch.isnan(f2e_emb).any() 55 | 56 | f2e_emb = f2e_emb.view(batch_size, max_local_entity, hidden_size) 57 | 58 | return f2e_emb 59 | 60 | def _build_sparse_tensor(self, indices, values, size): 61 | return torch.sparse.FloatTensor(indices, values, size).to(self.device) 62 | -------------------------------------------------------------------------------- /parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | def bool_flag(v): 5 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 6 | return True 7 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 8 | return False 9 | else: 10 | raise argparse.ArgumentTypeError('Boolean value expected.') 11 | 12 | def add_shared_args(parser): 13 | parser.add_argument('--name', default='webqsp', type=str) 14 | parser.add_argument('--data_folder', default='data/webqsp/', type=str) 15 | parser.add_argument('--max_train', default=200000, type=int) 16 | 17 | # embeddings 18 | parser.add_argument('--word2id', default='vocab.txt', type=str) 19 | parser.add_argument('--relation2id', default='relations.txt', type=str) 20 | parser.add_argument('--entity2id', default='entities.txt', type=str) 21 | parser.add_argument('--char2id', default='chars.txt', type=str) 22 | parser.add_argument('--entity_emb_file', default=None, type=str) 23 | parser.add_argument('--relation_emb_file', default=None, type=str) 24 | parser.add_argument('--relation_word_emb', default=False, type=bool_flag) 25 | parser.add_argument('--word_emb_file', default='word_emb.npy', type=str) 26 | parser.add_argument('--rel_word_ids', default='rel_word_idx.npy', type=str) 27 | parser.add_argument('--kge_frozen', default=0, type=int) 28 | parser.add_argument('--lm', default='lstm', type=str, choices=['lstm', 'bert', 'roberta', 'sbert', 't5','sbert2']) 29 | parser.add_argument('--lm_frozen', default=1, type=int) 30 | 31 | # dimensions, layers, dropout 32 | parser.add_argument('--entity_dim', default=50, type=int) 33 | parser.add_argument('--kg_dim', default=100, type=int) 34 | parser.add_argument('--word_dim', default=300, type=int) 35 | parser.add_argument('--lm_dropout', default=0.3, type=float) 36 | parser.add_argument('--linear_dropout', default=0.2, type=float) 37 | 38 | # optimization 39 | parser.add_argument('--num_epoch', default=100, type=int) 40 | parser.add_argument('--fact_scale', default=3, type=int) 41 | parser.add_argument('--eval_every', default=2, type=int) 42 | parser.add_argument('--batch_size', default=20, type=int) 43 | parser.add_argument('--gradient_clip', default=1.0, type=float) 44 | parser.add_argument('--lr', default=0.0005, type=float) 45 | parser.add_argument('--decay_rate', default=0.0, type=float) 46 | parser.add_argument('--seed', default=19960626, type=int) 47 | parser.add_argument('--lr_schedule', action='store_true') 48 | parser.add_argument('--label_smooth', default=0.1, type=float) 49 | parser.add_argument('--fact_drop', default=0, type=float) 50 | #parser.add_argument('--encode_type', action='store_true') 51 | 52 | # model options 53 | 54 | parser.add_argument('--is_eval', action='store_true') 55 | parser.add_argument('--checkpoint_dir', default='checkpoint/pretrain/', type=str) 56 | parser.add_argument('--log_level', type=str, default='info') 57 | parser.add_argument('--experiment_name', default='', type=str) 58 | parser.add_argument('--load_experiment', default=None, type=str) 59 | parser.add_argument('--load_ckpt_file', default=None, type=str) 60 | parser.add_argument('--eps', default=0.95, type=float) # threshold for f1 61 | parser.add_argument('--test_batch_size', default=20, type=int) 62 | parser.add_argument('--q_type', default='seq', type=str) 63 | 64 | 65 | 66 | def add_parse_args(parser): 67 | 68 | subparsers = parser.add_subparsers(help='Reason KGQA model') 69 | 70 | parser_rearev = subparsers.add_parser("ReaRev") 71 | create_parser_rearev(parser_rearev) 72 | 73 | 74 | def create_parser_rearev(parser): 75 | 76 | parser.add_argument('--model_name', default='ReaRev', type=str, choices=['ReaRev']) 77 | parser.add_argument('--alg', default='bfs', type=str) 78 | parser.add_argument('--num_iter', default=2, type=int) 79 | parser.add_argument('--num_ins', default=3, type=int) 80 | parser.add_argument('--num_gnn', default=3, type=int) 81 | parser.add_argument('--loss_type', default='kl', type=str) 82 | parser.add_argument('--use_self_loop', default=True, type=bool_flag) 83 | parser.add_argument('--normalized_gnn', default=False, type=bool_flag) 84 | parser.add_argument('--data_eff', action='store_true') 85 | add_shared_args(parser) 86 | -------------------------------------------------------------------------------- /modules/question_encoding/bert_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | VERY_SMALL_NUMBER = 1e-10 5 | VERY_NEG_NUMBER = -100000000000 6 | 7 | 8 | from transformers import AutoModel, AutoTokenizer #DistilBertModel, BertModel, BertTokenizer, RobertaModel, RobertaTokenizer 9 | from torch.nn import LayerNorm 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | import os 13 | try: 14 | os.environ['TRANSFORMERS_CACHE'] = '/export/scratch/costas/home/mavro016/.cache' 15 | except: 16 | pass 17 | 18 | from .base_encoder import BaseInstruction 19 | 20 | 21 | class BERTInstruction(BaseInstruction): 22 | 23 | def __init__(self, args, word_embedding, num_word, model): 24 | super(BERTInstruction, self).__init__(args) 25 | self.word_embedding = word_embedding 26 | self.num_word = num_word 27 | 28 | entity_dim = self.entity_dim 29 | self.model = model 30 | 31 | 32 | if model == 'bert': 33 | self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 34 | self.pretrained_weights = 'bert-base-uncased' 35 | word_dim = 768#self.word_dim 36 | elif model == 'roberta': 37 | self.tokenizer = AutoTokenizer.from_pretrained('roberta-base') 38 | self.pretrained_weights = 'roberta-base' 39 | word_dim = 768#self.word_dim 40 | elif model == 'sbert': 41 | self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') 42 | self.pretrained_weights = 'sentence-transformers/all-MiniLM-L6-v2' 43 | word_dim = 384#self.word_dim 44 | elif model == 'sbert2': 45 | self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') 46 | self.pretrained_weights = 'sentence-transformers/all-mpnet-base-v2' 47 | word_dim = 768#self.word_dim 48 | elif model == 't5': 49 | self.tokenizer = AutoTokenizer.from_pretrained('t5-small') 50 | self.pretrained_weights = 't5-small' 51 | word_dim = 512#self.word_dim 52 | #self.mask = mask 53 | self.pad_val = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 54 | self.word_dim = word_dim 55 | 56 | print('word_dim', self.word_dim) 57 | self.cq_linear = nn.Linear(in_features=4 * entity_dim, out_features=entity_dim) 58 | self.ca_linear = nn.Linear(in_features=entity_dim, out_features=1) 59 | for i in range(self.num_ins): 60 | self.add_module('question_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim)) 61 | self.question_emb = nn.Linear(in_features=word_dim, out_features=entity_dim) 62 | 63 | self.encoder_def() 64 | 65 | def encoder_def(self): 66 | # initialize entity embedding 67 | word_dim = self.word_dim 68 | entity_dim = self.entity_dim 69 | self.node_encoder = AutoModel.from_pretrained(self.pretrained_weights) 70 | print('Total Params', sum(p.numel() for p in self.node_encoder.parameters())) 71 | if self.lm_frozen == 1: 72 | print('Freezing LM params') 73 | for param in self.node_encoder.parameters(): 74 | param.requires_grad = False 75 | else: 76 | print('Unfrozen LM params') 77 | 78 | def encode_question(self, query_text, store=True): 79 | batch_size = query_text.size(0) 80 | 81 | if self.model != 't5': 82 | 83 | query_hidden_emb = self.node_encoder(query_text)[0] # 1, batch_size, entity_dim 84 | else: 85 | query_hidden_emb = self.node_encoder.encoder(query_text)[0] 86 | #print(query_hidden_emb.size()) 87 | 88 | 89 | if store: 90 | self.query_hidden_emb = self.question_emb(query_hidden_emb) 91 | self.query_node_emb = query_hidden_emb.transpose(1,0)[0].unsqueeze(1) 92 | #print(self.query_node_emb.size()) 93 | self.query_node_emb = self.question_emb(self.query_node_emb) 94 | 95 | self.query_mask = (query_text != self.pad_val).float() 96 | return query_hidden_emb, self.query_node_emb 97 | else: 98 | return query_hidden_emb 99 | 100 | -------------------------------------------------------------------------------- /modules/question_encoding/base_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | VERY_SMALL_NUMBER = 1e-10 6 | VERY_NEG_NUMBER = -100000000000 7 | 8 | class BaseInstruction(torch.nn.Module): 9 | 10 | def __init__(self, args): 11 | super(BaseInstruction, self).__init__() 12 | self._parse_args(args) 13 | self.share_module_def() 14 | 15 | def _parse_args(self, args): 16 | self.device = torch.device('cuda' if args['use_cuda'] else 'cpu') 17 | 18 | 19 | # self.share_encoder = args['share_encoder'] 20 | self.q_type = args['q_type'] 21 | if 'num_step' in args: 22 | self.num_ins = args['num_step'] 23 | elif 'num_ins' in args: 24 | self.num_ins = args['num_ins'] 25 | else: 26 | self.num_ins = 1 27 | 28 | self.lm_dropout = args['lm_dropout'] 29 | self.linear_dropout = args['linear_dropout'] 30 | self.lm_frozen = args['lm_frozen'] 31 | 32 | for k, v in args.items(): 33 | if k.endswith('dim'): 34 | setattr(self, k, v) 35 | if k.endswith('emb_file') or k.endswith('kge_file'): 36 | if v is None: 37 | setattr(self, k, None) 38 | else: 39 | setattr(self, k, args['data_folder'] + v) 40 | 41 | self.reset_time = 0 42 | 43 | def share_module_def(self): 44 | # dropout 45 | self.lstm_drop = nn.Dropout(p=self.lm_dropout) 46 | self.linear_drop = nn.Dropout(p=self.linear_dropout) 47 | 48 | def init_hidden(self, num_layer, batch_size, hidden_size): 49 | return (torch.zeros(num_layer, batch_size, hidden_size).to(self.device), 50 | torch.zeros(num_layer, batch_size, hidden_size).to(self.device)) 51 | 52 | def encode_question(self, *args): 53 | # constituency tree or query_text 54 | pass 55 | 56 | @staticmethod 57 | def get_node_emb(query_hidden_emb, action): 58 | ''' 59 | 60 | :param query_hidden_emb: (batch_size, max_hyper, emb) 61 | :param action: (batch_size) 62 | :return: (batch_size, 1, emb) 63 | ''' 64 | batch_size, max_hyper, _ = query_hidden_emb.size() 65 | row_idx = torch.arange(0, batch_size).type(torch.LongTensor) 66 | q_rep = query_hidden_emb[row_idx, action, :] 67 | return q_rep.unsqueeze(1) 68 | 69 | def init_reason(self, query_text): 70 | self.batch_size = query_text.size(0) 71 | self.max_query_word = query_text.size(1) 72 | self.encode_question(query_text) 73 | self.relational_ins = torch.zeros(self.batch_size, self.entity_dim).to(self.device) 74 | self.instructions = [] 75 | self.attn_list = [] 76 | 77 | def get_instruction(self, relational_ins, step=0, query_node_emb=None): 78 | 79 | query_hidden_emb = self.query_hidden_emb 80 | 81 | query_mask = self.query_mask 82 | if query_node_emb is None: 83 | query_node_emb = self.query_node_emb 84 | 85 | relational_ins = relational_ins.unsqueeze(1) 86 | question_linear = getattr(self, 'question_linear' + str(step)) 87 | q_i = question_linear(self.linear_drop(query_node_emb)) 88 | cq = self.cq_linear(self.linear_drop(torch.cat((relational_ins, q_i, q_i-relational_ins,q_i*relational_ins), dim=-1))) 89 | # batch_size, 1, entity_dim 90 | ca = self.ca_linear(self.linear_drop(cq * query_hidden_emb)) 91 | # batch_size, max_local_entity, 1 92 | # cv = self.softmax_d1(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER) 93 | attn_weight = F.softmax(ca + (1 - query_mask.unsqueeze(2)) * VERY_NEG_NUMBER, dim=1) 94 | # batch_size, max_local_entity, 1 95 | relational_ins = torch.sum(attn_weight * query_hidden_emb, dim=1) 96 | return relational_ins, attn_weight 97 | 98 | 99 | 100 | def forward(self, query_text): 101 | self.init_reason(query_text) 102 | for i in range(self.num_ins): 103 | relational_ins, attn_weight = self.get_instruction(self.relational_ins, step=i) 104 | self.instructions.append(relational_ins) 105 | self.attn_list.append(attn_weight) 106 | self.relational_ins = relational_ins 107 | return self.instructions, self.attn_list 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReaRev [EMNLP 2022] 2 | This is the code for the EMNLP 2022 Findings paper: [ReaRev: Adaptive Reasoning for Question Answering over Knowledge 3 | Graphs](https://arxiv.org/abs/2210.13650). 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rearev-adaptive-reasoning-for-question/semantic-parsing-on-webquestionssp)](https://paperswithcode.com/sota/semantic-parsing-on-webquestionssp?p=rearev-adaptive-reasoning-for-question) 6 | 7 | ## Overview 8 | Our methods improves instruction decoding and execution for KGQA via adaptive reasoning, as shown: 9 | 10 | ![](./Rearev_framework.png) 11 | 12 | 13 | ## Get Started 14 | We have simple requirements in `requirements.txt'. You can always check if you can run the code immediately. 15 | 16 | We use the pre-processed data from: https://drive.google.com/drive/folders/1qRXeuoL-ArQY7pJFnMpNnBu0G-cOz6xv 17 | Download it and extract it to a folder named "data". 18 | 19 | __Acknowledgements__: 20 | 21 | [NSM](https://github.com/RichardHGL/WSDM2021_NSM): Datasets (webqsp, CWQ, MetaQA) / Code. 22 | 23 | [GraftNet](https://github.com/haitian-sun/GraftNet): Datasets (webqsp incomplete, MetaQA) / Code. 24 | 25 | ## Training 26 | 27 | To run Webqsp: 28 | ``` 29 | python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 \ 30 | --data_folder data/webqsp/ --lm sbert --num_iter 3 --num_ins 2 --num_gnn 2 \ 31 | --relation_word_emb True --experiment_name Webqsp322 --name webqsp 32 | ``` 33 | 34 | To run CWQ: 35 | ``` 36 | python main.py ReaRev --entity_dim 50 --num_epoch 100 --batch_size 8 --eval_every 2 \ 37 | --data_folder data/CWQ/ --lm sbert --num_iter 2 --num_ins 3 --num_gnn 3 \ 38 | --relation_word_emb True --experiment_name CWQ --name cwq 39 | ``` 40 | To run MetaQA-3: 41 | ``` 42 | python main.py ReaRev --entity_dim 50 --num_epoch 10 --batch_size 8 --eval_every 2 \ 43 | --data_folder data/metaqa-3hop/ --lm lstm --num_iter 2 --num_ins 3 --num_gnn 3 \ 44 | --relation_word_emb False --experiment_name metaqa3 --name metaqa 45 | ``` 46 | 47 | For incomplete Webqsp, see 'data/incomplete/' (after obtaining them by [GraftNet](https://github.com/haitian-sun/GraftNet)). If you cannot afford a lot of memory for CWQ, use the '--data_eff' argument (see our arguments in `parsing.py'). 48 | 49 | ## Results 50 | 51 | We also provide some pretrained ReaRev models (ReaRev_webqsp.ckpt, ReaRev_webqsp_v2.ckpt, ReaRev_CWQ.ckpt). You can download them from [here](https://drive.google.com/file/d/1p7eLSsSKkZQxB32mT5lMsthVP6R_3x1j/view?usp=share_link). Please extract them to a folder `checkpoint/pretrain/'. 52 | 53 | To reproduce Webqsp results, run: 54 | ``` 55 | python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 --data_folder data/webqsp/ --lm sbert --num_iter 3 --num_ins 2 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_webqsp.ckpt --is_eval --name webqsp 56 | ``` 57 | or 58 | ``` 59 | python main.py ReaRev --entity_dim 50 --num_epoch 200 --batch_size 8 --eval_every 2 --data_folder data/webqsp/ --lm sbert --num_iter 3 --num_ins 2 --num_gnn 2 --relation_word_emb True --load_experiment ReaRev_webqsp_v2.ckpt --is_eval --name webqsp 60 | ``` 61 | 62 | To reproduce CWQ results, run: 63 | ``` 64 | python main.py ReaRev --entity_dim 50 --num_epoch 100 --batch_size 8 --eval_every 2 --data_folder .data/CWQ/ --lm sbert --num_iter 2 --num_ins 3 --num_gnn 3 --relation_word_emb True --load_experiment ReaRev_CWQ.ckpt --is_eval --name cwq 65 | ``` 66 | 67 | |Models | Webqsp| CWQ | MetaQA-3hop| 68 | |:---:|:---:|:---:|:---:| 69 | |KV-Mem| 46.7 | 21.1| 48.9 | 70 | |GraftNet| 66.4 | 32.8 |77.7 | 71 | |PullNet| 68.1 | 45.9 | 91.4| 72 | |NSM-distill| 74.3 | 48.8 | **98.9** | 73 | |ReaRev| **76.4** | **52.9** | **98.9** | 74 | 75 | ## Cite 76 | If you find our code or method useful, please cite our work as 77 | ``` 78 | @article{mavromatis2022rearev, 79 | title={ReaRev: Adaptive Reasoning for Question Answering over Knowledge Graphs}, 80 | author={Mavromatis, Costas and Karypis, George}, 81 | journal={arXiv preprint arXiv:2210.13650}, 82 | year={2022} 83 | } 84 | ``` 85 | or 86 | ``` 87 | @inproceedings{mavromatis-karypis-2022-rearev, 88 | title = "{R}ea{R}ev: Adaptive Reasoning for Question Answering over Knowledge Graphs", 89 | author = "Mavromatis, Costas and 90 | Karypis, George", 91 | booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2022", 92 | month = dec, 93 | year = "2022", 94 | address = "Abu Dhabi, United Arab Emirates", 95 | publisher = "Association for Computational Linguistics", 96 | url = "https://aclanthology.org/2022.findings-emnlp.181", 97 | pages = "2447--2458", 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /modules/query_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class Fusion(nn.Module): 7 | """docstring for Fusion""" 8 | def __init__(self, d_hid): 9 | super(Fusion, self).__init__() 10 | self.r = nn.Linear(d_hid*3, d_hid, bias=False) 11 | self.g = nn.Linear(d_hid*3, d_hid, bias=False) 12 | 13 | def forward(self, x, y): 14 | r_ = self.r(torch.cat([x,y,x-y], dim=-1))#.tanh() 15 | g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y], dim=-1))) 16 | return g_ * r_ + (1 - g_) * x 17 | 18 | class QueryReform(nn.Module): 19 | """docstring for QueryReform""" 20 | def __init__(self, h_dim): 21 | super(QueryReform, self).__init__() 22 | # self.q_encoder = AttnEncoder(h_dim) 23 | self.fusion = Fusion(h_dim) 24 | self.q_ent_attn = nn.Linear(h_dim, h_dim) 25 | 26 | def forward(self, q_node, ent_emb, seed_info, ent_mask): 27 | ''' 28 | q: (B,q_len,h_dim) 29 | q_mask: (B,q_len) 30 | q_ent_span: (B,q_len) 31 | ent_emb: (B,C,h_dim) 32 | seed_info: (B, C) 33 | ent_mask: (B, C) 34 | ''' 35 | # q_node = self.q_encoder(q, q_mask) 36 | q_ent_attn = (self.q_ent_attn(q_node).unsqueeze(1) * ent_emb).sum(2, keepdim=True) 37 | q_ent_attn = F.softmax(q_ent_attn - (1 - ent_mask.unsqueeze(2)) * 1e8, dim=1) 38 | attn_retrieve = (q_ent_attn * ent_emb).sum(1) 39 | 40 | seed_retrieve = torch.bmm(seed_info.unsqueeze(1), ent_emb).squeeze(1) # (B, 1, h_dim) 41 | # how to calculate the gate 42 | 43 | #return self.fusion(q_node, attn_retrieve) 44 | return self.fusion(q_node, seed_retrieve) 45 | 46 | class AttnEncoder(nn.Module): 47 | """docstring for ClassName""" 48 | def __init__(self, d_hid): 49 | super(AttnEncoder, self).__init__() 50 | self.attn_linear = nn.Linear(d_hid, 1, bias=False) 51 | 52 | def forward(self, x, x_mask): 53 | """ 54 | x: (B, len, d_hid) 55 | x_mask: (B, len) 56 | return: (B, d_hid) 57 | """ 58 | x_attn = self.attn_linear(x) 59 | x_attn = x_attn - (1 - x_mask.unsqueeze(2))*1e8 60 | x_attn = F.softmax(x_attn, dim=1) 61 | return (x*x_attn).sum(1) 62 | 63 | class Attention(nn.Module): 64 | """ Applies attention mechanism on the `context` using the `query`. 65 | 66 | **Thank you** to IBM for their initial implementation of :class:`Attention`. Here is 67 | their `License 68 | `__. 69 | 70 | Args: 71 | dimensions (int): Dimensionality of the query and context. 72 | attention_type (str, optional): How to compute the attention score: 73 | 74 | * dot: :math:`score(H_j,q) = H_j^T q` 75 | * general: :math:`score(H_j, q) = H_j^T W_a q` 76 | 77 | Example: 78 | 79 | >>> attention = Attention(256) 80 | >>> query = torch.randn(5, 1, 256) 81 | >>> context = torch.randn(5, 5, 256) 82 | >>> output, weights = attention(query, context) 83 | >>> output.size() 84 | torch.Size([5, 1, 256]) 85 | >>> weights.size() 86 | torch.Size([5, 1, 5]) 87 | """ 88 | 89 | def __init__(self, dimensions, attention_type='general'): 90 | super(Attention, self).__init__() 91 | 92 | if attention_type not in ['dot', 'general']: 93 | raise ValueError('Invalid attention type selected.') 94 | 95 | self.attention_type = attention_type 96 | if self.attention_type == 'general': 97 | self.linear_in = nn.Linear(dimensions, dimensions, bias=False) 98 | 99 | self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False) 100 | self.softmax = nn.Softmax(dim=-1) 101 | self.tanh = nn.Tanh() 102 | 103 | def forward(self, query, context): 104 | """ 105 | Args: 106 | query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of 107 | queries to query the context. 108 | context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data 109 | overwhich to apply the attention mechanism. 110 | 111 | Returns: 112 | :class:`tuple` with `output` and `weights`: 113 | * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]): 114 | Tensor containing the attended features. 115 | * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]): 116 | Tensor containing attention weights. 117 | """ 118 | batch_size, output_len, dimensions = query.size() 119 | query_len = context.size(1) 120 | 121 | if self.attention_type == "general": 122 | query = query.reshape(batch_size * output_len, dimensions) 123 | query = self.linear_in(query) 124 | query = query.reshape(batch_size, output_len, dimensions) 125 | 126 | # TODO: Include mask on PADDING_INDEX? 127 | 128 | # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) -> 129 | # (batch_size, output_len, query_len) 130 | attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous()) 131 | 132 | # Compute weights across every context sequence 133 | attention_scores = attention_scores.view(batch_size * output_len, query_len) 134 | attention_weights = self.softmax(attention_scores) 135 | attention_weights = attention_weights.view(batch_size, output_len, query_len) 136 | 137 | # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) -> 138 | # (batch_size, output_len, dimensions) 139 | mix = torch.bmm(attention_weights, context) 140 | 141 | # concat -> (batch_size * output_len, 2*dimensions) 142 | combined = torch.cat((mix, query), dim=2) 143 | combined = combined.view(batch_size * output_len, 2 * dimensions) 144 | 145 | # Apply linear_out on every 2nd dimension of concat 146 | # output -> (batch_size, output_len, dimensions) 147 | output = self.linear_out(combined).view(batch_size, output_len, dimensions) 148 | output = self.tanh(output) 149 | 150 | return output, attention_weights -------------------------------------------------------------------------------- /modules/kg_reasoning/reasongnn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | 7 | from .base_gnn import BaseGNNLayer 8 | 9 | VERY_NEG_NUMBER = -100000000000 10 | 11 | class ReasonGNNLayer(BaseGNNLayer): 12 | """ 13 | GNN Reasoning 14 | """ 15 | def __init__(self, args, num_entity, num_relation, entity_dim, alg): 16 | super(ReasonGNNLayer, self).__init__(args, num_entity, num_relation) 17 | self.num_entity = num_entity 18 | self.num_relation = num_relation 19 | self.entity_dim = entity_dim 20 | self.alg = alg 21 | self.num_ins = args['num_ins'] 22 | self.num_gnn = args['num_gnn'] 23 | 24 | self.init_layers(args) 25 | 26 | def init_layers(self, args): 27 | entity_dim = self.entity_dim 28 | self.softmax_d1 = nn.Softmax(dim=1) 29 | self.score_func = nn.Linear(in_features=entity_dim, out_features=1) 30 | self.glob_lin = nn.Linear(in_features=entity_dim, out_features=entity_dim) 31 | self.lin = nn.Linear(in_features=2*entity_dim, out_features=entity_dim) 32 | assert self.alg == 'bfs' 33 | self.linear_dropout = args['linear_dropout'] 34 | self.linear_drop = nn.Dropout(p=self.linear_dropout) 35 | for i in range(self.num_gnn): 36 | self.add_module('rel_linear' + str(i), nn.Linear(in_features=entity_dim, out_features=entity_dim)) 37 | if self.alg == 'bfs': 38 | self.add_module('e2e_linear' + str(i), nn.Linear(in_features=2*(self.num_ins)*entity_dim + entity_dim, out_features=entity_dim)) 39 | self.lin_m = nn.Linear(in_features=(self.num_ins)*entity_dim, out_features=entity_dim) 40 | 41 | def init_reason(self, local_entity, kb_adj_mat, local_entity_emb, rel_features, rel_features_inv, query_entities, query_node_emb=None): 42 | batch_size, max_local_entity = local_entity.size() 43 | self.local_entity_mask = (local_entity != self.num_entity).float() 44 | self.batch_size = batch_size 45 | self.max_local_entity = max_local_entity 46 | self.edge_list = kb_adj_mat 47 | self.rel_features = rel_features 48 | self.rel_features_inv = rel_features_inv 49 | self.local_entity_emb = local_entity_emb 50 | self.num_relation = self.rel_features.size(0) 51 | self.possible_cand = [] 52 | self.build_matrix() 53 | self.query_entities = query_entities 54 | 55 | 56 | def reason_layer(self, curr_dist, instruction, rel_linear): 57 | """ 58 | Aggregates neighbor representations 59 | """ 60 | batch_size = self.batch_size 61 | max_local_entity = self.max_local_entity 62 | # num_relation = self.num_relation 63 | rel_features = self.rel_features 64 | 65 | fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels) 66 | 67 | fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids) 68 | fact_val = F.relu(rel_linear(fact_rel) * fact_query) 69 | fact_prior = torch.sparse.mm(self.head2fact_mat, curr_dist.view(-1, 1)) 70 | 71 | fact_val = fact_val * fact_prior 72 | 73 | f2e_emb = torch.sparse.mm(self.fact2tail_mat, fact_val) 74 | assert not torch.isnan(f2e_emb).any() 75 | 76 | neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim) 77 | 78 | return neighbor_rep 79 | 80 | def reason_layer_inv(self, curr_dist, instruction, rel_linear): 81 | batch_size = self.batch_size 82 | max_local_entity = self.max_local_entity 83 | # num_relation = self.num_relation 84 | rel_features = self.rel_features_inv 85 | 86 | fact_rel = torch.index_select(rel_features, dim=0, index=self.batch_rels) 87 | 88 | fact_query = torch.index_select(instruction, dim=0, index=self.batch_ids) 89 | fact_val = F.relu(rel_linear(fact_rel) * fact_query) 90 | fact_prior = torch.sparse.mm(self.tail2fact_mat, curr_dist.view(-1, 1)) 91 | 92 | 93 | fact_val = fact_val * fact_prior 94 | 95 | f2e_emb = torch.sparse.mm(self.fact2head_mat, fact_val) 96 | assert not torch.isnan(f2e_emb).any() 97 | 98 | neighbor_rep = f2e_emb.view(batch_size, max_local_entity, self.entity_dim) 99 | 100 | return neighbor_rep 101 | 102 | def combine(self,emb): 103 | """ 104 | Combines instruction-specific representations. 105 | """ 106 | local_emb = torch.cat(emb, dim=-1) 107 | local_emb = F.relu(self.lin_m(local_emb)) 108 | 109 | score_func = self.score_func 110 | 111 | score_tp = score_func(self.linear_drop(local_emb)).squeeze(dim=2) 112 | answer_mask = self.local_entity_mask 113 | self.possible_cand.append(answer_mask) 114 | score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER 115 | current_dist = self.softmax_d1(score_tp) 116 | return current_dist, local_emb 117 | 118 | def forward(self, current_dist, relational_ins, step=0, return_score=False): 119 | """ 120 | Compute next probabilistic vectors and current node representations. 121 | """ 122 | rel_linear = getattr(self, 'rel_linear' + str(step)) 123 | e2e_linear = getattr(self, 'e2e_linear' + str(step)) 124 | # score_func = getattr(self, 'score_func' + str(step)) 125 | score_func = self.score_func 126 | neighbor_reps = [] 127 | 128 | for j in range(relational_ins.size(1)): 129 | # we do the same procedure for existing and inverse relations 130 | neighbor_rep = self.reason_layer(current_dist, relational_ins[:,j,:], rel_linear) 131 | neighbor_reps.append(neighbor_rep) 132 | 133 | neighbor_rep = self.reason_layer_inv(current_dist, relational_ins[:,j,:], rel_linear) 134 | neighbor_reps.append(neighbor_rep) 135 | 136 | neighbor_reps = torch.cat(neighbor_reps, dim=2) 137 | 138 | 139 | next_local_entity_emb = torch.cat((self.local_entity_emb, neighbor_reps), dim=2) 140 | #print(next_local_entity_emb.size()) 141 | self.local_entity_emb = F.relu(e2e_linear(self.linear_drop(next_local_entity_emb))) 142 | 143 | score_tp = score_func(self.linear_drop(self.local_entity_emb)).squeeze(dim=2) 144 | answer_mask = self.local_entity_mask 145 | self.possible_cand.append(answer_mask) 146 | score_tp = score_tp + (1 - answer_mask) * VERY_NEG_NUMBER 147 | current_dist = self.softmax_d1(score_tp) 148 | if return_score: 149 | return score_tp, current_dist 150 | 151 | 152 | return current_dist, self.local_entity_emb 153 | 154 | 155 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | from tqdm import tqdm 3 | tqdm.monitor_iterval = 0 4 | import torch 5 | import numpy as np 6 | import math, os 7 | import json 8 | 9 | def cal_accuracy(pred, answer_dist): 10 | """ 11 | pred: batch_size 12 | answer_dist: batch_size, max_local_entity 13 | """ 14 | num_correct = 0.0 15 | num_answerable = 0.0 16 | for i, l in enumerate(pred): 17 | num_correct += (answer_dist[i, l] != 0) 18 | for dist in answer_dist: 19 | if np.sum(dist) != 0: 20 | num_answerable += 1 21 | return num_correct / len(pred), num_answerable / len(pred) 22 | 23 | 24 | def f1_and_hits(answers, candidate2prob, id2entity, eps=0.5): 25 | ans = [] 26 | retrieved = [] 27 | for a in answers: 28 | ans.append(id2entity[a]) 29 | correct = 0 30 | cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True) 31 | if len(cand_list) == 0: 32 | best_ans = -1 33 | else: 34 | best_ans = cand_list[0][0] 35 | # max_prob = cand_list[0][1] 36 | tp_prob = 0.0 37 | for c, prob in cand_list: 38 | retrieved.append((id2entity[c], prob)) 39 | tp_prob += prob 40 | if c in answers: 41 | correct += 1 42 | if tp_prob > eps: 43 | break 44 | if len(answers) == 0: 45 | if len(retrieved) == 0: 46 | return 1.0, 1.0, 1.0, 1.0, 0, retrieved, ans # precision, recall, f1, hits 47 | else: 48 | return 0.0, 1.0, 0.0, 1.0, 1, retrieved , ans # precision, recall, f1, hits 49 | else: 50 | hits = float(best_ans in answers) 51 | if len(retrieved) == 0: 52 | return 1.0, 0.0, 0.0, hits, 2, retrieved , ans # precision, recall, f1, hits 53 | else: 54 | p, r = correct / len(retrieved), correct / len(answers) 55 | f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0 56 | return p, r, f1, hits, 3, retrieved, ans 57 | 58 | 59 | class Evaluator: 60 | 61 | def __init__(self, args, model, entity2id, relation2id, device): 62 | self.model = model 63 | self.args = args 64 | self.eps = args['eps'] 65 | 66 | id2entity = {idx: entity for entity, idx in entity2id.items()} 67 | self.id2entity = id2entity 68 | id2relation = {idx: relation for relation, idx in relation2id.items()} 69 | num_rel_ori = len(relation2id) 70 | 71 | if 'use_inverse_relation' in args: 72 | self.use_inverse_relation = args['use_inverse_relation'] 73 | if self.use_inverse_relation: 74 | for i in range(len(id2relation)): 75 | id2relation[i + num_rel_ori] = id2relation[i] + "_rev" 76 | 77 | if 'use_self_loop' in args: 78 | self.use_self_loop = args['use_self_loop'] 79 | if self.use_self_loop: 80 | id2relation[len(id2relation)] = "self_loop" 81 | 82 | self.id2relation = id2relation 83 | self.file_write = None 84 | self.device = device 85 | 86 | def write_info(self, valid_data, tp_list, num_step): 87 | question_list = valid_data.get_quest() 88 | #num_step = steps 89 | obj_list = [] 90 | if tp_list is not None: 91 | # attn_list = [tp[1] for tp in tp_list] 92 | action_list = [tp[0] for tp in tp_list] 93 | for i in range(len(question_list)): 94 | obj_list.append({}) 95 | for j in range(num_step): 96 | if tp_list is None: 97 | actions = None 98 | else: 99 | actions = action_list[j] 100 | actions = actions.cpu().numpy() 101 | # if attn_list is not None: 102 | # attention = attn_list[j].cpu().numpy() 103 | for i in range(len(question_list)): 104 | tp_obj = obj_list[i] 105 | q = question_list[i] 106 | # real_index = self.true_batch_id[i][0] 107 | tp_obj['question'] = q 108 | tp_obj[j] = {} 109 | # print(actions) 110 | if tp_list is not None: 111 | action = actions[i] 112 | rel_action = self.id2relation[action] 113 | tp_obj[j]['rel_action'] = rel_action 114 | tp_obj[j]['action'] = str(action) 115 | # if attn_list is not None: 116 | # attention_tp = attention[i] 117 | # tp_obj[j]['attention'] = attention_tp.tolist() 118 | return obj_list 119 | 120 | def evaluate(self, valid_data, test_batch_size=20, write_info=False): 121 | #write_info = True 122 | self.model.eval() 123 | self.count = 0 124 | eps = self.eps 125 | id2entity = self.id2entity 126 | eval_loss, eval_acc, eval_max_acc = [], [], [] 127 | f1s, hits, precisions, recalls = [], [], [], [] 128 | valid_data.reset_batches(is_sequential=True) 129 | num_epoch = math.ceil(valid_data.num_data / test_batch_size) 130 | if write_info and self.file_write is None: 131 | filename = os.path.join(self.args['checkpoint_dir'], 132 | "{}_test.info".format(self.args['experiment_name'])) 133 | self.file_write = open(filename, "w") 134 | case_ct = {} 135 | max_local_entity = valid_data.max_local_entity 136 | ignore_prob = (1 - eps) / max_local_entity 137 | for iteration in tqdm(range(num_epoch)): 138 | batch = valid_data.get_batch(iteration, test_batch_size, fact_dropout=0.0, test=True) 139 | with torch.no_grad(): 140 | loss, extras, pred_dist, tp_list = self.model(batch[:-1]) 141 | pred = torch.max(pred_dist, dim=1)[1] 142 | local_entity, query_entities, _, query_text, \ 143 | seed_dist, true_batch_id, answer_dist, answer_list = batch 144 | # self.true_batch_id = true_batch_id 145 | if write_info: 146 | obj_list = self.write_info(valid_data, tp_list, self.model.num_iter) 147 | # pred_sum = torch.sum(pred_dist, dim=1) 148 | # print(pred_sum) 149 | candidate_entities = torch.from_numpy(local_entity).type('torch.LongTensor') 150 | true_answers = torch.from_numpy(answer_dist).type('torch.FloatTensor') 151 | query_entities = torch.from_numpy(query_entities).type('torch.LongTensor') 152 | # acc, max_acc = cal_accuracy(pred, true_answers.cpu().numpy()) 153 | eval_loss.append(loss.item()) 154 | # eval_acc.append(acc) 155 | # eval_max_acc.append(max_acc) 156 | #pr_dist2 = pred_dist#.copy() 157 | #pred_dist = pr_dist2[-1] 158 | batch_size = pred_dist.size(0) 159 | batch_answers = answer_list 160 | batch_candidates = candidate_entities 161 | pad_ent_id = len(id2entity) 162 | #pr_dist2 = pred_dist.copy() 163 | #for pred_dist in pr_dist2: 164 | for batch_id in range(batch_size): 165 | answers = batch_answers[batch_id] 166 | candidates = batch_candidates[batch_id, :].tolist() 167 | probs = pred_dist[batch_id, :].tolist() 168 | seed_entities = query_entities[batch_id, :].tolist() 169 | #print(seed_entities) 170 | #print(candidates) 171 | candidate2prob = [] 172 | for c, p, s in zip(candidates, probs, seed_entities): 173 | if s == 1.0: 174 | # ignore seed entities 175 | #print(c, self.id2entity) 176 | # print(c, p, s) 177 | # if c < pad_ent_id: 178 | # tp_obj['seed'] = self.id2entity[c] 179 | continue 180 | if c == pad_ent_id: 181 | continue 182 | if p < ignore_prob: 183 | continue 184 | candidate2prob.append((c, p)) 185 | precision, recall, f1, hit, case, retrived , ans = f1_and_hits(answers, candidate2prob, self.id2entity, eps) 186 | if write_info: 187 | tp_obj = obj_list[batch_id] 188 | tp_obj['answers'] = ans 189 | tp_obj['precison'] = precision 190 | tp_obj['recall'] = recall 191 | tp_obj['f1'] = f1 192 | tp_obj['hit'] = hit 193 | tp_obj['cand'] = retrived 194 | self.file_write.write(json.dumps(tp_obj) + "\n") 195 | case_ct.setdefault(case, 0) 196 | case_ct[case] += 1 197 | f1s.append(f1) 198 | hits.append(hit) 199 | precisions.append(precision) 200 | recalls.append(recall) 201 | print('evaluation.......') 202 | # print('how many eval samples......', len(f1s)) 203 | # # print('avg_f1', np.mean(f1s)) 204 | # print('avg_hits', np.mean(hits)) 205 | # print('avg_precision', np.mean(precisions)) 206 | # print('avg_recall', np.mean(recalls)) 207 | # print('avg_f1', np.mean(f1s)) 208 | print(case_ct) 209 | if write_info: 210 | self.file_write.close() 211 | self.file_write = None 212 | return np.mean(f1s), np.mean(hits) 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import create_logger 3 | import time 4 | import numpy as np 5 | import os, math 6 | 7 | import torch 8 | from torch.optim.lr_scheduler import ExponentialLR 9 | import torch.optim as optim 10 | 11 | from tqdm import tqdm 12 | tqdm.monitor_iterval = 0 13 | 14 | 15 | 16 | from dataset_load import load_data 17 | from models.ReaRev.rearev import ReaRev 18 | from evaluate import Evaluator 19 | 20 | class Trainer_KBQA(object): 21 | def __init__(self, args, model_name, logger=None): 22 | #print('Trainer here') 23 | self.args = args 24 | self.logger = logger 25 | self.best_dev_performance = 0.0 26 | self.best_h1 = 0.0 27 | self.best_f1 = 0.0 28 | self.best_h1b = 0.0 29 | self.best_f1b = 0.0 30 | self.eps = args['eps'] 31 | self.learning_rate = self.args['lr'] 32 | self.test_batch_size = args['test_batch_size'] 33 | self.device = torch.device('cuda' if args['use_cuda'] else 'cpu') 34 | self.reset_time = 0 35 | self.load_data(args, args['lm']) 36 | 37 | 38 | 39 | if 'decay_rate' in args: 40 | self.decay_rate = args['decay_rate'] 41 | else: 42 | self.decay_rate = 0.98 43 | 44 | assert model_name == 'ReaRev' 45 | 46 | self.model = ReaRev(self.args, len(self.entity2id), self.num_kb_relation, 47 | self.num_word) 48 | 49 | if args['relation_word_emb']: 50 | #self.model.use_rel_texts(self.rel_texts, self.rel_texts_inv) 51 | self.model.encode_rel_texts(self.rel_texts, self.rel_texts_inv) 52 | 53 | 54 | self.model.to(self.device) 55 | self.evaluator = Evaluator(args=args, model=self.model, entity2id=self.entity2id, 56 | relation2id=self.relation2id, device=self.device) 57 | self.load_pretrain() 58 | self.optim_def() 59 | 60 | self.num_relation = self.num_kb_relation 61 | self.num_entity = len(self.entity2id) 62 | self.num_word = len(self.word2id) 63 | 64 | 65 | print("Entity: {}, Relation: {}, Word: {}".format(self.num_entity, self.num_relation, self.num_word)) 66 | 67 | for k, v in args.items(): 68 | if k.endswith('dim'): 69 | setattr(self, k, v) 70 | if k.endswith('emb_file') or k.endswith('kge_file'): 71 | if v is None: 72 | setattr(self, k, None) 73 | else: 74 | setattr(self, k, args['data_folder'] + v) 75 | 76 | def optim_def(self): 77 | 78 | trainable = filter(lambda p: p.requires_grad, self.model.parameters()) 79 | self.optim_model = optim.Adam(trainable, lr=self.learning_rate) 80 | if self.decay_rate > 0: 81 | self.scheduler = ExponentialLR(self.optim_model, self.decay_rate) 82 | 83 | def load_data(self, args, tokenize): 84 | dataset = load_data(args, tokenize) 85 | self.train_data = dataset["train"] 86 | self.valid_data = dataset["valid"] 87 | self.test_data = dataset["test"] 88 | self.entity2id = dataset["entity2id"] 89 | self.relation2id = dataset["relation2id"] 90 | self.word2id = dataset["word2id"] 91 | self.num_word = dataset["num_word"] 92 | self.num_kb_relation = self.test_data.num_kb_relation 93 | self.num_entity = len(self.entity2id) 94 | self.rel_texts = dataset["rel_texts"] 95 | self.rel_texts_inv = dataset["rel_texts_inv"] 96 | 97 | def load_pretrain(self): 98 | args = self.args 99 | if args['load_experiment'] is not None: 100 | ckpt_path = os.path.join(args['checkpoint_dir'], args['load_experiment']) 101 | print("Load ckpt from", ckpt_path) 102 | self.load_ckpt(ckpt_path) 103 | 104 | def evaluate(self, data, test_batch_size=20, write_info=False): 105 | return self.evaluator.evaluate(data, test_batch_size, write_info) 106 | 107 | def train(self, start_epoch, end_epoch): 108 | # self.load_pretrain() 109 | eval_every = self.args['eval_every'] 110 | # eval_acc = inference(self.model, self.valid_data, self.entity2id, self.args) 111 | # self.evaluate(self.test_data, self.test_batch_size) 112 | print("Start Training------------------") 113 | for epoch in range(start_epoch, end_epoch + 1): 114 | st = time.time() 115 | 116 | #self.train_epoch2() 117 | loss, extras, h1_list_all, f1_list_all = self.train_epoch() 118 | 119 | if self.decay_rate > 0: 120 | self.scheduler.step() 121 | 122 | self.logger.info("Epoch: {}, loss : {:.4f}, time: {}".format(epoch + 1, loss, time.time() - st)) 123 | self.logger.info("Training h1 : {:.4f}, f1 : {:.4f}".format(np.mean(h1_list_all), np.mean(f1_list_all))) 124 | 125 | if (epoch + 1) % eval_every == 0: 126 | eval_f1, eval_h1 = self.evaluate(self.valid_data, self.test_batch_size) 127 | self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 128 | # eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size) 129 | # self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 130 | do_test = False 131 | if eval_h1 > self.best_h1: 132 | self.best_h1 = eval_h1 133 | self.save_ckpt("h1") 134 | self.logger.info("BEST EVAL H1: {:.4f}".format(eval_h1)) 135 | do_test = True 136 | if eval_f1 > self.best_f1: 137 | self.best_f1 = eval_f1 138 | self.save_ckpt("f1") 139 | self.logger.info("BEST EVAL F1: {:.4f}".format(eval_f1)) 140 | do_test = True 141 | 142 | eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size) 143 | self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 144 | # if do_test: 145 | # eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size) 146 | # self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 147 | 148 | # if eval_h1 > self.best_h1: 149 | # self.best_h1 = eval_h1 150 | # self.save_ckpt("h1") 151 | # if eval_f1 > self.best_f1: 152 | # self.best_f1 = eval_f1 153 | # self.save_ckpt("f1") 154 | # self.reset_time = 0 155 | # else: 156 | # self.logger.info('No improvement after one evaluation iter.') 157 | # self.reset_time += 1 158 | # if self.reset_time >= 5: 159 | # self.logger.info('No improvement after 5 evaluation. Early Stopping.') 160 | # break 161 | self.save_ckpt("final") 162 | self.logger.info('Train Done! Evaluate on testset with saved model') 163 | print("End Training------------------") 164 | self.evaluate_best() 165 | 166 | def evaluate_best(self): 167 | filename = os.path.join(self.args['checkpoint_dir'], "{}-h1.ckpt".format(self.args['experiment_name'])) 168 | self.load_ckpt(filename) 169 | eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, write_info=False) 170 | self.logger.info("Best h1 evaluation") 171 | self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 172 | 173 | filename = os.path.join(self.args['checkpoint_dir'], "{}-f1.ckpt".format(self.args['experiment_name'])) 174 | self.load_ckpt(filename) 175 | eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, write_info=False) 176 | self.logger.info("Best f1 evaluation") 177 | self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 178 | 179 | filename = os.path.join(self.args['checkpoint_dir'], "{}-final.ckpt".format(self.args['experiment_name'])) 180 | self.load_ckpt(filename) 181 | eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, write_info=False) 182 | self.logger.info("Final evaluation") 183 | self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1)) 184 | 185 | def evaluate_single(self, filename): 186 | if filename is not None: 187 | self.load_ckpt(filename) 188 | eval_f1, eval_hits = self.evaluate(self.valid_data, self.test_batch_size, write_info=False) 189 | self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_hits)) 190 | test_f1, test_hits = self.evaluate(self.test_data, self.test_batch_size, write_info=True) 191 | self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(test_f1, test_hits)) 192 | 193 | def train_epoch(self): 194 | self.model.train() 195 | self.train_data.reset_batches(is_sequential=False) 196 | losses = [] 197 | actor_losses = [] 198 | ent_losses = [] 199 | num_epoch = math.ceil(self.train_data.num_data / self.args['batch_size']) 200 | h1_list_all = [] 201 | f1_list_all = [] 202 | for iteration in tqdm(range(num_epoch)): 203 | batch = self.train_data.get_batch(iteration, self.args['batch_size'], self.args['fact_drop']) 204 | 205 | self.optim_model.zero_grad() 206 | loss, _, _, tp_list = self.model(batch, training=True) 207 | # if tp_list is not None: 208 | h1_list, f1_list = tp_list 209 | h1_list_all.extend(h1_list) 210 | f1_list_all.extend(f1_list) 211 | loss.backward() 212 | torch.nn.utils.clip_grad_norm_([param for name, param in self.model.named_parameters()], 213 | self.args['gradient_clip']) 214 | self.optim_model.step() 215 | losses.append(loss.item()) 216 | extras = [0, 0] 217 | return np.mean(losses), extras, h1_list_all, f1_list_all 218 | 219 | 220 | def save_ckpt(self, reason="h1"): 221 | model = self.model 222 | checkpoint = { 223 | 'model_state_dict': model.state_dict() 224 | } 225 | model_name = os.path.join(self.args['checkpoint_dir'], "{}-{}.ckpt".format(self.args['experiment_name'], 226 | reason)) 227 | torch.save(checkpoint, model_name) 228 | print("Best %s, save model as %s" %(reason, model_name)) 229 | 230 | def load_ckpt(self, filename): 231 | checkpoint = torch.load(filename) 232 | model_state_dict = checkpoint["model_state_dict"] 233 | 234 | model = self.model 235 | #self.logger.info("Load param of {} from {}.".format(", ".join(list(model_state_dict.keys())), filename)) 236 | model.load_state_dict(model_state_dict, strict=False) 237 | 238 | -------------------------------------------------------------------------------- /models/ReaRev/rearev.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | from models.base_model import BaseModel 8 | from modules.kg_reasoning.reasongnn import ReasonGNNLayer 9 | from modules.question_encoding.lstm_encoder import LSTMInstruction 10 | from modules.question_encoding.bert_encoder import BERTInstruction 11 | from modules.layer_init import TypeLayer 12 | from modules.query_update import AttnEncoder, Fusion, QueryReform 13 | 14 | VERY_SMALL_NUMBER = 1e-10 15 | VERY_NEG_NUMBER = -100000000000 16 | 17 | 18 | 19 | class ReaRev(BaseModel): 20 | def __init__(self, args, num_entity, num_relation, num_word): 21 | """ 22 | Init ReaRev model. 23 | """ 24 | super(ReaRev, self).__init__(args, num_entity, num_relation, num_word) 25 | #self.embedding_def() 26 | #self.share_module_def() 27 | self.layers(args) 28 | 29 | 30 | self.loss_type = args['loss_type'] 31 | self.num_iter = args['num_iter'] 32 | self.num_ins = args['num_ins'] 33 | self.num_gnn = args['num_gnn'] 34 | self.alg = args['alg'] 35 | assert self.alg == 'bfs' 36 | self.lm = args['lm'] 37 | 38 | self.private_module_def(args, num_entity, num_relation) 39 | 40 | self.to(self.device) 41 | self.lin = nn.Linear(3*self.entity_dim, self.entity_dim) 42 | 43 | self.fusion = Fusion(self.entity_dim) 44 | self.reforms = [] 45 | for i in range(self.num_ins): 46 | self.add_module('reform' + str(i), QueryReform(self.entity_dim)) 47 | # self.reform_rel = QueryReform(self.entity_dim) 48 | # self.add_module('reform', QueryReform(self.entity_dim)) 49 | 50 | def layers(self, args): 51 | # initialize entity embedding 52 | word_dim = self.word_dim 53 | kg_dim = self.kg_dim 54 | entity_dim = self.entity_dim 55 | 56 | #self.lstm_dropout = args['lstm_dropout'] 57 | self.linear_dropout = args['linear_dropout'] 58 | 59 | self.entity_linear = nn.Linear(in_features=self.ent_dim, out_features=entity_dim) 60 | # self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim) 61 | # self.relation_linear_inv = nn.Linear(in_features=self.rel_dim, out_features=entity_dim) 62 | #self.relation_linear = nn.Linear(in_features=self.rel_dim, out_features=entity_dim) 63 | 64 | # dropout 65 | #self.lstm_drop = nn.Dropout(p=self.lstm_dropout) 66 | self.linear_drop = nn.Dropout(p=self.linear_dropout) 67 | 68 | if self.encode_type: 69 | self.type_layer = TypeLayer(in_features=entity_dim, out_features=entity_dim, 70 | linear_drop=self.linear_drop, device=self.device) 71 | 72 | self.self_att_r = AttnEncoder(self.entity_dim) 73 | #self.self_att_r_inv = AttnEncoder(self.entity_dim) 74 | self.kld_loss = nn.KLDivLoss(reduction='none') 75 | self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none') 76 | self.mse_loss = torch.nn.MSELoss() 77 | 78 | def get_ent_init(self, local_entity, kb_adj_mat, rel_features): 79 | if self.encode_type: 80 | local_entity_emb = self.type_layer(local_entity=local_entity, 81 | edge_list=kb_adj_mat, 82 | rel_features=rel_features) 83 | else: 84 | local_entity_emb = self.entity_embedding(local_entity) # batch_size, max_local_entity, word_dim 85 | local_entity_emb = self.entity_linear(local_entity_emb) 86 | 87 | return local_entity_emb 88 | 89 | 90 | def get_rel_feature(self): 91 | """ 92 | Encode relation tokens to vectors. 93 | """ 94 | if self.rel_texts is None: 95 | rel_features = self.relation_embedding.weight 96 | rel_features_inv = self.relation_embedding_inv.weight 97 | rel_features = self.relation_linear(rel_features) 98 | rel_features_inv = self.relation_linear(rel_features_inv) 99 | else: 100 | 101 | rel_features = self.instruction.question_emb(self.rel_features) 102 | rel_features_inv = self.instruction.question_emb(self.rel_features_inv) 103 | 104 | rel_features = self.self_att_r(rel_features, (self.rel_texts != self.instruction.pad_val).float()) 105 | rel_features_inv = self.self_att_r(rel_features_inv, (self.rel_texts != self.instruction.pad_val).float()) 106 | if self.lm == 'lstm': 107 | rel_features = self.self_att_r(rel_features, (self.rel_texts != self.num_relation+1).float()) 108 | rel_features_inv = self.self_att_r(rel_features_inv, (self.rel_texts_inv != self.num_relation+1).float()) 109 | 110 | return rel_features, rel_features_inv 111 | 112 | 113 | def private_module_def(self, args, num_entity, num_relation): 114 | """ 115 | Building modules: LM encoder, GNN, etc. 116 | """ 117 | # initialize entity embedding 118 | word_dim = self.word_dim 119 | kg_dim = self.kg_dim 120 | entity_dim = self.entity_dim 121 | self.reasoning = ReasonGNNLayer(args, num_entity, num_relation, entity_dim, self.alg) 122 | if args['lm'] == 'lstm': 123 | self.instruction = LSTMInstruction(args, self.word_embedding, self.num_word) 124 | self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim) 125 | else: 126 | self.instruction = BERTInstruction(args, self.word_embedding, self.num_word, args['lm']) 127 | #self.relation_linear = nn.Linear(in_features=self.instruction.word_dim, out_features=entity_dim) 128 | # self.relation_linear = nn.Linear(in_features=entity_dim, out_features=entity_dim) 129 | # self.relation_linear_inv = nn.Linear(in_features=entity_dim, out_features=entity_dim) 130 | 131 | def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input, query_entities): 132 | """ 133 | Initializing Reasoning 134 | """ 135 | # batch_size = local_entity.size(0) 136 | self.local_entity = local_entity 137 | self.instruction_list, self.attn_list = self.instruction(q_input) 138 | rel_features, rel_features_inv = self.get_rel_feature() 139 | self.local_entity_emb = self.get_ent_init(local_entity, kb_adj_mat, rel_features) 140 | self.init_entity_emb = self.local_entity_emb 141 | self.curr_dist = curr_dist 142 | self.dist_history = [] 143 | self.action_probs = [] 144 | self.seed_entities = curr_dist 145 | 146 | self.reasoning.init_reason( 147 | local_entity=local_entity, 148 | kb_adj_mat=kb_adj_mat, 149 | local_entity_emb=self.local_entity_emb, 150 | rel_features=rel_features, 151 | rel_features_inv=rel_features_inv, 152 | query_entities=query_entities) 153 | 154 | 155 | def calc_loss_label(self, curr_dist, teacher_dist, label_valid): 156 | tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist, reduction='none') 157 | tp_loss = tp_loss * label_valid 158 | cur_loss = torch.sum(tp_loss) / curr_dist.size(0) 159 | return cur_loss 160 | 161 | 162 | def forward(self, batch, training=False): 163 | """ 164 | Forward function: creates instructions and performs GNN reasoning. 165 | """ 166 | 167 | # local_entity, query_entities, kb_adj_mat, query_text, seed_dist, answer_dist = batch 168 | local_entity, query_entities, kb_adj_mat, query_text, seed_dist, true_batch_id, answer_dist = batch 169 | local_entity = torch.from_numpy(local_entity).type('torch.LongTensor').to(self.device) 170 | # local_entity_mask = (local_entity != self.num_entity).float() 171 | query_entities = torch.from_numpy(query_entities).type('torch.FloatTensor').to(self.device) 172 | answer_dist = torch.from_numpy(answer_dist).type('torch.FloatTensor').to(self.device) 173 | seed_dist = torch.from_numpy(seed_dist).type('torch.FloatTensor').to(self.device) 174 | current_dist = Variable(seed_dist, requires_grad=True) 175 | 176 | q_input= torch.from_numpy(query_text).type('torch.LongTensor').to(self.device) 177 | #query_text2 = torch.from_numpy(query_text2).type('torch.LongTensor').to(self.device) 178 | if self.lm != 'lstm': 179 | pad_val = self.instruction.pad_val #tokenizer.convert_tokens_to_ids(self.instruction.tokenizer.pad_token) 180 | query_mask = (q_input != pad_val).float() 181 | 182 | else: 183 | query_mask = (q_input != self.num_word).float() 184 | 185 | 186 | """ 187 | Instruction generations 188 | """ 189 | self.init_reason(curr_dist=current_dist, local_entity=local_entity, 190 | kb_adj_mat=kb_adj_mat, q_input=q_input, query_entities=query_entities) 191 | self.instruction.init_reason(q_input) 192 | for i in range(self.num_ins): 193 | relational_ins, attn_weight = self.instruction.get_instruction(self.instruction.relational_ins, step=i) 194 | self.instruction.instructions.append(relational_ins.unsqueeze(1)) 195 | self.instruction.relational_ins = relational_ins 196 | #relation_ins = torch.cat(self.instruction.instructions, dim=1) 197 | #query_emb = None 198 | self.dist_history.append(self.curr_dist) 199 | 200 | 201 | """ 202 | BFS + GNN reasoning 203 | """ 204 | 205 | for t in range(self.num_iter): 206 | relation_ins = torch.cat(self.instruction.instructions, dim=1) 207 | self.curr_dist = current_dist 208 | for j in range(self.num_gnn): 209 | self.curr_dist, global_rep = self.reasoning(self.curr_dist, relation_ins, step=j) 210 | self.dist_history.append(self.curr_dist) 211 | qs = [] 212 | 213 | """ 214 | Instruction Updates 215 | """ 216 | for j in range(self.num_ins): 217 | reform = getattr(self, 'reform' + str(j)) 218 | q = reform(self.instruction.instructions[j].squeeze(1), global_rep, query_entities, local_entity) 219 | qs.append(q.unsqueeze(1)) 220 | self.instruction.instructions[j] = q.unsqueeze(1) 221 | 222 | 223 | """ 224 | Answer Predictions 225 | """ 226 | pred_dist = self.dist_history[-1] 227 | answer_number = torch.sum(answer_dist, dim=1, keepdim=True) 228 | case_valid = (answer_number > 0).float() 229 | # filter no answer training case 230 | # loss = 0 231 | # for pred_dist in self.dist_history: 232 | loss = self.calc_loss_label(curr_dist=pred_dist, teacher_dist=answer_dist, label_valid=case_valid) 233 | 234 | 235 | pred_dist = self.dist_history[-1] 236 | pred = torch.max(pred_dist, dim=1)[1] 237 | if training: 238 | h1, f1 = self.get_eval_metric(pred_dist, answer_dist) 239 | tp_list = [h1.tolist(), f1.tolist()] 240 | else: 241 | tp_list = None 242 | return loss, pred, pred_dist, tp_list 243 | 244 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | 7 | VERY_SMALL_NUMBER = 1e-10 8 | 9 | class BaseModel(torch.nn.Module): 10 | """ 11 | Base model functions: create embeddings, store relations, compute f1/h1 scores, etc. 12 | """ 13 | 14 | def __init__(self, args, num_entity, num_relation, num_word): 15 | super(BaseModel, self).__init__() 16 | self.num_relation = num_relation 17 | self.num_entity = num_entity 18 | self.num_word = num_word 19 | print('Num Word', self.num_word) 20 | self.kge_frozen = args['kge_frozen'] 21 | self.kg_dim = args['kg_dim'] 22 | #self._parse_args(args) 23 | self.entity_emb_file = args['entity_emb_file'] 24 | self.relation_emb_file = args['relation_emb_file'] 25 | self.relation_word_emb = args['relation_word_emb'] 26 | self.word_emb_file = args['word_emb_file'] 27 | self.entity_dim = args['entity_dim'] 28 | 29 | self.lm = args['lm'] 30 | if self.lm in ['bert']: 31 | #self.word_dim = 768 32 | args['word_dim'] = 768 33 | 34 | self.word_dim = args['word_dim'] 35 | 36 | self.rel_texts = None 37 | 38 | 39 | #self.share_module_def() 40 | #self.model_name = args['model_name'].lower() 41 | self.device = torch.device('cuda' if args['use_cuda'] else 'cpu') 42 | 43 | print("Entity: {}, Relation: {}, Word: {}".format(num_entity, num_relation, num_word)) 44 | 45 | 46 | self.kld_loss = nn.KLDivLoss(reduction='none') 47 | self.bce_loss_logits = nn.BCEWithLogitsLoss(reduction='none') 48 | self.mse_loss = torch.nn.MSELoss() 49 | 50 | for k, v in args.items(): 51 | if k.endswith('dim'): 52 | setattr(self, k, v) 53 | if k.endswith('emb_file') or k.endswith('kge_file'): 54 | if v is None: 55 | setattr(self, k, None) 56 | else: 57 | setattr(self, k, args['data_folder'] + v) 58 | 59 | self.reset_time = 0 60 | 61 | if 'use_inverse_relation' in args: 62 | self.use_inverse_relation = args['use_inverse_relation'] 63 | if 'use_self_loop' in args: 64 | self.use_self_loop = args['use_self_loop'] 65 | self.eps = args['eps'] 66 | 67 | self.embedding_def() 68 | args['word_dim'] = self.word_dim 69 | 70 | def embedding_def(self): 71 | num_entity = self.num_entity 72 | num_relation = self.num_relation 73 | num_word = self.num_word 74 | 75 | if self.lm != 'lstm': 76 | self.word_dim = 768 77 | self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim, 78 | padding_idx=num_word) 79 | elif self.word_emb_file is not None: 80 | word_emb = np.load(self.word_emb_file) 81 | _ , self.word_dim = word_emb.shape 82 | print('Word emb dim', self.word_dim) 83 | self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim, 84 | padding_idx=num_word) 85 | self.word_embedding.weight = nn.Parameter( 86 | torch.from_numpy( 87 | np.pad(np.load(self.word_emb_file), ((0, 1), (0, 0)), 'constant')).type( 88 | 'torch.FloatTensor')) 89 | self.word_embedding.weight.requires_grad = False 90 | else: 91 | #self.word_dim = 768 92 | self.word_embedding = nn.Embedding(num_embeddings=num_word + 1, embedding_dim=self.word_dim, 93 | padding_idx=num_word) 94 | 95 | 96 | if self.entity_emb_file is not None: 97 | self.encode_type = False 98 | emb = np.load(self.entity_emb_file) 99 | ent_num , self.ent_dim = emb.shape 100 | # if ent_num != num_entity: 101 | # print('Number of entities in KG embeddings do not match: Random Init.') 102 | 103 | self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim, 104 | padding_idx=num_entity) 105 | if ent_num != num_entity: 106 | print('Number of entities in KG embeddings do not match: Random Init.') 107 | else: 108 | self.entity_embedding.weight = nn.Parameter( 109 | torch.from_numpy(np.pad(emb, ((0, 1), (0, 0)), 'constant')).type( 110 | 'torch.FloatTensor')) 111 | if self.kge_frozen: 112 | self.entity_embedding.weight.requires_grad = False 113 | else: 114 | self.entity_embedding.weight.requires_grad = True 115 | else: 116 | self.ent_dim = self.kg_dim 117 | self.encode_type = True 118 | #self.entity_embedding = nn.Embedding(num_embeddings=num_entity + 1, embedding_dim=self.ent_dim, 119 | #padding_idx=num_entity) 120 | 121 | #print 122 | 123 | 124 | # initialize relation embedding 125 | if self.relation_emb_file is not None: 126 | np_tensor = self.load_relation_file(self.relation_emb_file) 127 | #print('check?', np_tensor.shape) 128 | rel_num, self.rel_dim = np_tensor.shape 129 | self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim) 130 | if rel_num != num_relation: 131 | print('Number of relations in KG embeddings do not match: Random Init.') 132 | else: 133 | self.relation_embedding.weight = nn.Parameter(torch.from_numpy(np_tensor).type('torch.FloatTensor')) 134 | if self.kge_frozen: 135 | self.relation_embedding.weight.requires_grad = False 136 | else: 137 | self.relation_embedding.weight.requires_grad = True 138 | 139 | elif self.relation_word_emb: 140 | self.rel_dim = self.entity_dim 141 | self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim) 142 | self.relation_embedding.weight.requires_grad = True 143 | self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim) 144 | self.relation_embedding_inv.weight.requires_grad = True 145 | pass 146 | else: 147 | self.rel_dim = 2*self.kg_dim 148 | self.relation_embedding = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim) 149 | self.relation_embedding_inv = nn.Embedding(num_embeddings=num_relation+1, embedding_dim=self.rel_dim) 150 | 151 | # initialize text embeddings 152 | 153 | 154 | 155 | 156 | def load_relation_file(self, filename): 157 | half_tensor = np.load(filename) 158 | num_pad = 0 159 | if self.use_self_loop: 160 | num_pad = 2 161 | if self.use_inverse_relation: 162 | load_tensor = np.concatenate([half_tensor, half_tensor]) 163 | else: 164 | load_tensor = half_tensor 165 | return np.pad(load_tensor, ((0, num_pad), (0, 0)), 'constant') 166 | 167 | def use_rel_texts(self, rel_texts, rel_texts_inv): 168 | self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device) 169 | self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device) 170 | 171 | def encode_rel_texts(self, rel_texts, rel_texts_inv): 172 | self.rel_texts = torch.from_numpy(rel_texts).type('torch.LongTensor').to(self.device) 173 | self.rel_texts_inv = torch.from_numpy(rel_texts_inv).type('torch.LongTensor').to(self.device) 174 | self.instruction.eval() 175 | with torch.no_grad(): 176 | self.rel_features = self.instruction.encode_question(self.rel_texts, store=False) 177 | self.rel_features_inv = self.instruction.encode_question(self.rel_texts_inv, store=False) 178 | self.rel_features.requires_grad = False 179 | self.rel_features_inv.requires_grad = False 180 | 181 | def init_hidden(self, num_layer, batch_size, hidden_size): 182 | return self.instruction.init_hidden(num_layer, batch_size, hidden_size) 183 | 184 | def encode_question(self, q_input): 185 | return self.instruction.encode_question(q_input) 186 | 187 | def get_instruction(self, query_hidden_emb, query_mask, states): 188 | return self.instruction.get_instruction(query_hidden_emb, query_mask, states) 189 | 190 | def get_loss_bce(self, pred_dist_score, answer_dist): 191 | answer_dist = (answer_dist > 0).float() * 0.9 # label smooth 192 | # answer_dist = answer_dist * 0.9 # label smooth 193 | loss = self.bce_loss_logits(pred_dist_score, answer_dist) 194 | return loss 195 | 196 | def get_loss_kl(self, pred_dist, answer_dist): 197 | answer_len = torch.sum(answer_dist, dim=1, keepdim=True) 198 | answer_len[answer_len == 0] = 1.0 199 | answer_prob = answer_dist.div(answer_len) 200 | log_prob = torch.log(pred_dist + 1e-8) 201 | loss = self.kld_loss(log_prob, answer_prob) 202 | return loss 203 | 204 | def get_loss(self, pred_dist, answer_dist, reduction='mean'): 205 | if self.loss_type == "bce": 206 | tp_loss = self.get_loss_bce(pred_dist, answer_dist) 207 | if reduction == 'none': 208 | return tp_loss 209 | else: 210 | # mean 211 | return torch.mean(tp_loss) 212 | else: 213 | tp_loss = self.get_loss_kl(pred_dist, answer_dist) 214 | if reduction == 'none': 215 | return tp_loss 216 | else: 217 | # batchmean 218 | return torch.sum(tp_loss) / pred_dist.size(0) 219 | 220 | def f1_and_hits(self, answers, candidate2prob, eps=0.5): 221 | retrieved = [] 222 | correct = 0 223 | cand_list = sorted(candidate2prob, key=lambda x:x[1], reverse=True) 224 | if len(cand_list) == 0: 225 | best_ans = -1 226 | else: 227 | best_ans = cand_list[0][0] 228 | # max_prob = cand_list[0][1] 229 | tp_prob = 0.0 230 | for c, prob in cand_list: 231 | retrieved.append((c, prob)) 232 | tp_prob += prob 233 | if c in answers: 234 | correct += 1 235 | if tp_prob > eps: 236 | break 237 | if len(answers) == 0: 238 | if len(retrieved) == 0: 239 | return 1.0, 1.0, 1.0, 1.0 # precision, recall, f1, hits 240 | else: 241 | return 0.0, 1.0, 0.0, 1.0 # precision, recall, f1, hits 242 | else: 243 | hits = float(best_ans in answers) 244 | if len(retrieved) == 0: 245 | return 1.0, 0.0, 0.0, hits # precision, recall, f1, hits 246 | else: 247 | p, r = correct / len(retrieved), correct / len(answers) 248 | f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0 249 | return p, r, f1, hits 250 | 251 | 252 | def calc_f1_new(self, curr_dist, dist_ans, h1_vec): 253 | batch_size = curr_dist.size(0) 254 | max_local_entity = curr_dist.size(1) 255 | seed_dist = self.seed_entities #self.dist_history[0] 256 | local_entity = self.local_entity 257 | ignore_prob = (1 - self.eps) / max_local_entity 258 | pad_ent_id = self.num_entity 259 | # hits_list = [] 260 | f1_list = [] 261 | for batch_id in range(batch_size): 262 | if h1_vec[batch_id].item() == 0.0: 263 | f1_list.append(0.0) 264 | # we consider cases which own hit@1 as prior to reduce computation time 265 | continue 266 | candidates = local_entity[batch_id, :].tolist() 267 | probs = curr_dist[batch_id, :].tolist() 268 | answer_prob = dist_ans[batch_id, :].tolist() 269 | seed_entities = seed_dist[batch_id, :].tolist() 270 | answer_list = [] 271 | candidate2prob = [] 272 | for c, p, p_a, s in zip(candidates, probs, answer_prob, seed_entities): 273 | if s > 0: 274 | # ignore seed entities 275 | continue 276 | if c == pad_ent_id: 277 | continue 278 | if p_a > 0: 279 | answer_list.append(c) 280 | if p < ignore_prob: 281 | continue 282 | candidate2prob.append((c, p)) 283 | precision, recall, f1, hits = self.f1_and_hits(answer_list, candidate2prob, self.eps) 284 | # hits_list.append(hits) 285 | f1_list.append(f1) 286 | # hits_vec = torch.FloatTensor(hits_list).to(self.device) 287 | f1_vec = torch.FloatTensor(f1_list).to(self.device) 288 | return f1_vec 289 | 290 | def calc_h1(self, curr_dist, dist_ans, eps=0.01): 291 | greedy_option = curr_dist.argmax(dim=-1, keepdim=True) 292 | dist_top1 = torch.zeros_like(curr_dist).scatter_(1, greedy_option, 1.0) 293 | dist_ans = (dist_ans > eps).float() 294 | h1 = torch.sum(dist_top1 * dist_ans, dim=-1) 295 | return (h1 > 0).float() 296 | 297 | def get_eval_metric(self, pred_dist, answer_dist): 298 | with torch.no_grad(): 299 | h1 = self.calc_h1(curr_dist=pred_dist, dist_ans=answer_dist, eps=VERY_SMALL_NUMBER) 300 | f1 = self.calc_f1_new(pred_dist, answer_dist, h1) 301 | return h1, f1 -------------------------------------------------------------------------------- /dataset_load.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import re 4 | from tqdm import tqdm 5 | import torch 6 | from collections import Counter 7 | import random 8 | import warnings 9 | import pickle 10 | warnings.filterwarnings("ignore") 11 | from modules.question_encoding.tokenizers import LSTMTokenizer#, BERTTokenizer 12 | from transformers import AutoTokenizer 13 | import time 14 | 15 | import os 16 | try: 17 | os.environ['TRANSFORMERS_CACHE'] = '/export/scratch/costas/home/mavro016/.cache' 18 | except: 19 | pass 20 | 21 | 22 | class BasicDataLoader(object): 23 | """ 24 | Basic Dataloader contains all the functions to read questions and KGs from json files and 25 | create mappings between global entity ids and local ids that are used during GNN updates. 26 | """ 27 | 28 | def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"): 29 | self.tokenize = tokenize 30 | self._parse_args(config, word2id, relation2id, entity2id) 31 | self._load_file(config, data_type) 32 | self._load_data() 33 | 34 | 35 | def _load_file(self, config, data_type="train"): 36 | 37 | """ 38 | Loads lines (questions + KG subgraphs) from json files. 39 | """ 40 | 41 | data_file = config['data_folder'] + data_type + ".json" 42 | self.data_file = data_file 43 | print('loading data from', data_file) 44 | self.data_type = data_type 45 | self.data = [] 46 | skip_index = set() 47 | index = 0 48 | 49 | with open(data_file) as f_in: 50 | for line in tqdm(f_in): 51 | if index == config['max_train'] and data_type == "train": break #break if we reach max_question_size 52 | line = json.loads(line) 53 | 54 | if len(line['entities']) == 0: 55 | skip_index.add(index) 56 | continue 57 | self.data.append(line) 58 | self.max_facts = max(self.max_facts, 2 * len(line['subgraph']['tuples'])) 59 | index += 1 60 | 61 | print("skip", skip_index) 62 | print('max_facts: ', self.max_facts) 63 | self.num_data = len(self.data) 64 | self.batches = np.arange(self.num_data) 65 | 66 | def _load_data(self): 67 | 68 | """ 69 | Creates mappings between global entity ids and local entity ids that are used during GNN updates. 70 | """ 71 | 72 | print('converting global to local entity index ...') 73 | self.global2local_entity_maps = self._build_global2local_entity_maps() 74 | 75 | if self.use_self_loop: 76 | self.max_facts = self.max_facts + self.max_local_entity 77 | 78 | self.question_id = [] 79 | self.candidate_entities = np.full((self.num_data, self.max_local_entity), len(self.entity2id), dtype=int) 80 | self.kb_adj_mats = np.empty(self.num_data, dtype=object) 81 | self.q_adj_mats = np.empty(self.num_data, dtype=object) 82 | self.kb_fact_rels = np.full((self.num_data, self.max_facts), self.num_kb_relation, dtype=int) 83 | self.query_entities = np.zeros((self.num_data, self.max_local_entity), dtype=float) 84 | self.seed_list = np.empty(self.num_data, dtype=object) 85 | self.seed_distribution = np.zeros((self.num_data, self.max_local_entity), dtype=float) 86 | # self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int) 87 | self.answer_dists = np.zeros((self.num_data, self.max_local_entity), dtype=float) 88 | self.answer_lists = np.empty(self.num_data, dtype=object) 89 | 90 | self._prepare_data() 91 | 92 | def _parse_args(self, config, word2id, relation2id, entity2id): 93 | 94 | """ 95 | Builds necessary dictionaries and stores arguments. 96 | """ 97 | self.data_eff = config['data_eff'] 98 | self.data_name = config['name'] 99 | 100 | if 'use_inverse_relation' in config: 101 | self.use_inverse_relation = config['use_inverse_relation'] 102 | else: 103 | self.use_inverse_relation = False 104 | if 'use_self_loop' in config: 105 | self.use_self_loop = config['use_self_loop'] 106 | else: 107 | self.use_self_loop = False 108 | 109 | self.rel_word_emb = config['relation_word_emb'] 110 | #self.num_step = config['num_step'] 111 | self.max_local_entity = 0 112 | self.max_relevant_doc = 0 113 | self.max_facts = 0 114 | 115 | print('building word index ...') 116 | self.word2id = word2id 117 | self.id2word = {i: word for word, i in word2id.items()} 118 | self.relation2id = relation2id 119 | self.entity2id = entity2id 120 | self.id2entity = {i: entity for entity, i in entity2id.items()} 121 | self.q_type = config['q_type'] 122 | 123 | if self.use_inverse_relation: 124 | self.num_kb_relation = 2 * len(relation2id) 125 | else: 126 | self.num_kb_relation = len(relation2id) 127 | if self.use_self_loop: 128 | self.num_kb_relation = self.num_kb_relation + 1 129 | print("Entity: {}, Relation in KB: {}, Relation in use: {} ".format(len(entity2id), 130 | len(self.relation2id), 131 | self.num_kb_relation)) 132 | 133 | 134 | def get_quest(self, training=False): 135 | q_list = [] 136 | 137 | sample_ids = self.sample_ids 138 | for sample_id in sample_ids: 139 | tp_str = self.decode_text(self.query_texts[sample_id, :]) 140 | # id2word = self.id2word 141 | # for i in range(self.max_query_word): 142 | # if self.query_texts[sample_id, i] in id2word: 143 | # tp_str += id2word[self.query_texts[sample_id, i]] + " " 144 | q_list.append(tp_str) 145 | return q_list 146 | 147 | def decode_text(self, np_array_x): 148 | if self.tokenize == 'lstm': 149 | id2word = self.id2word 150 | tp_str = "" 151 | for i in range(self.max_query_word): 152 | if np_array_x[i] in id2word: 153 | tp_str += id2word[np_array_x[i]] + " " 154 | else: 155 | tp_str = "" 156 | words = self.tokenizer.convert_ids_to_tokens(np_array_x) 157 | for w in words: 158 | if w not in ['[CLS]', '[SEP]', '[PAD]']: 159 | tp_str += w + " " 160 | return tp_str 161 | 162 | 163 | def _prepare_data(self): 164 | """ 165 | global2local_entity_maps: a map from global entity id to local entity id 166 | adj_mats: a local adjacency matrix for each relation. relation 0 is reserved for self-connection. 167 | """ 168 | max_count = 0 169 | for line in self.data: 170 | word_list = line["question"].split(' ') 171 | max_count = max(max_count, len(word_list)) 172 | 173 | 174 | if self.rel_word_emb: 175 | self.build_rel_words(self.tokenize) 176 | else: 177 | self.rel_texts = None 178 | self.ent_texts = None 179 | 180 | 181 | 182 | self.max_query_word = max_count 183 | #self.query_texts = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int) 184 | #self.query_texts2 = np.full((self.num_data, self.max_query_word), len(self.word2id), dtype=int) 185 | 186 | #build tokenizers 187 | if self.tokenize == 'lstm': 188 | self.num_word = len(self.word2id) 189 | self.tokenizer = LSTMTokenizer(self.word2id, self.max_query_word) 190 | self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int) 191 | else: 192 | if self.tokenize == 'bert': 193 | tokenizer_name = 'bert-base-uncased' 194 | elif self.tokenize == 'roberta': 195 | tokenizer_name = 'roberta-base' 196 | elif self.tokenize == 'sbert': 197 | tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2' 198 | elif self.tokenize == 'sbert2': 199 | tokenizer_name = 'sentence-transformers/all-mpnet-base-v2' 200 | elif self.tokenize == 't5': 201 | tokenizer_name = 't5-small' 202 | 203 | self.max_query_word = max_count + 2 #for cls token and sep 204 | #self.tokenizer = AutoTokenizer(self.max_query_word) 205 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 206 | self.num_word = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) #self.tokenizer.q_tokenizer.encode("[UNK]")[0] 207 | 208 | self.query_texts = np.full((self.num_data, self.max_query_word), self.num_word, dtype=int) 209 | 210 | 211 | next_id = 0 212 | num_query_entity = {} 213 | for sample in tqdm(self.data): 214 | self.question_id.append(sample["id"]) 215 | # get a list of local entities 216 | g2l = self.global2local_entity_maps[next_id] 217 | #print(g2l) 218 | if len(g2l) == 0: 219 | #print(next_id) 220 | continue 221 | # build connection between question and entities in it 222 | tp_set = set() 223 | seed_list = [] 224 | for j, entity in enumerate(sample['entities']): 225 | # if entity['text'] not in self.entity2id: 226 | # continue 227 | try: 228 | global_entity = self.entity2id[entity['text']] 229 | except: 230 | global_entity = entity #self.entity2id[entity['text']] 231 | if global_entity not in g2l: 232 | continue 233 | local_ent = g2l[global_entity] 234 | self.query_entities[next_id, local_ent] = 1.0 235 | seed_list.append(local_ent) 236 | tp_set.add(local_ent) 237 | 238 | self.seed_list[next_id] = seed_list 239 | num_query_entity[next_id] = len(tp_set) 240 | for global_entity, local_entity in g2l.items(): 241 | if self.data_name != 'cwq': 242 | 243 | if local_entity not in tp_set: # skip entities in question 244 | #print(global_entity) 245 | #print(local_entity) 246 | self.candidate_entities[next_id, local_entity] = global_entity 247 | elif self.data_name == 'cwq': 248 | self.candidate_entities[next_id, local_entity] = global_entity 249 | # if local_entity != 0: # skip question node 250 | # self.candidate_entities[next_id, local_entity] = global_entity 251 | 252 | # relations in local KB 253 | head_list = [] 254 | rel_list = [] 255 | tail_list = [] 256 | for i, tpl in enumerate(sample['subgraph']['tuples']): 257 | sbj, rel, obj = tpl 258 | try: 259 | head = g2l[self.entity2id[sbj['text']]] 260 | rel = self.relation2id[rel['text']] 261 | tail = g2l[self.entity2id[obj['text']]] 262 | except: 263 | head = g2l[sbj] 264 | rel = int(rel) 265 | tail = g2l[obj] 266 | head_list.append(head) 267 | rel_list.append(rel) 268 | tail_list.append(tail) 269 | self.kb_fact_rels[next_id, i] = rel 270 | if self.use_inverse_relation: 271 | head_list.append(tail) 272 | rel_list.append(rel + len(self.relation2id)) 273 | tail_list.append(head) 274 | self.kb_fact_rels[next_id, i] = rel + len(self.relation2id) 275 | 276 | if len(tp_set) > 0: 277 | for local_ent in tp_set: 278 | self.seed_distribution[next_id, local_ent] = 1.0 / len(tp_set) 279 | else: 280 | for index in range(len(g2l)): 281 | self.seed_distribution[next_id, index] = 1.0 / len(g2l) 282 | try: 283 | assert np.sum(self.seed_distribution[next_id]) > 0.0 284 | except: 285 | print(next_id, len(tp_set)) 286 | exit(-1) 287 | 288 | #tokenize question 289 | if self.tokenize == 'lstm': 290 | self.query_texts[next_id] = self.tokenizer.tokenize(sample['question']) 291 | else: 292 | tokens = self.tokenizer.encode_plus(text=sample['question'], max_length=self.max_query_word, \ 293 | pad_to_max_length=True, return_attention_mask = False, truncation=True) 294 | self.query_texts[next_id] = np.array(tokens['input_ids']) 295 | 296 | 297 | # construct distribution for answers 298 | answer_list = [] 299 | for answer in sample['answers']: 300 | keyword = 'text' if type(answer['kb_id']) == int else 'kb_id' 301 | answer_ent = self.entity2id[answer[keyword]] 302 | answer_list.append(answer_ent) 303 | if answer_ent in g2l: 304 | self.answer_dists[next_id, g2l[answer_ent]] = 1.0 305 | self.answer_lists[next_id] = answer_list 306 | 307 | if not self.data_eff: 308 | self.kb_adj_mats[next_id] = (np.array(head_list, dtype=int), 309 | np.array(rel_list, dtype=int), 310 | np.array(tail_list, dtype=int)) 311 | 312 | next_id += 1 313 | num_no_query_ent = 0 314 | num_one_query_ent = 0 315 | num_multiple_ent = 0 316 | for i in range(next_id): 317 | ct = num_query_entity[i] 318 | if ct == 1: 319 | num_one_query_ent += 1 320 | elif ct == 0: 321 | num_no_query_ent += 1 322 | else: 323 | num_multiple_ent += 1 324 | print("{} cases in total, {} cases without query entity, {} cases with single query entity," 325 | " {} cases with multiple query entities".format(next_id, num_no_query_ent, 326 | num_one_query_ent, num_multiple_ent)) 327 | 328 | 329 | def build_rel_words(self, tokenize): 330 | """ 331 | Tokenizes relation surface forms. 332 | """ 333 | 334 | max_rel_words = 0 335 | rel_words = [] 336 | if 'metaqa' in self.data_file: 337 | for rel in self.relation2id: 338 | words = rel.split('_') 339 | max_rel_words = max(len(words), max_rel_words) 340 | rel_words.append(words) 341 | #print(rel_words) 342 | else: 343 | for rel in self.relation2id: 344 | rel = rel.strip() 345 | fields = rel.split('.') 346 | try: 347 | words = fields[-2].split('_') + fields[-1].split('_') 348 | max_rel_words = max(len(words), max_rel_words) 349 | rel_words.append(words) 350 | #print(rel, words) 351 | except: 352 | words = ['UNK'] 353 | rel_words.append(words) 354 | pass 355 | #words = fields[-2].split('_') + fields[-1].split('_') 356 | 357 | self.max_rel_words = max_rel_words 358 | if tokenize == 'lstm': 359 | self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int) 360 | self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), len(self.word2id), dtype=int) 361 | for rel_id,tokens in enumerate(rel_words): 362 | for j, word in enumerate(tokens): 363 | if j < self.max_rel_words: 364 | if word in self.word2id: 365 | self.rel_texts[rel_id, j] = self.word2id[word] 366 | self.rel_texts_inv[rel_id, j] = self.word2id[word] 367 | else: 368 | self.rel_texts[rel_id, j] = len(self.word2id) 369 | self.rel_texts_inv[rel_id, j] = len(self.word2id) 370 | else: 371 | if tokenize == 'bert': 372 | tokenizer_name = 'bert-base-uncased' 373 | elif tokenize == 'roberta': 374 | tokenizer_name = 'roberta-base' 375 | elif tokenize == 'sbert': 376 | tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2' 377 | elif tokenize == 'sbert2': 378 | tokenizer_name = 'sentence-transformers/all-mpnet-base-v2' 379 | elif tokenize == 't5': 380 | tokenizer_name = 't5-small' 381 | 382 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 383 | pad_val = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 384 | self.rel_texts = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int) 385 | self.rel_texts_inv = np.full((self.num_kb_relation + 1, self.max_rel_words), pad_val, dtype=int) 386 | 387 | for rel_id,words in enumerate(rel_words): 388 | 389 | tokens = tokenizer.encode_plus(text=' '.join(words), max_length=self.max_rel_words, \ 390 | pad_to_max_length=True, return_attention_mask = False, truncation=True) 391 | tokens_inv = tokenizer.encode_plus(text=' '.join(words[::-1]), max_length=self.max_rel_words, \ 392 | pad_to_max_length=True, return_attention_mask = False, truncation=True) 393 | self.rel_texts[rel_id] = np.array(tokens['input_ids']) 394 | self.rel_texts_inv[rel_id] = np.array(tokens_inv['input_ids']) 395 | 396 | 397 | 398 | #print(rel_words) 399 | #print(len(rel_words), len(self.relation2id)) 400 | assert len(rel_words) == len(self.relation2id) 401 | #print(self.rel_texts, self.max_rel_words) 402 | 403 | def create_kb_adj_mats(self, sample_id): 404 | 405 | """ 406 | Re-build local adj mats if we have data_eff == True (they are not pre-stored). 407 | """ 408 | sample = self.data[sample_id] 409 | g2l = self.global2local_entity_maps[sample_id] 410 | 411 | # build connection between question and entities in it 412 | head_list = [] 413 | rel_list = [] 414 | tail_list = [] 415 | for i, tpl in enumerate(sample['subgraph']['tuples']): 416 | sbj, rel, obj = tpl 417 | try: 418 | head = g2l[self.entity2id[sbj['text']]] 419 | rel = self.relation2id[rel['text']] 420 | tail = g2l[self.entity2id[obj['text']]] 421 | except: 422 | head = g2l[sbj] 423 | rel = int(rel) 424 | tail = g2l[obj] 425 | head_list.append(head) 426 | rel_list.append(rel) 427 | tail_list.append(tail) 428 | if self.use_inverse_relation: 429 | head_list.append(tail) 430 | rel_list.append(rel + len(self.relation2id)) 431 | tail_list.append(head) 432 | 433 | return np.array(head_list, dtype=int), np.array(rel_list, dtype=int), np.array(tail_list, dtype=int) 434 | 435 | 436 | def _build_fact_mat(self, sample_ids, fact_dropout): 437 | """ 438 | Creates local adj mats that contain entities, relations, and structure. 439 | """ 440 | batch_heads = np.array([], dtype=int) 441 | batch_rels = np.array([], dtype=int) 442 | batch_tails = np.array([], dtype=int) 443 | batch_ids = np.array([], dtype=int) 444 | #print(sample_ids) 445 | for i, sample_id in enumerate(sample_ids): 446 | index_bias = i * self.max_local_entity 447 | if self.data_eff: 448 | head_list, rel_list, tail_list = self.create_kb_adj_mats(sample_id) #kb_adj_mats[sample_id] 449 | else: 450 | (head_list, rel_list, tail_list) = self.kb_adj_mats[sample_id] 451 | num_fact = len(head_list) 452 | num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout))) 453 | mask_index = np.random.permutation(num_fact)[: num_keep_fact] 454 | 455 | real_head_list = head_list[mask_index] + index_bias 456 | real_tail_list = tail_list[mask_index] + index_bias 457 | real_rel_list = rel_list[mask_index] 458 | batch_heads = np.append(batch_heads, real_head_list) 459 | batch_rels = np.append(batch_rels, real_rel_list) 460 | batch_tails = np.append(batch_tails, real_tail_list) 461 | batch_ids = np.append(batch_ids, np.full(len(mask_index), i, dtype=int)) 462 | if self.use_self_loop: 463 | num_ent_now = len(self.global2local_entity_maps[sample_id]) 464 | ent_array = np.array(range(num_ent_now), dtype=int) + index_bias 465 | rel_array = np.array([self.num_kb_relation - 1] * num_ent_now, dtype=int) 466 | batch_heads = np.append(batch_heads, ent_array) 467 | batch_tails = np.append(batch_tails, ent_array) 468 | batch_rels = np.append(batch_rels, rel_array) 469 | batch_ids = np.append(batch_ids, np.full(num_ent_now, i, dtype=int)) 470 | fact_ids = np.array(range(len(batch_heads)), dtype=int) 471 | head_count = Counter(batch_heads) 472 | # tail_count = Counter(batch_tails) 473 | weight_list = [1.0 / head_count[head] for head in batch_heads] 474 | # entity2fact_index = torch.LongTensor([batch_heads, fact_ids]) 475 | # entity2fact_val = torch.FloatTensor(weight_list) 476 | # entity2fact_mat = torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size( 477 | # [len(sample_ids) * self.max_local_entity, len(batch_heads)])) 478 | return batch_heads, batch_rels, batch_tails, batch_ids, fact_ids, weight_list 479 | 480 | 481 | def reset_batches(self, is_sequential=True): 482 | if is_sequential: 483 | self.batches = np.arange(self.num_data) 484 | else: 485 | self.batches = np.random.permutation(self.num_data) 486 | 487 | def _build_global2local_entity_maps(self): 488 | """Create a map from global entity id to local entity of each sample""" 489 | global2local_entity_maps = [None] * self.num_data 490 | total_local_entity = 0.0 491 | next_id = 0 492 | for sample in tqdm(self.data): 493 | g2l = dict() 494 | self._add_entity_to_map(self.entity2id, sample['entities'], g2l) 495 | # construct a map from global entity id to local entity id 496 | self._add_entity_to_map(self.entity2id, sample['subgraph']['entities'], g2l) 497 | 498 | global2local_entity_maps[next_id] = g2l 499 | total_local_entity += len(g2l) 500 | self.max_local_entity = max(self.max_local_entity, len(g2l)) 501 | next_id += 1 502 | print('avg local entity: ', total_local_entity / next_id) 503 | print('max local entity: ', self.max_local_entity) 504 | return global2local_entity_maps 505 | 506 | 507 | 508 | @staticmethod 509 | def _add_entity_to_map(entity2id, entities, g2l): 510 | #print(entities) 511 | #print(entity2id) 512 | for entity_global_id in entities: 513 | try: 514 | ent = entity2id[entity_global_id['text']] 515 | if ent not in g2l: 516 | g2l[ent] = len(g2l) 517 | except: 518 | if entity_global_id not in g2l: 519 | g2l[entity_global_id] = len(g2l) 520 | 521 | def deal_q_type(self, q_type=None): 522 | sample_ids = self.sample_ids 523 | if q_type is None: 524 | q_type = self.q_type 525 | if q_type == "seq": 526 | q_input = self.query_texts[sample_ids] 527 | else: 528 | raise NotImplementedError 529 | 530 | return q_input 531 | 532 | 533 | 534 | 535 | 536 | class SingleDataLoader(BasicDataLoader): 537 | """ 538 | Single Dataloader creates training/eval batches during KGQA. 539 | """ 540 | def __init__(self, config, word2id, relation2id, entity2id, tokenize, data_type="train"): 541 | super(SingleDataLoader, self).__init__(config, word2id, relation2id, entity2id, tokenize, data_type) 542 | 543 | def get_batch(self, iteration, batch_size, fact_dropout, q_type=None, test=False): 544 | start = batch_size * iteration 545 | end = min(batch_size * (iteration + 1), self.num_data) 546 | sample_ids = self.batches[start: end] 547 | self.sample_ids = sample_ids 548 | # true_batch_id, sample_ids, seed_dist = self.deal_multi_seed(ori_sample_ids) 549 | # self.sample_ids = sample_ids 550 | # self.true_sample_ids = ori_sample_ids 551 | # self.batch_ids = true_batch_id 552 | true_batch_id = None 553 | seed_dist = self.seed_distribution[sample_ids] 554 | q_input = self.deal_q_type(q_type) 555 | kb_adj_mats = self._build_fact_mat(sample_ids, fact_dropout=fact_dropout) 556 | 557 | if test: 558 | return self.candidate_entities[sample_ids], \ 559 | self.query_entities[sample_ids], \ 560 | kb_adj_mats, \ 561 | q_input, \ 562 | seed_dist, \ 563 | true_batch_id, \ 564 | self.answer_dists[sample_ids], \ 565 | self.answer_lists[sample_ids],\ 566 | 567 | return self.candidate_entities[sample_ids], \ 568 | self.query_entities[sample_ids], \ 569 | kb_adj_mats, \ 570 | q_input, \ 571 | seed_dist, \ 572 | true_batch_id, \ 573 | self.answer_dists[sample_ids] 574 | 575 | 576 | def load_dict(filename): 577 | word2id = dict() 578 | with open(filename, encoding='utf-8') as f_in: 579 | for line in f_in: 580 | word = line.strip() 581 | word2id[word] = len(word2id) 582 | return word2id 583 | 584 | def load_data(config, tokenize): 585 | 586 | """ 587 | Creates train/val/test dataloaders (seperately). 588 | """ 589 | 590 | entity2id = load_dict(config['data_folder'] + config['entity2id']) 591 | word2id = load_dict(config['data_folder'] + config['word2id']) 592 | relation2id = load_dict(config['data_folder'] + config['relation2id']) 593 | 594 | if config["is_eval"]: 595 | train_data = None 596 | valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev") 597 | test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test") 598 | num_word = test_data.num_word 599 | else: 600 | train_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="train") 601 | valid_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="dev") 602 | test_data = SingleDataLoader(config, word2id, relation2id, entity2id, tokenize, data_type="test") 603 | num_word = train_data.num_word 604 | relation_texts = test_data.rel_texts 605 | relation_texts_inv = test_data.rel_texts_inv 606 | entities_texts = None 607 | dataset = { 608 | "train": train_data, 609 | "valid": valid_data, 610 | "test": test_data, #test_data, 611 | "entity2id": entity2id, 612 | "relation2id": relation2id, 613 | "word2id": word2id, 614 | "num_word": num_word, 615 | "rel_texts": relation_texts, 616 | "rel_texts_inv": relation_texts_inv, 617 | "ent_texts": entities_texts 618 | } 619 | return dataset 620 | 621 | 622 | if __name__ == "__main__": 623 | st = time.time() 624 | #args = get_config() 625 | load_data(args) 626 | --------------------------------------------------------------------------------