├── .gitignore ├── requirements.txt ├── README.md ├── main.sh ├── util.py ├── entity_detection.py ├── embedding.py ├── evaluation.py ├── augment_process_dataset.py ├── trim_names.py ├── train_detection.py ├── train_entity.py ├── train_pred.py └── test_main.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data.zip 3 | data/ 4 | preprocess/ 5 | .idea/ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fuzzywuzzy>=0.16.0 2 | scikit-learn>=0.19.1 3 | torchtext==0.3.1 4 | nltk>=3.3 5 | torch==0.4.1 6 | numpy>=1.12.1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Graph Embedding Based Question Answering 2 | Knowledge Graph Embedding Based Question Answering, WSDM 2019 3 | 4 | ## Installation 5 | - Requirements 6 | 1. fuzzywuzzy 7 | 2. scikit-learn 8 | 3. torchtext 9 | 4. nltk 10 | 5. pytorch 11 | 6. numpy 12 | - Usage 13 | 1. cd KEQA_WSDM19 14 | 2. pip install -r requirements.txt 15 | 3. sh main.sh 16 | 17 | ## Reference in BibTeX: 18 | @conference{Huang-etal19Knowledge, 19 | Title = {Knowledge Graph Embedding Based Question Answering}, 20 | Author = {Xiao Huang and Jingyuan Zhang and Dingcheng Li and Ping Li}, 21 | Booktitle = {ACM International Conference on Web Search and Data Mining}, 22 | Year = {2019}} 23 | -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Downloading SimpleQuestions dataset...\n" 4 | wget https://www.dropbox.com/s/9lxudhdfpfkihr1/data.zip 5 | unzip data.zip 6 | rm data.zip 7 | 8 | echo "Preprocess the raw data" 9 | python3.6 trim_names.py -f data/freebase-FB2M.txt -n data/FB5M.name.txt 10 | 11 | echo "Create processed, augmented dataset...\n" 12 | python3.6 augment_process_dataset.py -d data/ 13 | 14 | 15 | echo "Embed the Knowledge Graph:\n" 16 | echo "It takes too long time and an existing method is used. Thus, we download the Knowledge Graph Embedding directly...\n" 17 | wget https://www.dropbox.com/s/o5hd8lnr5c0l6hj/KGembed.zip 18 | unzip KGembed.zip 19 | rm KGembed.zip 20 | mv -f KGembed/* preprocess/ 21 | rm -r KGembed 22 | #python3.6 transE_emb.py --learning_rate 0.003 --batch_size 3000 --eval_freq 50 23 | 24 | 25 | 26 | echo "We could runn train_detection.py, train_entity.py, train_pred.py simultaneously" 27 | 28 | echo "Head Entity Detection (HED) model, train and test the model..." 29 | python3.6 train_detection.py --entity_detection_mode LSTM --fix_embed --gpu 0 30 | 31 | echo "Entity representation learning..." 32 | python3.6 train_entity.py --qa_mode GRU --fix_embed --gpu 0 33 | python3.6 train_pred.py --qa_mode GRU --fix_embed --gpu 0 34 | 35 | echo "We have to run train_detection.py, train_entity.py, train_pred.py first, before running test_main.py..." 36 | python3.6 test_main.py --gpu 0 37 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from nltk.tokenize.treebank import TreebankWordTokenizer 3 | 4 | tokenizer = TreebankWordTokenizer() 5 | 6 | def processed_text(text): 7 | text = text.replace('\\\\', '') 8 | #stripped = strip_accents(text.lower()) 9 | stripped = text.lower() 10 | toks = tokenizer.tokenize(stripped) 11 | return " ".join(toks) 12 | 13 | def strip_accents(text): 14 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 15 | 16 | def www2fb(in_str): 17 | if in_str.startswith("www.freebase.com"): 18 | in_str = '%s' % (in_str.replace('www.freebase.com/', '').replace('/', '.')) 19 | in_str_list = in_str.split() 20 | for i, in_str in enumerate(in_str_list): 21 | # Manual Correction 22 | if in_str == 'm.07s9rl0': 23 | in_str_list[i] = 'm.02822' 24 | if in_str == 'm.0bb56b6': 25 | in_str_list[i] = 'm.0dn0r' 26 | if in_str == 'm.01g81dw': 27 | in_str_list[i] = 'm.01g_bfh' 28 | if in_str == 'm.0y7q89y': 29 | in_str_list[i] = 'm.0wrt1c5' 30 | if in_str == 'm.0b0w7': 31 | in_str_list[i] = 'm.0fq0s89' 32 | if in_str == 'm.09rmm6y': 33 | in_str_list[i] = 'm.03cnrcc' 34 | if in_str == 'm.0crsn60': 35 | in_str_list[i] = 'm.02pnlqy' 36 | if in_str == 'm.04t1f8y': 37 | in_str_list[i] = 'm.04t1fjr' 38 | if in_str == 'm.027z990': 39 | in_str_list[i] = 'm.0ghdhcb' 40 | if in_str == 'm.02xhc2v': 41 | in_str_list[i] = 'm.084sq' 42 | if in_str == 'm.02z8b2h': 43 | in_str_list[i] = 'm.033vn1' 44 | if in_str == 'm.0w43mcj': 45 | in_str_list[i] = 'm.0m0qffc' 46 | if in_str == 'm.07rqy': 47 | in_str_list[i] = 'm.0py_0' 48 | if in_str == 'm.0y9s5rm': 49 | in_str_list[i] = 'm.0ybxl2g' 50 | if in_str == 'm.037ltr7': 51 | in_str_list[i] = 'm.0qjx99s' 52 | return ' '.join(in_str_list) 53 | 54 | def clean_uri(uri): 55 | if uri.startswith("<") and uri.endswith(">"): 56 | return clean_uri(uri[4:-1]) 57 | elif uri.startswith("\"") and uri.endswith("\""): 58 | return clean_uri(uri[1:-1]) 59 | return uri -------------------------------------------------------------------------------- /entity_detection.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class EntityDetection(nn.Module): 5 | def __init__(self, config): 6 | super(EntityDetection, self).__init__() 7 | self.config = config 8 | target_size = config.label 9 | self.embed = nn.Embedding(config.words_num, config.words_dim) 10 | if config.train_embed == False: 11 | self.embed.weight.requires_grad = False 12 | if config.entity_detection_mode.upper() == 'LSTM': 13 | self.lstm = nn.LSTM(input_size=config.words_dim, 14 | hidden_size=config.hidden_size, 15 | num_layers=config.num_layer, 16 | dropout=config.rnn_dropout, 17 | bidirectional=True) 18 | elif config.entity_detection_mode.upper() == 'GRU': 19 | self.gru = nn.GRU(input_size=config.words_dim, 20 | hidden_size=config.hidden_size, 21 | num_layers=config.num_layer, 22 | dropout=config.rnn_dropout, 23 | bidirectional=True) 24 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 25 | self.relu = nn.ReLU() 26 | self.hidden2tag = nn.Sequential( 27 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 28 | nn.BatchNorm1d(config.hidden_size * 2), 29 | self.relu, 30 | self.dropout, 31 | nn.Linear(config.hidden_size * 2, target_size) 32 | ) 33 | 34 | 35 | def forward(self, x): 36 | # x = (sequence length, batch_size, dimension of embedding) 37 | text = x.text 38 | batch_size = text.size()[1] 39 | x = self.embed(text) 40 | # h0 / c0 = (layer*direction, batch_size, hidden_dim) 41 | if self.config.entity_detection_mode.upper() == 'LSTM': 42 | outputs, (ht, ct) = self.lstm(x) 43 | elif self.config.entity_detection_mode.upper() == 'GRU': 44 | outputs, ht = self.gru(x) 45 | else: 46 | print("Wrong Entity Prediction Mode") 47 | exit(1) 48 | tags = self.hidden2tag(outputs.view(-1, outputs.size(2))) 49 | scores = F.log_softmax(tags, dim=1) 50 | return scores 51 | 52 | -------------------------------------------------------------------------------- /embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class EmbedVector(nn.Module): 6 | def __init__(self, config): 7 | super(EmbedVector, self).__init__() 8 | self.config = config 9 | target_size = config.label 10 | self.embed = nn.Embedding(config.words_num, config.words_dim) 11 | if config.train_embed == False: 12 | self.embed.weight.requires_grad = False 13 | if config.qa_mode.upper() == 'LSTM': 14 | self.lstm = nn.LSTM(input_size=config.words_dim, 15 | hidden_size=config.hidden_size, 16 | num_layers=config.num_layer, 17 | dropout=config.rnn_dropout, 18 | bidirectional=True) 19 | elif config.qa_mode.upper() == 'GRU': 20 | self.gru = nn.GRU(input_size=config.words_dim, 21 | hidden_size=config.hidden_size, 22 | num_layers=config.num_layer, 23 | dropout=config.rnn_dropout, 24 | bidirectional=True) 25 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 26 | self.nonlinear = nn.Tanh() 27 | #self.attn = nn.Sequential( 28 | # nn.Linear(config.hidden_size * 2 + config.words_dim, config.hidden_size), 29 | # self.nonlinear, 30 | # nn.Linear(config.hidden_size, 1) 31 | #) 32 | self.hidden2tag = nn.Sequential( 33 | #nn.Linear(config.hidden_size * 2 + config.words_dim, config.hidden_size * 2), 34 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 35 | nn.BatchNorm1d(config.hidden_size * 2), 36 | self.nonlinear, 37 | self.dropout, 38 | nn.Linear(config.hidden_size * 2, target_size) 39 | ) 40 | 41 | def forward(self, x): 42 | # x = (sequence length, batch_size, dimension of embedding) 43 | text = x.text 44 | x = self.embed(text) 45 | num_word, batch_size, words_dim = x.size() 46 | # h0 / c0 = (layer*direction, batch_size, hidden_dim) 47 | 48 | if self.config.qa_mode.upper() == 'LSTM': 49 | outputs, (ht, ct) = self.lstm(x) 50 | elif self.config.qa_mode.upper() == 'GRU': 51 | outputs, ht = self.gru(x) 52 | else: 53 | print("Wrong Entity Prediction Mode") 54 | exit(1) 55 | outputs = outputs.view(-1, outputs.size(2)) 56 | #x = x.view(-1, words_dim) 57 | #attn_weights = F.softmax(self.attn(torch.cat((x, outputs), 1)), dim=0) 58 | #attn_applied = torch.bmm(torch.diag(attn_weights[:, 0]).unsqueeze(0), outputs.unsqueeze(0)) 59 | #outputs = torch.cat((x, attn_applied.squeeze(0)), 1) 60 | tags = self.hidden2tag(outputs).view(num_word, batch_size, -1) 61 | scores = nn.functional.normalize(torch.mean(tags, dim=0), dim=1) 62 | 63 | return scores -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | def get_span(label, index2tag, type): 3 | span = [] 4 | st = -1 5 | en = -1 6 | flag = False 7 | tag = [] 8 | for k in range(len(label)): 9 | if index2tag[label[k]][0] == 'I' and flag == False: 10 | flag = True 11 | st = k 12 | if type: 13 | tag.append(index2tag[label[k]][2:]) 14 | if index2tag[label[k]][0] == 'I' and flag == True: 15 | if type: 16 | tag.append(index2tag[label[k]][2:]) 17 | if index2tag[label[k]][0] != 'I' and flag == True: 18 | flag = False 19 | en = k 20 | if type: 21 | max_tag_counter = Counter(tag) 22 | max_tag = max_tag_counter.most_common()[0][0] 23 | span.append((st, en, max_tag)) 24 | else: 25 | span.append((st,en)) 26 | st = -1 27 | en = -1 28 | tag = [] 29 | if st != -1 and en == -1: 30 | en = len(label) 31 | if type: 32 | max_tag_counter = Counter(tag) 33 | max_tag = max_tag_counter.most_common()[0][0] 34 | span.append((st, en, max_tag)) 35 | else: 36 | span.append((st, en)) 37 | 38 | return span 39 | 40 | def evaluation(gold, pred, index2tag, type): 41 | right = 0 42 | predicted = 0 43 | total_en = 0 44 | #fout = open('log.valid', 'w') 45 | for i in range(len(gold)): 46 | gold_batch = gold[i] 47 | pred_batch = pred[i] 48 | 49 | for j in range(len(gold_batch)): 50 | gold_label = gold_batch[j] 51 | pred_label = pred_batch[j] 52 | gold_span = get_span(gold_label, index2tag, type) 53 | pred_span = get_span(pred_label, index2tag, type) 54 | #fout.write('{}\t{}\n'.format(gold_span, pred_span)) 55 | total_en += len(gold_span) 56 | predicted += len(pred_span) 57 | for item in pred_span: 58 | if item in gold_span: 59 | right += 1 60 | if predicted == 0: 61 | precision = 0 62 | else: 63 | precision = right / predicted 64 | if total_en == 0: 65 | recall = 0 66 | else: 67 | recall = right / total_en 68 | if precision + recall == 0: 69 | f1 = 0 70 | else: 71 | f1 = 2 * precision * recall / (precision + recall) 72 | #fout.flush() 73 | #fout.close() 74 | return precision, recall, f1 75 | 76 | def get_names_for_entities(namespath): 77 | print("getting names map...") 78 | names = {} 79 | with open(namespath, 'r') as f: 80 | for i, line in enumerate(f): 81 | items = line.strip().split("\t") 82 | if len(items) != 2: 83 | print("ERROR: line - {}".format(line)) 84 | continue 85 | entity = items[0] 86 | literal = items[1].strip() 87 | if literal != "": 88 | if names.get(literal) is None: 89 | names[literal] = [(entity)] 90 | else: 91 | names[literal].append(entity) 92 | #print('ERROR: Entities with the same name!') 93 | return names -------------------------------------------------------------------------------- /augment_process_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import re 5 | #import logging 6 | 7 | from fuzzywuzzy import process, fuzz 8 | #from nltk.tokenize.treebank import TreebankWordTokenizer 9 | from util import www2fb, processed_text 10 | #tokenizer = TreebankWordTokenizer() 11 | #logger = logging.getLogger() 12 | #logger.disabled = True 13 | 14 | def get_indices(src_list, pattern_list): 15 | indices = None 16 | for i in range(len(src_list)): 17 | match = 1 18 | for j in range(len(pattern_list)): 19 | if src_list[i+j] != pattern_list[j]: 20 | match = 0 21 | break 22 | if match: 23 | indices = range(i, i + len(pattern_list)) 24 | break 25 | return indices 26 | 27 | def get_ngram(tokens): 28 | ngram = [] 29 | for i in range(1, len(tokens)+1): 30 | for s in range(len(tokens)-i+1): 31 | ngram.append((" ".join(tokens[s: s+i]), s, i+s)) 32 | return ngram 33 | 34 | def reverseLinking(sent, text_candidate): 35 | tokens = sent.split() 36 | label = ["O"] * len(tokens) 37 | text_attention_indices = None 38 | exact_match = False 39 | if text_candidate is None or len(text_candidate) == 0: 40 | return '', ' '.join(label), exact_match 41 | # sorted by length 42 | for text in sorted(text_candidate, key=lambda x:len(x), reverse=True): 43 | pattern = r'(^|\s)(%s)($|\s)' % (re.escape(text)) 44 | if re.search(pattern, sent): 45 | text_attention_indices = get_indices(tokens, text.split()) 46 | break 47 | if text_attention_indices != None: 48 | exact_match = True 49 | for i in text_attention_indices: 50 | label[i] = 'I' 51 | else: 52 | try: 53 | v, score = process.extractOne(sent, text_candidate, scorer=fuzz.partial_ratio) 54 | except: 55 | print("Extraction Error with FuzzyWuzzy : {} || {}".format(sent, text_candidate)) 56 | return '', ' '.join(label), exact_match 57 | v = v.split() 58 | n_gram_candidate = get_ngram(tokens) 59 | n_gram_candidate = sorted(n_gram_candidate, key=lambda x: fuzz.ratio(x[0], v), reverse=True) 60 | top = n_gram_candidate[0] 61 | for i in range(top[1], top[2]): 62 | label[i] = 'I' 63 | entity_text = [] 64 | for l, t in zip(label, tokens): 65 | if l == 'I': 66 | entity_text.append(t) 67 | entity_text = " ".join(entity_text) 68 | label = " ".join(label) 69 | return entity_text, label, exact_match 70 | 71 | def augment_dataset(datadir, outdir): 72 | # Get the name dictionary 73 | names_map = {} 74 | with open(os.path.join(outdir, 'names.trimmed.txt'), 'r') as f: 75 | for i, line in enumerate(f): 76 | if i % 100000 == 0: 77 | print("line: {}".format(i)) 78 | 79 | items = line.strip().split("\t") 80 | if len(items) != 2: 81 | print("ERROR: line - {}".format(line)) 82 | continue 83 | entity = items[0] 84 | literal = items[1].strip() 85 | if names_map.get(entity) is None: 86 | names_map[entity] = [(literal)] 87 | else: 88 | names_map[entity].append(literal) 89 | print("creating new datasets...") 90 | entiset = set() 91 | predset = set() 92 | wordset = [] 93 | for f_tuple in [("annotated_fb_data_train", "train"), ("annotated_fb_data_valid", "valid"), 94 | ("annotated_fb_data_test", "test")]: 95 | f = f_tuple[0] 96 | fname = f_tuple[1] 97 | fpath = os.path.join(datadir, f + ".txt") 98 | fpath_numbered = os.path.join(outdir, fname + ".txt") 99 | total_exact = 0 100 | outfile = open(fpath_numbered, 'w') 101 | print("reading from {}".format(fpath)) 102 | 103 | with open(fpath, 'r') as f: 104 | for i, line in enumerate(f): 105 | items = line.strip().split("\t") 106 | if len(items) != 4: 107 | print("ERROR: line - {}".format(line)) 108 | sys.exit(0) 109 | lineid = i + 1 110 | subject = www2fb(items[0]) 111 | predicate = www2fb(items[1]) 112 | object = www2fb(items[2]) 113 | question = processed_text(items[3]) 114 | entiset.add(subject) 115 | entiset.add(object) 116 | predset.add(predicate) 117 | 118 | if names_map.get(subject) is None: 119 | cand_entity_names = None 120 | else: 121 | cand_entity_names = names_map[subject] 122 | 123 | entity_name, label, exact_match = reverseLinking(question, cand_entity_names) 124 | if exact_match: 125 | total_exact += 1 126 | for token in question.split(): 127 | wordset.append(token) 128 | outfile.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(lineid, subject, entity_name, predicate, object, question, label)) 129 | outfile.close() 130 | print("Exact Match Entity : {} out of {} : {}".format(total_exact, lineid, total_exact / lineid)) 131 | print("wrote to {}".format(fpath_numbered)) 132 | print('Total entities {}'.format(len(entiset))) 133 | print('Total predicates {}'.format(len(predset))) 134 | print('Total words {}'.format(len(set(wordset)) - 1)) # -1 for '' 135 | # outfile = open(os.path.join(outdir, 'synthetic.txt'), 'w') 136 | # total_exact = 0 137 | # lineid = 0 138 | # whereset = {'location', 'place', 'geographic', 'region', 'places'} 139 | # whoset = {'composer', 'people', 'artist', 'author', 'publisher', 'directed', 'developer', 'director', 'lyricist', 140 | # 'edited', 'parents', 'instrumentalists', 'produced', 'manufacturer', 'written', 'designers', 'producer'} 141 | # for line in open(os.path.join(outdir, 'transE_valid.txt'), 'r'): 142 | # items = line.strip().split("\t") 143 | # subject = items[0] 144 | # if names_map.get(subject) is not None: 145 | # lineid += 1 146 | # shortest = 10000 147 | # for name in names_map[subject]: 148 | # if len(name.split()) < shortest: 149 | # cand_entity_names = name 150 | # tokens = items[2].replace('.', ' ').replace('_', ' ').split() 151 | # seen = set() 152 | # clean_token = [token for token in tokens if not (token in seen or seen.add(token))] 153 | # flag = True 154 | # for token in clean_token: 155 | # if token in whereset: 156 | # question = 'where is the ' + ' '.join(clean_token) + ' of ' + cand_entity_names 157 | # flag = False 158 | # break 159 | # elif token in whoset: 160 | # question = 'who is the ' + ' '.join(clean_token) + ' of ' + cand_entity_names 161 | # flag = False 162 | # break 163 | # if flag: 164 | # question = 'what is the ' + ' '.join(clean_token) + ' of ' + cand_entity_names 165 | # cand_entity_names = [cand_entity_names] 166 | # entity_name, label, exact_match = reverseLinking(question, cand_entity_names) 167 | # if exact_match: 168 | # total_exact += 1 169 | # outfile.write( 170 | # '{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(lineid, subject, entity_name, items[2], items[1], question, 171 | # label)) 172 | # outfile.close() 173 | # print("Exact Match Entity : {} out of {} : {}".format(total_exact, lineid, total_exact / lineid)) 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser(description='Augment dataset with line ids, shorted names, entity names') 177 | parser.add_argument('-d', '--dataset', dest='dataset', action='store', required=True, 178 | help='path to the dataset directory - contains train, valid, test files') 179 | parser.add_argument('-o', '--output', type=str, default='preprocess', help='output directory for new dataset') 180 | 181 | args = parser.parse_args() 182 | print("Dataset: {}".format(args.dataset)) 183 | print("Index - Names: /{}/names.trimmed.txt".format(args.output)) 184 | print("Output: {}".format(args.output)) 185 | 186 | augment_dataset(args.dataset, args.output) 187 | -------------------------------------------------------------------------------- /trim_names.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from util import www2fb, processed_text, clean_uri 5 | 6 | # output 'cleanedFB.txt', 'names.trimmed.txt', 'transE_*.txt', 'entity2id.txt', 'relation2id.txt' 7 | 8 | def get_fb_mids_set(cleanfile, fbsubset): 9 | print('get all mids in the Freebase subset...') 10 | lines_seen = set() # holds lines already seen 11 | outfile = open(cleanfile, "w") 12 | mids = [] 13 | for i, line in enumerate(open(fbsubset, "r")): 14 | if i % 1000000 == 0: 15 | print("line: {}".format(i)) 16 | items = line.strip().split("\t") 17 | if len(items) != 3: 18 | print("ERROR: line - {}".format(line)) 19 | entity1 = www2fb(items[0]) 20 | line = "{}\t{}\t{}\n".format(entity1, www2fb(items[2]), www2fb(items[1])) 21 | if line not in lines_seen: # not a duplicate 22 | mids.append(entity1) # mids.extend(entity2.split()) 23 | outfile.write(line) 24 | lines_seen.add(line) 25 | outfile.close() 26 | return set(mids) 27 | 28 | def findsetgrams(dataset): 29 | grams = [] # all possible grams for head entities 30 | ground = [] # Ground truth, for evluation only 31 | whhowset = [{'what', 'how', 'where', 'who', 'which', 'whom'}, 32 | {'in which', 'what is', "what 's", 'what are', 'what was', 'what were', 'where is', 'where are', 33 | 'where was', 'where were', 'who is', 'who was', 'who are', 'how is', 'what did'}, 34 | {'what kind of', 'what kinds of', 'what type of', 'what types of', 'what sort of'}] 35 | for fname in ["annotated_fb_data_valid", "annotated_fb_data_test"]: 36 | for i, line in enumerate(open(os.path.join(dataset, fname + ".txt"), 'r')): 37 | items = line.strip().split("\t") 38 | if len(items) != 4: 39 | print("ERROR: line - {}".format(line)) 40 | break 41 | ground.append(www2fb(items[0])) 42 | question = processed_text(items[3]).split() 43 | if len(question) > 2: 44 | for j in range(3, 0, -1): 45 | if ' '.join(question[0:j]) in whhowset[j - 1]: 46 | del question[0:j] 47 | continue 48 | maxlen = len(question) 49 | for token in question: 50 | grams.append(token) 51 | for j in range(2, maxlen + 1): 52 | for token in [question[idx:idx + j] for idx in range(maxlen - j + 1)]: 53 | grams.append(' '.join(token)) 54 | return set(grams), set(ground) 55 | 56 | def get_all_entity_mids(fbpath, entiset): 57 | print('based on selected entities filter Freebase subset') 58 | mids = [] 59 | #mids_dic = {} 60 | relat = [] 61 | trainfile = open(os.path.join(args.output, 'transE_train.txt'), 'w') 62 | validfile = open(os.path.join(args.output, 'transE_valid.txt'), 'w') 63 | testfile = open(os.path.join(args.output, 'transE_test.txt'), 'w') 64 | with open(fbpath, 'r') as f: 65 | for i, line in enumerate(f): 66 | if i % 1000000 == 0: 67 | print("line: {}".format(i)) 68 | items = line.strip().split("\t") 69 | entity1 = items[0] 70 | if entity1 in entiset: # or entity2 in entiset: # or predicate in predset: 71 | predicate = items[2] 72 | relat.append(predicate) 73 | mids.append(entity1) 74 | #if mids_dic.get(entity1) is None: 75 | # mids_dic[entity1] = [(predicate)] 76 | #else: 77 | # mids_dic[entity1].append(predicate) 78 | #for entity2 in items[1].split(): 79 | entity2 = items[1].split()[0] # could be a list of entities 80 | mids.append(entity2) 81 | trainfile.write("{}\t{}\t{}\n".format(entity1, entity2, predicate)) 82 | j = random.randrange(10) 83 | if not j: 84 | validfile.write("{}\t{}\t{}\n".format(entity1, entity2, predicate)) 85 | if j == 1: 86 | testfile.write("{}\t{}\t{}\n".format(entity1, entity2, predicate)) 87 | trainfile.close() 88 | validfile.close() 89 | testfile.close() 90 | with open(os.path.join(args.output, 'entity2id.txt'), 'w',encoding='UTF-8',errors='ignore') as outfile: 91 | for i, entity in enumerate(set(mids)): 92 | outfile.write("{}\t{}\n".format(entity, i)) 93 | #if mids_dic.get(entity) is None: 94 | # outfile.write("{}\t{}\n".format(entity, i)) 95 | #else: 96 | # tokens = [] 97 | # for context in mids_dic[entity]: 98 | # tokens.append(context.replace('.', ' ').replace('_', ' ')) 99 | # seen = set() 100 | # outfile.write("{}\t{}\t{}\n".format(entity, i, ' '.join([token for token in tokens if not (token in seen or seen.add(token))]))) 101 | print('Number of entities in transE_*: {}'.format(i + 1)) 102 | outfile.close() 103 | with open(os.path.join(args.output, 'relation2id.txt'), 'w',encoding='UTF-8',errors='ignore') as outfile: 104 | for i, predicate in enumerate(set(relat)): 105 | outfile.write("{}\t{}\n".format(predicate, i)) 106 | print('Number of predicates in transE_*: {}'.format(i + 1)) 107 | outfile.close() 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser(description='Preprocess the questions to match the Freebase') 112 | parser.add_argument('-n', '--names', dest='names', action='store', required=True, 113 | help='path to the names file (from CFO)') 114 | parser.add_argument('-f', '--fbsubset', dest='fbsubset', action='store', required=True, 115 | help='path to freebase subset file') 116 | parser.add_argument('-d', '--dataset', type=str, default='data', help='directory contains annotated_fb_data_*') 117 | parser.add_argument('-o', '--output', type=str, default='preprocess/', help='output directory for new dataset') 118 | args = parser.parse_args() 119 | os.makedirs(args.output, exist_ok=True) 120 | 121 | cleanfile = os.path.join(args.output, 'cleanedFB.txt') 122 | fb_mids = get_fb_mids_set(cleanfile, args.fbsubset) 123 | gramset, groundset = findsetgrams(args.dataset) 124 | 125 | print('select head entities based on questions:') 126 | entiset = set() # selected head entities 127 | with open(os.path.join(args.dataset, "annotated_fb_data_train.txt"), 'r',encoding='UTF-8',errors='ignore') as f: 128 | for i, line in enumerate(f): 129 | items = line.strip().split("\t") 130 | if len(items) != 4: 131 | print("ERROR: line - {}".format(line)) 132 | break 133 | entiset.add(www2fb(items[0])) # entiset.add(www2fb(items[2])) 134 | outfile = open(os.path.join(args.output, 'names.trimmed.txt'), 'w',encoding='UTF-8',errors='ignore') # output file path for trimmed names file 135 | with open(args.names, 'r',encoding='UTF-8',errors='ignore') as f: 136 | for i, line in enumerate(f): 137 | if i % 1000000 == 0: 138 | print("line: {}".format(i)) 139 | items = line.strip().split("\t") 140 | if len(items) != 4: 141 | print("ERROR: line - {}".format(line)) 142 | entity = www2fb(clean_uri(items[0])) 143 | if entity in fb_mids: 144 | name = processed_text(clean_uri(items[2])) 145 | if name.strip() != "": 146 | if entity in entiset: 147 | outfile.write("{}\t{}\n".format(entity, name)) 148 | elif name in gramset: 149 | entiset.add(entity) 150 | outfile.write("{}\t{}\n".format(entity, name)) 151 | #name_gram = [name] 152 | #tokens = name.split() 153 | #maxlen = len(tokens) 154 | #if maxlen > 2: 155 | # j = maxlen - 1 156 | # for token in [tokens[idx:idx + j] for idx in range(maxlen - j + 1)]: 157 | # name_gram.append(' '.join(token)) 158 | #for token in name_gram: 159 | outfile.close() 160 | print('{} out of {} entities are selected for head'.format(len(entiset), i + 1)) 161 | i = 0 162 | for entity in groundset: 163 | if entity in entiset: 164 | i += 1 165 | print('recall of head entity selection: {}'.format(float(i) / len(groundset))) 166 | get_all_entity_mids(cleanfile, entiset) 167 | -------------------------------------------------------------------------------- /train_detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import os 5 | import numpy as np 6 | # from torchtext import data 7 | from torchtext.legacy import data 8 | import random 9 | from argparse import ArgumentParser 10 | from evaluation import evaluation 11 | from entity_detection import EntityDetection 12 | 13 | parser = ArgumentParser(description="Joint Prediction") 14 | parser.add_argument('--entity_detection_mode', type=str, required=True, help='options are GRU, LSTM') 15 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 16 | parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU 17 | parser.add_argument('--epochs', type=int, default=30) 18 | parser.add_argument('--batch_size', type=int, default=16) 19 | parser.add_argument('--lr', type=float, default=.0003) 20 | parser.add_argument('--seed', type=int, default=3435) 21 | parser.add_argument('--dev_every', type=int, default=12000) 22 | parser.add_argument('--log_every', type=int, default=2000) 23 | parser.add_argument('--patience', type=int, default=15) 24 | parser.add_argument('--dete_prefix', type=str, default='dete') 25 | parser.add_argument('--words_dim', type=int, default=300) 26 | parser.add_argument('--num_layer', type=int, default=2) 27 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3) 28 | parser.add_argument('--hidden_size', type=int, default=300) 29 | parser.add_argument('--rnn_dropout', type=float, default=0.3) 30 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 31 | parser.add_argument('--vector_cache', type=str, default="data/sq_glove300d.pt") 32 | parser.add_argument('--weight_decay',type=float, default=0) 33 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 34 | # added for testing 35 | parser.add_argument('--output', type=str, default='preprocess/') 36 | args = parser.parse_args() 37 | 38 | outfile = open(os.path.join(args.output, 'dete_train.txt'), 'w',encoding='UTF-8',errors='ignore') 39 | for line in open(os.path.join(args.output, 'train.txt'), 'r',encoding='UTF-8',errors='ignore'): 40 | items = line.strip().split("\t") 41 | tokens = items[6].split() 42 | if any(token != tokens[0] for token in tokens): 43 | outfile.write("{}\t{}\n".format(items[5], items[6])) 44 | outfile.close() 45 | 46 | # Set random seed for reproducibility 47 | torch.manual_seed(args.seed) 48 | np.random.seed(args.seed) 49 | random.seed(args.seed) 50 | torch.backends.cudnn.deterministic = True 51 | 52 | if not args.cuda: 53 | args.gpu = -1 54 | if torch.cuda.is_available() and args.cuda: 55 | print("Note: You are using GPU for training") 56 | torch.cuda.set_device(args.gpu) 57 | torch.cuda.manual_seed(args.seed) 58 | if torch.cuda.is_available() and not args.cuda: 59 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 60 | 61 | # Set up the data for training 62 | TEXT = data.Field(lower=True) 63 | ED = data.Field() 64 | train = data.TabularDataset(path=os.path.join(args.output, 'dete_train.txt'), format='tsv', fields=[('text', TEXT), ('ed', ED)]) 65 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', ED)] 66 | dev, test = data.TabularDataset.splits(path=args.output, validation='valid.txt', test='test.txt', format='tsv', fields=field) 67 | TEXT.build_vocab(train, dev, test) 68 | ED.build_vocab(train, dev) 69 | 70 | match_embedding = 0 71 | if os.path.isfile(args.vector_cache): 72 | stoi, vectors, dim = torch.load(args.vector_cache) 73 | TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), dim) 74 | for i, token in enumerate(TEXT.vocab.itos): 75 | wv_index = stoi.get(token, None) 76 | if wv_index is not None: 77 | TEXT.vocab.vectors[i] = vectors[wv_index] 78 | match_embedding += 1 79 | else: 80 | TEXT.vocab.vectors[i] = torch.FloatTensor(dim).uniform_(-0.25, 0.25) 81 | else: 82 | print("Error: Need word embedding pt file") 83 | exit(1) 84 | 85 | print("Embedding match number {} out of {}".format(match_embedding, len(TEXT.vocab))) 86 | 87 | if args.cuda: 88 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=True, 89 | repeat=False, sort=False, shuffle=True, sort_within_batch=False) 90 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 91 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 92 | else: 93 | train_iter = data.Iterator(train, batch_size=args.batch_size, train=True, repeat=False, sort=False, shuffle=True, 94 | sort_within_batch=False) 95 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 96 | sort_within_batch=False) 97 | 98 | config = args 99 | config.words_num = len(TEXT.vocab) 100 | config.label = len(ED.vocab) 101 | model = EntityDetection(config) 102 | model.embed.weight.data.copy_(TEXT.vocab.vectors) 103 | if args.cuda: 104 | modle = model.to(torch.device("cuda:{}".format(args.gpu))) 105 | print("Shift model to GPU") 106 | 107 | print(config) 108 | print("VOCAB num",len(TEXT.vocab)) 109 | print("Train instance", len(train)) 110 | print("Dev instance", len(dev)) 111 | print("Entity Type", len(ED.vocab)) 112 | print(model) 113 | 114 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 115 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 116 | criterion = nn.NLLLoss() 117 | 118 | early_stop = False 119 | best_dev_R = 0 120 | iterations = 0 121 | iters_not_improved = 0 122 | num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1 123 | patience = args.patience * num_dev_in_epoch # for early stopping 124 | epoch = 0 125 | start = time.time() 126 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:10.6f}%'.split(',')) 127 | print(' Time Epoch Iteration Progress (%Epoch) Loss Accuracy') 128 | 129 | index2tag = np.array(ED.vocab.itos) # ['' '' 'O' 'I'] 130 | 131 | while True: 132 | if early_stop: 133 | print("Early Stopping. Epoch: {}, Best Dev Recall: {}".format(epoch, best_dev_R)) 134 | break 135 | epoch += 1 136 | train_iter.init_epoch() 137 | n_correct, n_total = 0, 0 138 | 139 | for batch_idx, batch in enumerate(train_iter): 140 | # Batch size : (Sentence Length, Batch_size) 141 | iterations += 1 142 | model.train() 143 | optimizer.zero_grad() 144 | scores = model(batch) 145 | # Entity Detection 146 | n_correct += torch.sum( 147 | torch.sum((torch.max(scores, 1)[1].view(batch.ed.size()).data == batch.ed.data), dim=0) == batch.ed.size()[ 148 | 0]).item() 149 | loss = criterion(scores, batch.ed.view(-1, 1)[:, 0]) 150 | n_total += batch.batch_size 151 | loss.backward() 152 | # clip the gradient 153 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient) 154 | optimizer.step() 155 | 156 | # evaluate performance on validation set periodically 157 | if iterations % args.dev_every == 0: 158 | model.eval() 159 | dev_iter.init_epoch() 160 | gold_list = [] 161 | pred_list = [] 162 | 163 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 164 | answer = model(dev_batch) 165 | #n_dev_correct += ( 166 | # (torch.max(answer, 1)[1].view(dev_batch.ed.size()).data == dev_batch.ed.data).sum(dim=0) == 167 | # dev_batch.ed.size()[0]).sum() 168 | index_tag = np.transpose(torch.max(answer, 1)[1].view(dev_batch.ed.size()).cpu().data.numpy()) 169 | gold_list.append(np.transpose(dev_batch.ed.cpu().data.numpy())) 170 | pred_list.append(index_tag) 171 | 172 | P, R, F = evaluation(gold_list, pred_list, index2tag, type=False) 173 | print("{} Recall: {:10.6f}% Precision: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * R, 100. * P, 174 | 100. * F)) 175 | 176 | # update model 177 | if R > best_dev_R: 178 | best_dev_R = R 179 | iters_not_improved = 0 180 | snapshot_path = os.path.join(args.output, args.dete_prefix + '_best_model.pt') 181 | # save model, delete previous 'best_snapshot' files 182 | torch.save(model, snapshot_path) # .state_dict() 183 | else: 184 | iters_not_improved += 1 185 | if iters_not_improved > patience: 186 | early_stop = True 187 | break 188 | 189 | if iterations % args.log_every == 1: 190 | print(log_template.format(time.time() - start, epoch, iterations, 1 + batch_idx, len(train_iter), 191 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 3, 192 | 100. * n_correct / n_total)) 193 | 194 | # Early Stopping. Epoch: 119, Best Dev Recall: 0.9513194245975041 195 | -------------------------------------------------------------------------------- /train_entity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | from torchtext import data 9 | from argparse import ArgumentParser 10 | from embedding import EmbedVector 11 | from evaluation import get_names_for_entities 12 | from sklearn.metrics.pairwise import euclidean_distances 13 | 14 | parser = ArgumentParser(description="Training") 15 | parser.add_argument('--qa_mode', type=str, required=True, help='options are GRU, LSTM') 16 | parser.add_argument('--embed_dim', type=int, default=250) 17 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 18 | parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU 19 | parser.add_argument('--epochs', type=int, default=30) 20 | parser.add_argument('--batch_size', type=int, default=32) 21 | parser.add_argument('--lr', type=float, default=0.0002) 22 | parser.add_argument('--seed', type=int, default=3435) 23 | parser.add_argument('--dev_every', type=int, default=10000) 24 | parser.add_argument('--log_every', type=int, default=2000) 25 | parser.add_argument('--output_channel', type=int, default=300) 26 | parser.add_argument('--patience', type=int, default=10) 27 | parser.add_argument('--best_prefix', type=str, default='entity') 28 | parser.add_argument('--num_layer', type=int, default=2) 29 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3) 30 | parser.add_argument('--hidden_size', type=int, default=300) 31 | parser.add_argument('--rnn_dropout', type=float, default=0.3) 32 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 33 | parser.add_argument('--vector_cache', type=str, default="data/sq_glove300d.pt") 34 | parser.add_argument('--weight_decay',type=float, default=0) 35 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 36 | parser.add_argument('--output', type=str, default='preprocess') 37 | args = parser.parse_args() 38 | 39 | ################## Prepare training and validation datasets ################## 40 | # Dictionary and embedding for words 41 | if os.path.isfile(args.vector_cache): 42 | stoi, vectors, words_dim = torch.load(args.vector_cache) 43 | else: 44 | print("Error: Need word embedding pt file") 45 | exit(1) 46 | 47 | mid_dic = {} # Dictionary for MID 48 | for line in open(os.path.join(args.output, 'entity2id.txt'), 'r'): 49 | items = line.strip().split("\t") 50 | mid_dic[items[0]] = int(items[1]) 51 | outfile = open(os.path.join(args.output, 'entity_train.txt'), 'w') 52 | for line in open(os.path.join(args.output, 'train.txt'), 'r'): 53 | items = line.strip().split("\t") 54 | if items[1] in mid_dic: 55 | outfile.write("{}\t{}\n".format(items[5], mid_dic[items[1]])) 56 | outfile.close() 57 | # context = [] 58 | # for token in list(compress(items[5].split(), [obj == 'O' for obj in items[6].split()])): 59 | # if token not in stop_words and stoi.get(token) is not None: 60 | # context.append(token) 61 | # if context: 62 | 63 | entities_emb = np.fromfile(os.path.join(args.output, 'entities_emb.bin'), dtype=np.float32).reshape((len(mid_dic), args.embed_dim)) 64 | mid_emb_list = [] 65 | mids_list = [] 66 | index_names = get_names_for_entities(os.path.join(args.output, 'names.trimmed.txt')) 67 | outfile = open(os.path.join(args.output, 'entity_valid.txt'), 'w') 68 | for line in open(os.path.join(args.output, 'valid.txt'), 'r'): 69 | items = line.strip().split("\t") 70 | if items[1] in mid_dic and items[2] in index_names: 71 | mids = [mid for mid in index_names.get(items[2]) if mid in mid_dic] 72 | if len(mids) > 1: 73 | mids_list.append(mids) 74 | outfile.write("{}\t{}\n".format(items[5], mid_dic[items[1]])) 75 | mid_emb = [] 76 | for mid in mids: 77 | mid_emb.append(entities_emb[mid_dic[mid]]) 78 | mid_emb_list.append(np.asarray(mid_emb)) 79 | #if flag: 80 | # outtrain.write("{}\t{}\n".format(items[5], mid_dic[items[1]])) 81 | outfile.close() 82 | del index_names 83 | entities_emb = torch.from_numpy(entities_emb) 84 | 85 | #with open(os.path.join(args.output, entity2id.txt'), 'r') as f: 86 | # for line in f: 87 | # items = line.strip().split("\t") 88 | # if len(items) == 3: 89 | # context = [] 90 | # for token in items[2].split(): 91 | # if token not in stop_words and stoi.get(token) is not None: 92 | # context.append(token) 93 | # if context: 94 | # entity_train.write("{}\t{}\n".format(' '.join(context), items[0])) 95 | 96 | ################## Set random seed for reproducibility ################## 97 | torch.manual_seed(args.seed) 98 | np.random.seed(args.seed) 99 | random.seed(args.seed) 100 | torch.backends.cudnn.deterministic = True 101 | 102 | if not args.cuda: 103 | args.gpu = -1 104 | if torch.cuda.is_available() and args.cuda: 105 | print("Note: You are using GPU for training") 106 | torch.cuda.set_device(args.gpu) 107 | torch.cuda.manual_seed(args.seed) 108 | if torch.cuda.is_available() and not args.cuda: 109 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 110 | 111 | ################## Load the datasets ################## 112 | TEXT = data.Field(lower=True) 113 | ED = data.Field(sequential=False, use_vocab=False) 114 | train, dev = data.TabularDataset.splits(path=args.output, train='entity_train.txt', validation='entity_valid.txt', format='tsv', fields=[('text', TEXT), ('mid', ED)]) 115 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', None)] 116 | test = data.TabularDataset(path=os.path.join(args.output, 'test.txt'), format='tsv', fields=field) 117 | TEXT.build_vocab(train, dev, test) # training data includes validation data 118 | 119 | 120 | match_embedding = 0 121 | TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), words_dim) 122 | for i, token in enumerate(TEXT.vocab.itos): 123 | wv_index = stoi.get(token, None) 124 | if wv_index is not None: 125 | TEXT.vocab.vectors[i] = vectors[wv_index] 126 | match_embedding += 1 127 | else: 128 | TEXT.vocab.vectors[i] = torch.FloatTensor(words_dim).uniform_(-0.25, 0.25) 129 | print("Word embedding match number {} out of {}".format(match_embedding, len(TEXT.vocab))) 130 | del stoi, vectors 131 | 132 | 133 | ################## batch ################## 134 | if args.cuda: 135 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=True, 136 | repeat=False, sort=False, shuffle=True, sort_within_batch=False) 137 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 138 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 139 | else: 140 | train_iter = data.Iterator(train, batch_size=args.batch_size, train=True, repeat=False, sort=False, shuffle=True, 141 | sort_within_batch=False) 142 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 143 | sort_within_batch=False) 144 | 145 | config = args 146 | config.words_num = len(TEXT.vocab) 147 | config.label = args.embed_dim 148 | config.words_dim = words_dim 149 | model = EmbedVector(config) 150 | model.embed.weight.data.copy_(TEXT.vocab.vectors) 151 | 152 | if args.cuda: 153 | modle = model.to(torch.device("cuda:{}".format(args.gpu))) 154 | print("Shift model to GPU") 155 | entities_emb = entities_emb.cuda() 156 | 157 | print(config) 158 | print("VOCAB num", len(TEXT.vocab)) 159 | print("Train instance", len(train)) 160 | print("Dev instance", len(dev)) 161 | print(model) 162 | 163 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 164 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 165 | criterion = nn.MSELoss() 166 | 167 | early_stop = False 168 | best_model, iterations, iters_not_improved = 0, 0, 0 169 | num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1 170 | patience = args.patience * num_dev_in_epoch # for early stopping 171 | epoch = 0 172 | start = time.time() 173 | print(' Time Epoch Iteration Progress (%Epoch) Loss') 174 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f}'.split(',')) 175 | 176 | while True: 177 | if early_stop: 178 | print("Early Stopping. Epoch: {}, Best Dev Accuracy: {}".format(epoch, best_model)) 179 | break 180 | epoch += 1 181 | train_iter.init_epoch() 182 | for batch_idx, batch in enumerate(train_iter): 183 | # Batch size : (Sentence Length, Batch_size) 184 | iterations += 1 185 | model.train() 186 | optimizer.zero_grad() 187 | loss = criterion(model(batch), entities_emb[batch.mid, :]) 188 | loss.backward() 189 | # clip the gradient 190 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient) 191 | optimizer.step() 192 | # evaluate performance on validation set periodically 193 | if iterations % args.dev_every == 0: 194 | model.eval() 195 | dev_iter.init_epoch() 196 | baseidx, n_dev_correct = 0, 0 197 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 198 | batch_size = dev_batch.text.size()[1] 199 | answer = model(dev_batch).cpu().data.numpy() 200 | label = dev_batch.mid.data 201 | for devi in range(batch_size): 202 | if label[devi].item() == mid_dic[mids_list[baseidx + devi][ 203 | euclidean_distances(answer[devi].reshape(1, -1), mid_emb_list[baseidx + devi]).argmin(axis=1)[ 204 | 0]]]: 205 | n_dev_correct += 1 206 | baseidx = baseidx + batch_size 207 | curr_accu = n_dev_correct / len(mids_list) 208 | print('Dev Accuracy: {}'.format(curr_accu)) 209 | # update model 210 | if curr_accu > best_model: 211 | best_model = curr_accu 212 | iters_not_improved = 0 213 | # save model, delete previous 'best_snapshot' files 214 | torch.save(model, os.path.join(args.output, args.best_prefix + '_best_model.pt')) 215 | else: 216 | iters_not_improved += 1 217 | if iters_not_improved > patience: 218 | early_stop = True 219 | break 220 | 221 | if iterations % args.log_every == 1: 222 | # print progress message 223 | print(log_template.format(time.time() - start, 224 | epoch, iterations, 1 + batch_idx, len(train_iter), 225 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8, ' ' * 12)) 226 | -------------------------------------------------------------------------------- /train_pred.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | from torchtext import data 9 | from argparse import ArgumentParser 10 | from embedding import EmbedVector 11 | from sklearn.metrics.pairwise import euclidean_distances 12 | 13 | parser = ArgumentParser(description="Training predicate vector learning") 14 | parser.add_argument('--qa_mode', type=str, required=True, help='options are GRU, LSTM') 15 | parser.add_argument('--embed_dim', type=int, default=250) 16 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 17 | parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU 18 | parser.add_argument('--epochs', type=int, default=30) 19 | parser.add_argument('--batch_size', type=int, default=32) 20 | parser.add_argument('--lr', type=float, default=0.0003) 21 | parser.add_argument('--seed', type=int, default=3435) 22 | parser.add_argument('--dev_every', type=int, default=10000) 23 | parser.add_argument('--log_every', type=int, default=2000) 24 | parser.add_argument('--patience', type=int, default=12) 25 | parser.add_argument('--best_prefix', type=str, default='pred') 26 | parser.add_argument('--output_channel', type=int, default=300) 27 | parser.add_argument('--num_layer', type=int, default=2) 28 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3) 29 | parser.add_argument('--hidden_size', type=int, default=300) 30 | parser.add_argument('--rnn_dropout', type=float, default=0.3) 31 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 32 | parser.add_argument('--vector_cache', type=str, default="data/sq_glove300d.pt") 33 | parser.add_argument('--weight_decay',type=float, default=0) 34 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 35 | parser.add_argument('--output', type=str, default='preprocess') 36 | args = parser.parse_args() 37 | 38 | ################## Prepare training and validation datasets ################## 39 | pre_dic = {} # Dictionary for predicates 40 | for line in open(os.path.join(args.output, 'relation2id.txt'), 'r'): 41 | items = line.strip().split("\t") 42 | pre_dic[items[0]] = int(items[1]) 43 | # Embedding for predicates 44 | predicates_emb = torch.from_numpy(np.fromfile(os.path.join(args.output, 'predicates_emb.bin'), dtype=np.float32).reshape((len(pre_dic), args.embed_dim))) 45 | # Set up the data for training 46 | for filename in ['train.txt', 'valid.txt']: 47 | outfile = open(os.path.join(args.output, 'pred_' + filename), 'w') 48 | for line in open(os.path.join(args.output, filename), 'r'): 49 | items = line.strip().split("\t") 50 | if items[3] in pre_dic: 51 | outfile.write("{}\t{}\n".format(items[5], pre_dic[items[3]])) 52 | # pred_list.append(entities_emb[mid_dic[items[4]], :] - entities_emb[mid_dic[items[1]], :]) 53 | # token = list(compress(items[5].split(), [element == 'O' for element in items[6].split()])) 54 | # if not token: 55 | outfile.close() 56 | 57 | synthetic_flag = True 58 | if synthetic_flag: 59 | names_map = {} 60 | for i, line in enumerate(open(os.path.join(args.output, 'names.trimmed.txt'), 'r')): 61 | items = line.strip().split("\t") 62 | if len(items) != 2: 63 | print("ERROR: line - {}".format(line)) 64 | continue 65 | entity = items[0] 66 | literal = items[1].strip() 67 | if literal != "" and (names_map.get(entity) is None or len(names_map[entity].split()) > len(literal.split())): 68 | names_map[entity] = literal 69 | seen_fact = [] 70 | for line in open(os.path.join(args.output, 'train.txt'), 'r'): 71 | items = line.strip().split("\t") 72 | names_map[items[1]] = items[2] 73 | seen_fact.append((items[1], items[3])) 74 | seen_fact = set(seen_fact) 75 | whereset = {'location', 'place', 'geographic', 'region', 'places'} 76 | whoset = {'composer', 'people', 'artist', 'author', 'publisher', 'directed', 'developer', 'director', 'lyricist', 77 | 'edited', 'parents', 'instrumentalists', 'produced', 'manufacturer', 'written', 'designers', 'producer'} 78 | outfile = open(os.path.join(args.output, 'pred_train.txt'), 'a') 79 | for line in open(os.path.join(args.output, 'transE_valid.txt'), 'r'): 80 | items = line.strip().split("\t") 81 | if (items[0], items[2]) not in seen_fact and names_map.get(items[0]) is not None: 82 | name = names_map[items[0]] 83 | tokens = items[2].replace('.', ' ').replace('_', ' ').split() 84 | seen = set() 85 | clean_token = [token for token in tokens if not (token in seen or seen.add(token))] 86 | question = 'what is the ' + ' '.join(clean_token) + ' of ' + name 87 | for token in clean_token: 88 | if token in whereset: 89 | question = 'where is ' + ' '.join(clean_token) + ' of ' + name 90 | break 91 | elif token in whoset: 92 | question = 'who is the ' + ' '.join(clean_token) + ' of ' + name 93 | break 94 | outfile.write("{}\t{}\n".format(question, pre_dic[items[2]])) 95 | outfile.close() 96 | del names_map, pre_dic 97 | 98 | ################## Set random seed for reproducibility ################## 99 | torch.manual_seed(args.seed) 100 | np.random.seed(args.seed) 101 | random.seed(args.seed) 102 | torch.backends.cudnn.deterministic = True 103 | 104 | if not args.cuda: 105 | args.gpu = -1 106 | if torch.cuda.is_available() and args.cuda: 107 | print("Note: You are using GPU for training") 108 | torch.cuda.set_device(args.gpu) 109 | torch.cuda.manual_seed(args.seed) 110 | if torch.cuda.is_available() and not args.cuda: 111 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 112 | 113 | # Dictionary and embedding for words 114 | if os.path.isfile(args.vector_cache): 115 | stoi, vectors, words_dim = torch.load(args.vector_cache) 116 | else: 117 | print("Error: Need word embedding pt file") 118 | exit(1) 119 | 120 | ################## Load the datasets ################## 121 | TEXT = data.Field(lower=True) 122 | ED = data.Field(sequential=False, use_vocab=False) 123 | train, dev = data.TabularDataset.splits(path=args.output, train='pred_train.txt', validation='pred_valid.txt', format='tsv', fields=[('text', TEXT), ('mid', ED)]) 124 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', None)] 125 | test = data.TabularDataset(path=os.path.join(args.output, 'test.txt'), format='tsv', fields=field) 126 | TEXT.build_vocab(train, dev, test) 127 | 128 | match_embedding = 0 129 | TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), words_dim) 130 | for i, token in enumerate(TEXT.vocab.itos): 131 | wv_index = stoi.get(token, None) 132 | if wv_index is not None: 133 | TEXT.vocab.vectors[i] = vectors[wv_index] 134 | match_embedding += 1 135 | else: 136 | TEXT.vocab.vectors[i] = torch.FloatTensor(words_dim).uniform_(-0.25, 0.25) 137 | print("Word embedding match number {} out of {}".format(match_embedding, len(TEXT.vocab))) 138 | 139 | del stoi, vectors 140 | 141 | if args.cuda: 142 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=True, 143 | repeat=False, sort=False, shuffle=True, sort_within_batch=False) 144 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 145 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 146 | else: 147 | train_iter = data.Iterator(train, batch_size=args.batch_size, train=True, repeat=False, sort=False, shuffle=True, 148 | sort_within_batch=False) 149 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 150 | sort_within_batch=False) 151 | 152 | config = args 153 | config.words_num = len(TEXT.vocab) 154 | config.label = args.embed_dim 155 | config.words_dim = words_dim 156 | model = EmbedVector(config) 157 | 158 | model.embed.weight.data.copy_(TEXT.vocab.vectors) 159 | if args.cuda: 160 | modle = model.to(torch.device("cuda:{}".format(args.gpu))) 161 | print("Shift model to GPU") 162 | # Embedding for MID 163 | predicates_emb = predicates_emb.cuda() 164 | 165 | total_num = len(dev) 166 | print(config) 167 | print("VOCAB num",len(TEXT.vocab)) 168 | print("Train instance", len(train)) 169 | print("Dev instance", total_num) 170 | print(model) 171 | 172 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 173 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 174 | criterion = nn.MSELoss() 175 | 176 | early_stop = False 177 | best_accu = 0 178 | best_loss = total_num 179 | iterations = 0 180 | iters_not_improved = 0 181 | num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1 182 | patience = args.patience * num_dev_in_epoch # for early stopping 183 | epoch = 0 184 | start = time.time() 185 | print(' Time Epoch Iteration Progress (%Epoch) Loss') 186 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f}'.split(',')) 187 | 188 | while True: 189 | if early_stop: 190 | print("Early Stopping. Epoch: {}, Best Dev accuracy: {}, loss: {},".format(epoch, best_accu, best_loss)) 191 | break 192 | epoch += 1 193 | train_iter.init_epoch() 194 | for batch_idx, batch in enumerate(train_iter): 195 | # Batch size : (Sentence Length, Batch_size) 196 | iterations += 1 197 | model.train() 198 | optimizer.zero_grad() 199 | loss = criterion(model(batch), predicates_emb[batch.mid, :]) 200 | loss.backward() 201 | # clip the gradient 202 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient) 203 | optimizer.step() 204 | 205 | # evaluate performance on validation set periodically 206 | if iterations % args.dev_every == 0: 207 | model.eval() 208 | dev_iter.init_epoch() 209 | n_dev_correct = 0 210 | dev_loss = 0 211 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 212 | batch_size = dev_batch.text.size()[1] 213 | answer = model(dev_batch) 214 | learned_pred = euclidean_distances(answer.cpu().data.numpy(), predicates_emb.cpu()).argmin(axis=1) 215 | n_dev_correct += sum(dev_batch.mid.cpu().data.numpy() == learned_pred) 216 | dev_loss += criterion(answer, predicates_emb[dev_batch.mid, :]).item() * batch_size 217 | 218 | curr_accu = n_dev_correct / total_num 219 | total_loss = dev_loss/total_num 220 | print('Dev loss: {}, accuracy: {}'.format(total_loss, curr_accu)) 221 | 222 | # update model 223 | if curr_accu > best_accu: # total_loss < best_model 224 | best_accu = curr_accu 225 | best_loss = total_loss 226 | iters_not_improved = 0 227 | # save model, delete previous 'best_snapshot' files 228 | torch.save(model, os.path.join(args.output, args.best_prefix + '_best_model.pt')) 229 | else: 230 | iters_not_improved += 1 231 | if iters_not_improved > patience: 232 | early_stop = True 233 | break 234 | 235 | if iterations % args.log_every == 1: 236 | print(log_template.format(time.time() - start, 237 | epoch, iterations, 1 + batch_idx, len(train_iter), 238 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8, ' ' * 12)) 239 | -------------------------------------------------------------------------------- /test_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import os 5 | 6 | from nltk.corpus import stopwords 7 | from itertools import compress 8 | from evaluation import evaluation, get_span 9 | from argparse import ArgumentParser 10 | from torchtext import data 11 | from sklearn.metrics.pairwise import euclidean_distances 12 | from fuzzywuzzy import fuzz 13 | from util import www2fb, processed_text, clean_uri 14 | 15 | parser = ArgumentParser(description="Joint Prediction") 16 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 17 | parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU 18 | parser.add_argument('--embed_dim', type=int, default=250) 19 | parser.add_argument('--batch_size', type=int, default=16) 20 | parser.add_argument('--seed', type=int, default=3435) 21 | parser.add_argument('--dete_model', type=str, default='dete_best_model.pt') 22 | parser.add_argument('--entity_model', type=str, default='entity_best_model.pt') 23 | parser.add_argument('--pred_model', type=str, default='pred_best_model.pt') 24 | parser.add_argument('--output', type=str, default='preprocess') 25 | args = parser.parse_args() 26 | args.dete_model = os.path.join(args.output, args.dete_model) 27 | args.entity_model = os.path.join(args.output, args.entity_model) 28 | args.pred_model = os.path.join(args.output, args.pred_model) 29 | 30 | def entity_predict(dataset_iter): 31 | model.eval() 32 | dataset_iter.init_epoch() 33 | gold_list = [] 34 | pred_list = [] 35 | dete_result = [] 36 | question_list = [] 37 | for data_batch_idx, data_batch in enumerate(dataset_iter): 38 | #batch_size = data_batch.text.size()[1] 39 | answer = torch.max(model(data_batch), 1)[1].view(data_batch.ed.size()) 40 | answer[(data_batch.text.data == 1)] = 1 41 | answer = np.transpose(answer.cpu().data.numpy()) 42 | gold_list.append(np.transpose(data_batch.ed.cpu().data.numpy())) 43 | index_question = np.transpose(data_batch.text.cpu().data.numpy()) 44 | question_array = index2word[index_question] 45 | dete_result.extend(answer) 46 | question_list.extend(question_array) 47 | #for i in range(batch_size): # If no word is detected as entity, select top 3 possible words 48 | # if all([j == 1 or j == idxO for j in answer[i]]): 49 | # index = list(range(i, scores.shape[0], batch_size)) 50 | # FindOidx = [j for j, x in enumerate(answer[i]) if x == idxO] 51 | # idx_in_socres = [index[j] for j in FindOidx] 52 | # subscores = scores[idx_in_socres] 53 | # answer[i][torch.sort(torch.max(subscores, 1)[0], descending=True)[1][0:min(2, len(FindOidx))]] = idxI 54 | pred_list.append(answer) 55 | P, R, F = evaluation(gold_list, pred_list, index2tag, type=False) 56 | print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R, 100. * F)) 57 | return dete_result, question_list 58 | 59 | def compute_reach_dic(matched_mid): 60 | reach_dic = {} # reach_dic[head_id] = (pred_id, tail_id) 61 | with open(os.path.join(args.output, 'transE_train.txt'), 'r') as f: 62 | for line in f: 63 | items = line.strip().split("\t") 64 | head_id = items[0] 65 | if head_id in matched_mid and items[2] in pre_dic: 66 | if reach_dic.get(head_id) is None: 67 | reach_dic[head_id] = [pre_dic[items[2]]] 68 | else: 69 | reach_dic[head_id].append(pre_dic[items[2]]) 70 | return reach_dic 71 | 72 | # Set random seed for reproducibility 73 | torch.manual_seed(args.seed) 74 | np.random.seed(args.seed) 75 | random.seed(args.seed) 76 | 77 | if not args.cuda: 78 | args.gpu = -1 79 | if torch.cuda.is_available() and args.cuda: 80 | print("Note: You are using GPU for testing") 81 | torch.cuda.set_device(args.gpu) 82 | torch.cuda.manual_seed(args.seed) 83 | if torch.cuda.is_available() and not args.cuda: 84 | print("Warning: You have Cuda but not use it. You are using CPU for testing.") 85 | 86 | 87 | ######################## Entity Detection ######################## 88 | TEXT = data.Field(lower=True) 89 | ED = data.Field() 90 | train = data.TabularDataset(path=os.path.join(args.output, 'dete_train.txt'), format='tsv', fields=[('text', TEXT), ('ed', ED)]) 91 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', ED)] 92 | dev, test = data.TabularDataset.splits(path=args.output, validation='valid.txt', test='test.txt', format='tsv', fields=field) 93 | TEXT.build_vocab(train, dev, test) 94 | ED.build_vocab(train, dev) 95 | total_num = len(test) 96 | print('total num of example: {}'.format(total_num)) 97 | 98 | # load the model 99 | if args.gpu == -1: # Load all tensors onto the CPU 100 | test_iter = data.Iterator(test, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 101 | sort_within_batch=False) 102 | model = torch.load(args.dete_model, map_location=lambda storage, loc: storage) 103 | model.config.cuda = False 104 | else: 105 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 106 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 107 | model = torch.load(args.dete_model, map_location=lambda storage, loc: storage.cuda(args.gpu)) 108 | index2tag = np.array(ED.vocab.itos) 109 | idxO = int(np.where(index2tag == 'O')[0][0]) # Index for 'O' 110 | idxI = int(np.where(index2tag == 'I')[0][0]) # Index for 'I' 111 | index2word = np.array(TEXT.vocab.itos) 112 | # run the model on the test set and write the output to a file 113 | dete_result, question_list = entity_predict(dataset_iter=test_iter) 114 | del model 115 | 116 | 117 | ######################## Find matched names ######################## 118 | mid_dic, mid_num_dic = {}, {} # Dictionary for MID 119 | for line in open(os.path.join(args.output, 'entity2id.txt'), 'r'): 120 | items = line.strip().split("\t") 121 | mid_dic[items[0]] = int(items[1]) 122 | mid_num_dic[int(items[1])] = items[0] 123 | pre_dic, pre_num_dic = {}, {} # Dictionary for predicates 124 | match_pool = [] 125 | for line in open(os.path.join(args.output, 'relation2id.txt'), 'r'): 126 | items = line.strip().split("\t") 127 | match_pool = match_pool + items[0].replace('.', ' ').replace('_', ' ').split() 128 | pre_dic[items[0]] = int(items[1]) 129 | pre_num_dic[int(items[1])] = items[0] 130 | # Embedding for MID 131 | entities_emb = np.fromfile(os.path.join(args.output, 'entities_emb.bin'), dtype=np.float32).reshape((len(mid_dic), args.embed_dim)) 132 | predicates_emb = np.fromfile(os.path.join(args.output, 'predicates_emb.bin'), dtype=np.float32).reshape((-1, args.embed_dim)) 133 | #names_map = {} 134 | index_names = {} 135 | 136 | for i, line in enumerate(open(os.path.join(args.output, 'names.trimmed.txt'), 'r')): 137 | items = line.strip().split("\t") 138 | entity = items[0] 139 | literal = items[1].strip() 140 | if literal != "": 141 | #if names_map.get(entity) is None or len(names_map[entity].split()) > len(literal.split()): 142 | # names_map[entity] = literal 143 | if index_names.get(literal) is None: 144 | index_names[literal] = [entity] 145 | else: 146 | index_names[literal].append(entity) 147 | for fname in ["train.txt", "valid.txt"]: 148 | with open(os.path.join(args.output, fname), 'r') as f: 149 | for line in f: 150 | items = line.strip().split("\t") 151 | if items[2] != '' and mid_dic.get(items[1]) is not None: 152 | if index_names.get(items[2]) is None: 153 | index_names[items[2]] = [items[1]] 154 | else: 155 | index_names[items[2]].append(items[1]) 156 | #if names_map.get(items[1]) is None or len(names_map[items[1]].split()) > len(items[2].split()): 157 | # names_map[items[1]] = items[2] 158 | 159 | 160 | #for fname in ["train.txt", "valid.txt"]: 161 | # with open(os.path.join(args.output, fname), 'r') as f: 162 | # for line in f: 163 | # items = line.strip().split("\t") 164 | # match_pool.extend(list(compress(items[5].split(), [element == 'O' for element in items[6].split()]))) 165 | head_mid_idx = [[] for i in range(total_num)] # [[head1,head2,...], [head1,head2,...], ...] 166 | match_pool = set(match_pool + stopwords.words('english') + ["'s"]) 167 | whhowset = [{'what', 'how', 'where', 'who', 'which', 'whom'}, 168 | {'in which', 'what is', "what 's", 'what are', 'what was', 'what were', 'where is', 'where are', 169 | 'where was', 'where were', 'who is', 'who was', 'who are', 'how is', 'what did'}, 170 | {'what kind of', 'what kinds of', 'what type of', 'what types of', 'what sort of'}] 171 | dete_tokens_list, filter_q = [], [] 172 | for i, question in enumerate(question_list): 173 | question = [token for token in question if token != ''] 174 | pred_span = get_span(dete_result[i], index2tag, type=False) 175 | tokens_list, dete_tokens, st, en, changed = [], [], 0, 0, 0 176 | for st, en in pred_span: 177 | tokens = question[st:en] 178 | tokens_list.append(tokens) 179 | if index_names.get(' '.join(tokens)) is not None: # important 180 | dete_tokens.append(' '.join(tokens)) 181 | head_mid_idx[i].append(' '.join(tokens)) 182 | if len(question) > 2: 183 | for j in range(3, 0, -1): 184 | if ' '.join(question[0:j]) in whhowset[j - 1]: 185 | changed = j 186 | del question[0:j] 187 | continue 188 | tokens_list.append(question) 189 | filter_q.append(' '.join(question[:st - changed] + question[en - changed:])) 190 | if not head_mid_idx[i]: 191 | dete_tokens = question 192 | for tokens in tokens_list: 193 | grams = [] 194 | maxlen = len(tokens) 195 | for j in range(maxlen - 1, 1, -1): 196 | for token in [tokens[idx:idx + j] for idx in range(maxlen - j + 1)]: 197 | grams.append(' '.join(token)) 198 | for gram in grams: 199 | if index_names.get(gram) is not None: 200 | head_mid_idx[i].append(gram) 201 | break 202 | for j, token in enumerate(tokens): 203 | if token not in match_pool: 204 | tokens = tokens[j:] 205 | break 206 | if index_names.get(' '.join(tokens)) is not None: 207 | head_mid_idx[i].append(' '.join(tokens)) 208 | tokens = tokens[::-1] 209 | for j, token in enumerate(tokens): 210 | if token not in match_pool: 211 | tokens = tokens[j:] 212 | break 213 | tokens = tokens[::-1] 214 | if index_names.get(' '.join(tokens)) is not None: 215 | head_mid_idx[i].append(' '.join(tokens)) 216 | dete_tokens_list.append(' '.join(dete_tokens)) 217 | 218 | id_match = set() 219 | match_mid_list = [] 220 | tupleset = [] 221 | for i, names in enumerate(head_mid_idx): 222 | tuplelist = [] 223 | for name in names: 224 | mids = index_names[name] 225 | match_mid_list.extend(mids) 226 | for mid in mids: 227 | if mid_dic.get(mid) is not None: 228 | tuplelist.append((mid, name)) 229 | tupleset.extend(tuplelist) 230 | head_mid_idx[i] = list(set(tuplelist)) 231 | if tuplelist: 232 | id_match.add(i) 233 | tupleset = set(tupleset) 234 | tuple_topic = [] 235 | with open('data/FB5M.name.txt', 'r') as f: 236 | for i, line in enumerate(f): 237 | if i % 1000000 == 0: 238 | print("line: {}".format(i)) 239 | items = line.strip().split("\t") 240 | if (www2fb(clean_uri(items[0])), processed_text(clean_uri(items[2]))) in tupleset and items[1] == "": 241 | tuple_topic.append((www2fb(clean_uri(items[0])), processed_text(clean_uri(items[2])))) 242 | tuple_topic = set(tuple_topic) 243 | 244 | 245 | ######################## Learn entity representation ######################## 246 | head_emb = np.zeros((total_num, args.embed_dim)) 247 | TEXT = data.Field(lower=True) 248 | ED = data.Field(sequential=False, use_vocab=False) 249 | train, dev = data.TabularDataset.splits(path=args.output, train='entity_train.txt', validation='entity_valid.txt', format='tsv', fields=[('text', TEXT), ('mid', ED)]) 250 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', None)] 251 | test = data.TabularDataset(path=os.path.join(args.output, 'test.txt'), format='tsv', fields=field) 252 | TEXT.build_vocab(train, dev, test) # training data includes validation data 253 | 254 | # load the model 255 | if args.gpu == -1: # Load all tensors onto the CPU 256 | test_iter = data.Iterator(test, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 257 | sort_within_batch=False) 258 | model = torch.load(args.entity_model, map_location=lambda storage, loc: storage) 259 | model.config.cuda = False 260 | else: 261 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 262 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 263 | model = torch.load(args.entity_model, map_location=lambda storage, loc: storage.cuda(args.gpu)) 264 | model.eval() 265 | test_iter.init_epoch() 266 | baseidx = 0 267 | for data_batch_idx, data_batch in enumerate(test_iter): 268 | batch_size = data_batch.text.size()[1] 269 | scores = model(data_batch).cpu().data.numpy() 270 | for i in range(batch_size): 271 | head_emb[baseidx + i] = scores[i] 272 | baseidx = baseidx + batch_size 273 | del model 274 | 275 | ######################## Learn predicate representation ######################## 276 | TEXT = data.Field(lower=True) 277 | ED = data.Field(sequential=False, use_vocab=False) 278 | train, dev = data.TabularDataset.splits(path=args.output, train='pred_train.txt', validation='pred_valid.txt', format='tsv', fields=[('text', TEXT), ('mid', ED)]) 279 | field = [('id', None), ('sub', None), ('entity', None), ('relation', None), ('obj', None), ('text', TEXT), ('ed', None)] 280 | test = data.TabularDataset(path=os.path.join(args.output, 'test.txt'), format='tsv', fields=field) 281 | TEXT.build_vocab(train, dev, test) 282 | 283 | # load the model 284 | if args.gpu == -1: # Load all tensors onto the CPU 285 | test_iter = data.Iterator(test, batch_size=args.batch_size, train=False, repeat=False, sort=False, shuffle=False, 286 | sort_within_batch=False) 287 | model = torch.load(args.pred_model, map_location=lambda storage, loc: storage) 288 | model.config.cuda = False 289 | else: 290 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=torch.device('cuda', args.gpu), train=False, 291 | repeat=False, sort=False, shuffle=False, sort_within_batch=False) 292 | model = torch.load(args.pred_model, map_location=lambda storage, loc: storage.cuda(args.gpu)) 293 | model.eval() 294 | test_iter.init_epoch() 295 | baseidx = 0 296 | pred_emb = np.zeros((total_num, args.embed_dim)) 297 | for data_batch_idx, data_batch in enumerate(test_iter): 298 | batch_size = data_batch.text.size()[1] 299 | scores = model(data_batch).cpu().data.numpy() 300 | for i in range(batch_size): 301 | pred_emb[baseidx + i] = scores[i] 302 | baseidx = baseidx + batch_size 303 | del model 304 | 305 | #learned_pred = [] 306 | #ed_dic = {} 307 | #for i, pred in enumerate(ED.vocab.itos): 308 | # ed_dic[i] = pred 309 | #for data_batch_idx, data_batch in enumerate(test_iter): 310 | # batch_size = data_batch.text.size()[1] 311 | # answer = torch.max(model(data_batch), 1)[1] 312 | # for devi in range(batch_size): 313 | # learned_pred.append(pre_dic[ed_dic[answer[devi].item()]]) 314 | #del ed_dic 315 | 316 | ######################## predict and evaluation ######################## 317 | gt_tail = [] # Ground Truth 318 | gt_pred = [] 319 | gt_head = [] # Ground Truth of head entity 320 | for line in open(os.path.join(args.output, 'test.txt'), 'r'): 321 | items = line.strip().split("\t") 322 | gt_head.append(items[1]) 323 | gt_pred.append(items[3]) 324 | gt_tail.append(items[4]) 325 | 326 | notmatch = list(set(range(0, total_num)).symmetric_difference(id_match)) 327 | print('{} out of {} nonmatching names, matching accuracy: {}'.format(len(notmatch), total_num, (total_num-len(notmatch))/total_num)) 328 | 329 | 330 | notmatch_idx = euclidean_distances(head_emb[notmatch], entities_emb, squared=True).argsort(axis=1) 331 | for idx, i in enumerate(notmatch): 332 | for j in notmatch_idx[idx, 0:40]: 333 | mid = mid_num_dic[j] 334 | head_mid_idx[i].append((mid, None)) 335 | match_mid_list.append(mid) 336 | 337 | correct, mid_num = 0, 0 338 | for i, head_ids in enumerate(head_mid_idx): 339 | mids = set() 340 | for (head_id, name) in head_ids: 341 | mids.add(head_id) 342 | if gt_head[i] in mids: 343 | correct += 1 344 | mid_num += len(mids) 345 | print('recall of head entity prediction: {}, num of mids per example {}'.format(correct/total_num, (mid_num + len(notmatch))/total_num)) 346 | 347 | reach_dic = compute_reach_dic(set(match_mid_list)) 348 | learned_pred, learned_fact, learned_head = [-1] * total_num, {}, [-1] * total_num 349 | 350 | alpha1, alpha3 = .39, .43 351 | for i, head_ids in enumerate(head_mid_idx): # head_ids is mids 352 | if i % 1000 == 1: 353 | print('progress: {}'.format(i / total_num), end='\r') 354 | answers = [] 355 | for (head_id, name) in head_ids: 356 | mid_score = np.sqrt(np.sum(np.power(entities_emb[mid_dic[head_id]] - head_emb[i], 2))) 357 | #if name is None and head_id in names_map: 358 | # name = names_map[head_id] 359 | name_score = - .003 * fuzz.ratio(name, dete_tokens_list[i]) 360 | if (head_id, name) in tuple_topic: 361 | name_score -= .18 362 | if reach_dic.get(head_id) is not None: 363 | for pred_id in reach_dic[head_id]: # reach_dic[head_id] = pred_id are numbers 364 | rel_names = - .017 * fuzz.ratio(pre_num_dic[pred_id].replace('.', ' ').replace('_', ' '), filter_q[i]) #0.017 365 | rel_score = np.sqrt(np.sum(np.power(predicates_emb[pred_id] - pred_emb[i], 2))) + rel_names 366 | tai_score = np.sqrt(np.sum( 367 | np.power(predicates_emb[pred_id] + entities_emb[mid_dic[head_id]] - head_emb[i] - pred_emb[i], 2))) 368 | answers.append((head_id, pred_id, alpha1 * mid_score + rel_score + alpha3 * tai_score + name_score)) 369 | if answers: 370 | answers.sort(key=lambda x: x[2]) 371 | learned_head[i] = answers[0][0] 372 | learned_pred[i] = answers[0][1] 373 | learned_fact[' '.join([learned_head[i], pre_num_dic[learned_pred[i]]])] = i 374 | 375 | learned_tail = [[] for i in range(total_num)] 376 | for line in open(os.path.join(args.output, 'cleanedFB.txt'), 'r'): 377 | items = line.strip().split("\t") 378 | if learned_fact.get(' '.join([items[0], items[2]])) is not None: 379 | learned_tail[learned_fact[' '.join([items[0], items[2]])]].extend(items[1].split()) 380 | # for i, tail_id in enumerate(learned_tail): 381 | # if not tail_id: 382 | # learned_tail[i] = mid_num_dic[euclidean_distances( 383 | # (entities_emb[mid_dic[learned_head[i]]] + predicates_emb[learned_pred[i]]).reshape(1, -1), entities_emb, 384 | # squared=True).argmin(axis=1)[0]] 385 | 386 | corr_head, correct, corr_all = 0, 0, 0 387 | for i, tail_id in enumerate(gt_tail): 388 | if gt_head[i] == learned_head[i]: 389 | corr_head += 1 390 | if gt_pred[i] == pre_num_dic[learned_pred[i]]: 391 | correct += 1 392 | if tail_id in learned_tail[i]: 393 | corr_all += 1 394 | 395 | print('final accuracy: {}, head acc {}, all acc {}'.format(correct / total_num, corr_head / total_num, corr_all / total_num)) 396 | --------------------------------------------------------------------------------