├── NeuralQA ├── AnswerRerank │ ├── demo.py │ ├── rerank.py │ └── util.py ├── EntityLinking │ ├── change_format.py │ ├── demo.py │ ├── entity_linking.py │ └── util.py ├── MentionDetection │ ├── crf │ │ ├── convert.py │ │ ├── eval.py │ │ ├── output.py │ │ └── stanford-ner │ │ │ └── NERDemo.java │ └── nn │ │ ├── args.py │ │ ├── demo.py │ │ ├── model.py │ │ ├── test.py │ │ ├── tmp_data.txt │ │ ├── train.py │ │ └── util │ │ ├── datasets.py │ │ └── util.py └── RelationDetection │ ├── nn │ ├── args.py │ ├── datasets.py │ ├── demo.py │ ├── model.py │ ├── preprocess.py │ ├── test.py │ └── train.py │ └── siamese │ ├── args.py │ ├── datasets.py │ ├── eval.py │ ├── model.py │ ├── train.py │ └── util.py ├── README.md ├── SiameseNetwork ├── config.py ├── qa_cnn.py ├── qa_lstm.py ├── qa_nn.py ├── train_cnn.py ├── train_lstm.py ├── train_nn.py └── util │ ├── dataset.py │ └── util.py ├── WebQA ├── data │ └── indexes │ │ └── relation_sub_2M.pkl ├── src │ ├── args.py │ ├── main.py │ ├── model.py │ └── simpleQA.py ├── static │ ├── css │ │ └── bootstrap.min.css │ ├── images │ │ ├── favicon.ico │ │ └── help.png │ └── js │ │ ├── bootstrap.js │ │ ├── bootstrap.min.js │ │ └── jquery-3.1.1.min.js ├── templates │ ├── coming_soon.html │ ├── entity.html │ ├── homepage.html │ ├── mention.html │ ├── relation.html │ └── test.html └── util │ ├── datasets.py │ └── utils.py └── demo.gif /NeuralQA/AnswerRerank/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | 6 | import math 7 | 8 | from util import clean_uri, www2fb, rdf2fb 9 | 10 | 11 | # Load up reachability graph 12 | 13 | def load_index(filename): 14 | # print("Loading index map from {}".format(filename)) 15 | with open(filename, 'rb') as handler: 16 | index = pickle.load(handler) 17 | return index 18 | 19 | 20 | # Load predicted MIDs and relations for each question in valid/test set 21 | def get_mids(filename, hits): 22 | # print("Entity Source : {}".format(filename)) 23 | id2mids = defaultdict(list) 24 | fin = open(filename) 25 | for line in fin.readlines(): 26 | items = line.strip().split(' %%%% ') 27 | lineid = items[0] 28 | cand_mids = items[1:][:hits] 29 | for mid_entry in cand_mids: 30 | # TODO:WHY MID_TYPE ONLY ONE! TYPE IS FROM CFO NAME EITHER TYPE.TOPIC.NAME OR COMMON.TOPIC.ALIAS 31 | mid, mid_name, mid_type, score = mid_entry.split('\t') 32 | id2mids[lineid].append((mid, mid_name, mid_type, float(score))) 33 | return id2mids 34 | 35 | 36 | def get_rels(filename, hits): 37 | # print("Relation Source : {}".format(filename)) 38 | id2rels = defaultdict(list) 39 | fin = open(filename) 40 | for line in fin.readlines(): 41 | items = line.strip().split(' %%%% ') 42 | lineid = items[0].strip() 43 | rel = www2fb(items[1].strip()) 44 | label = items[2].strip() 45 | score = items[3].strip() 46 | if len(id2rels[lineid]) < hits: 47 | id2rels[lineid].append((rel, label, float(score))) 48 | return id2rels 49 | 50 | 51 | def get_questions(filename): 52 | # print("getting questions ...") 53 | id2questions = {} 54 | id2goldmids = {} 55 | fin = open(filename) 56 | for line in fin.readlines(): 57 | items = line.strip().split('\t') 58 | lineid = items[0].strip() 59 | mid = items[1].strip() 60 | question = items[5].strip() 61 | rel = items[3].strip() 62 | id2questions[lineid] = (question, rel) 63 | id2goldmids[lineid] = mid 64 | return id2questions, id2goldmids 65 | 66 | 67 | def get_mid2wiki(filename): 68 | # print("Loading Wiki") 69 | mid2wiki = defaultdict(bool) 70 | fin = open(filename) 71 | for line in fin.readlines(): 72 | items = line.strip().split('\t') 73 | sub = rdf2fb(clean_uri(items[0])) 74 | mid2wiki[sub] = True 75 | return mid2wiki 76 | 77 | 78 | def evidence_integration(index_reach, index_degrees, mid2wiki, is_heuristics, input_entity, input_relation): 79 | id2answers = list() 80 | mids = input_entity.split("\n") 81 | rels = input_relation.split("\n") 82 | # print(rels) 83 | 84 | if is_heuristics: 85 | for item in mids: 86 | _, mid, mid_name, mid_type, mid_score = item.strip().split("\t") 87 | for item2 in rels: 88 | rel, rel_log_score = item2.strip().split("\t") 89 | # if this (mid, rel) exists in FB 90 | if rel in index_reach[mid]: 91 | rel_score = math.exp(float(rel_log_score)) 92 | comb_score = (float(mid_score) ** 0.6) * (rel_score ** 0.1) 93 | id2answers.append((mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, 94 | int(mid2wiki[mid]), int(index_degrees[mid][0]))) 95 | # I cannot use retrieved here because I use contain different name_type 96 | # if mid ==truth_mid and rel == truth_rel: 97 | # retrieved += 1 98 | id2answers.sort(key=lambda t: (t[6], t[3], t[7], t[8]), reverse=True) 99 | else: 100 | id2answers = [(mids[0][0], rels[0][0])] 101 | 102 | # write to file 103 | # TODO:CHANGED FOR SWITCH IS_HEURISTICS 104 | if is_heuristics: 105 | for answer in id2answers: 106 | mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, _, _ = answer 107 | print("{}\t{}\t{}\t{}\t{}".format(mid, rel, mid_name, mid_score, rel_score, comb_score)) 108 | else: 109 | for answer in id2answers: 110 | mid, rel = answer 111 | print("{}\t{}".format(mid, rel)) 112 | return id2answers 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = ArgumentParser(description='Perform evidence integration') 117 | parser.add_argument('--ent_type', type=str, required=True, help="options are [crf|lstm|gru]") 118 | parser.add_argument('--rel_type', type=str, required=True, help="options are [lr|cnn|lstm|gru]") 119 | parser.add_argument('--index_reachpath', type=str, default="indexes/reachability_2M.pkl", 120 | help='path to the pickle for the reachability index') 121 | parser.add_argument('--index_degreespath', type=str, default="indexes/degrees_2M.pkl", 122 | help='path to the pickle for the index with the degree counts') 123 | parser.add_argument('--data_path', type=str, default="data/processed_simplequestions_dataset/test.txt") 124 | parser.add_argument('--ent_path', type=str, default="entity_linking/results/lstm/test-h100.txt", 125 | help='path to the entity linking results') 126 | parser.add_argument('--rel_path', type=str, default="relation_prediction/nn/results/cnn/test.txt", 127 | help='path to the relation prediction results') 128 | parser.add_argument('--wiki_path', type=str, default="data/fb2w.nt") 129 | parser.add_argument('--hits_ent', type=int, default=50, 130 | help='the hits here has to be <= the hits in entity linking') 131 | parser.add_argument('--hits_rel', type=int, default=5, 132 | help='the hits here has to be <= the hits in relation prediction retrieval') 133 | parser.add_argument('--no_heuristics', action='store_false', help='do not use heuristics', dest='heuristics') 134 | parser.add_argument('--output_dir', type=str, default="./results") 135 | 136 | # added for demo 137 | parser.add_argument('--input_ent_path', type=str, default='el_result.txt') 138 | parser.add_argument('--input_rel_path', type=str, default='rp_result.txt') 139 | args = parser.parse_args() 140 | # print(args) 141 | 142 | ent_type = args.ent_type.lower() 143 | rel_type = args.rel_type.lower() 144 | output_dir = os.path.join(args.output_dir, "{}-{}".format(ent_type, rel_type)) 145 | os.makedirs(output_dir, exist_ok=True) 146 | 147 | index_reach = load_index(args.index_reachpath) 148 | # print(index_reach) 149 | index_degrees = load_index(args.index_degreespath) 150 | mid2wiki = get_mid2wiki(args.wiki_path) 151 | 152 | candidate_entity = open(args.input_ent_path).read().strip() 153 | candidate_relation = open(args.input_rel_path).read().strip() 154 | test_answers = evidence_integration(index_reach, index_degrees, mid2wiki, args.heuristics, candidate_entity, 155 | candidate_relation) 156 | -------------------------------------------------------------------------------- /NeuralQA/AnswerRerank/rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | 5 | import math 6 | from util import get_mid2wiki, get_rels, get_mids, get_questions, www2fb, load_index 7 | 8 | 9 | def answer_rerank(data_path, ent_path, rel_path, output_dir, index_reach, index_degrees, mid2wiki, is_heuristics, 10 | ent_hits, rel_hits): 11 | id2questions, id2goldmids = get_questions(data_path) 12 | id2mids = get_mids(ent_path, ent_hits) 13 | id2rels = get_rels(rel_path, rel_hits) 14 | file_base_name = os.path.basename(data_path) 15 | fout = open(os.path.join(output_dir, file_base_name), 'w') 16 | 17 | id2answers = defaultdict(list) 18 | found, notfound_both, notfound_mid, notfound_rel = 0, 0, 0, 0 19 | retrieved, retrieved_top1, retrieved_top2, retrieved_top3 = 0, 0, 0, 0 20 | lineids_found1 = [] 21 | lineids_found2 = [] 22 | lineids_found3 = [] 23 | 24 | # for every lineid 25 | for line_id in id2goldmids: 26 | if line_id not in id2mids and line_id not in id2rels: 27 | notfound_both += 1 28 | continue 29 | elif line_id not in id2mids: 30 | notfound_mid += 1 31 | continue 32 | elif line_id not in id2rels: 33 | notfound_rel += 1 34 | continue 35 | found += 1 36 | question, truth_rel = id2questions[line_id] 37 | truth_rel = www2fb(truth_rel) 38 | truth_mid = id2goldmids[line_id] 39 | mids = id2mids[line_id] 40 | rels = id2rels[line_id] 41 | 42 | if is_heuristics: 43 | for (mid, mid_name, mid_type, mid_score) in mids: 44 | for (rel, rel_label, rel_log_score) in rels: 45 | # if this (mid, rel) exists in FB 46 | if rel in index_reach[mid]: 47 | rel_score = math.exp(float(rel_log_score)) 48 | comb_score = (float(mid_score) ** 0.6) * (rel_score ** 0.1) 49 | id2answers[line_id].append((mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, 50 | int(mid2wiki[mid]), int(index_degrees[mid][0]))) 51 | # I cannot use retrieved here because I use contain different name_type 52 | # if mid ==truth_mid and rel == truth_rel: 53 | # retrieved += 1 54 | id2answers[line_id].sort(key=lambda t: (t[6], t[3], t[7], t[8]), reverse=True) 55 | else: 56 | id2answers[line_id] = [(mids[0][0], rels[0][0])] 57 | 58 | # write to file 59 | fout.write("{}".format(line_id)) 60 | if is_heuristics: 61 | for answer in id2answers[line_id]: 62 | mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, _, _ = answer 63 | fout.write(" %%%% {}\t{}\t{}\t{}\t{}".format(mid, rel, mid_name, mid_score, rel_score, comb_score)) 64 | else: 65 | for answer in id2answers[line_id]: 66 | mid, rel = answer 67 | fout.write(" %%%% {}\t{}".format(mid, rel)) 68 | fout.write('\n') 69 | 70 | if is_heuristics: 71 | if len(id2answers[line_id]) >= 1 and id2answers[line_id][0][1] == truth_rel: # id2answers[line_id][0][0] == truth_mid and 72 | retrieved_top1 += 1 73 | retrieved_top2 += 1 74 | retrieved_top3 += 1 75 | lineids_found1.append(line_id) 76 | elif len(id2answers[line_id]) >= 2 and id2answers[line_id][1][0] == truth_mid \ 77 | and id2answers[line_id][1][1] == truth_rel: 78 | retrieved_top2 += 1 79 | retrieved_top3 += 1 80 | lineids_found2.append(line_id) 81 | elif len(id2answers[line_id]) >= 3 and id2answers[line_id][2][0] == truth_mid \ 82 | and id2answers[line_id][2][1] == truth_rel: 83 | retrieved_top3 += 1 84 | lineids_found3.append(line_id) 85 | else: 86 | if len(id2answers[line_id]) >= 1 and id2answers[line_id][0][0] == truth_mid \ 87 | and id2answers[line_id][0][1] == truth_rel: 88 | retrieved_top1 += 1 89 | retrieved_top2 += 1 90 | retrieved_top3 += 1 91 | lineids_found1.append(line_id) 92 | print() 93 | print("found: {}".format(found / len(id2goldmids) * 100.0)) 94 | print("retrieved at top 1: {}".format(retrieved_top1 / len(id2goldmids) * 100.0)) 95 | print("retrieved at top 2: {}".format(retrieved_top2 / len(id2goldmids) * 100.0)) 96 | print("retrieved at top 3: {}".format(retrieved_top3 / len(id2goldmids) * 100.0)) 97 | # print("retrieved at inf: {}".format(retrieved / len(id2goldmids) * 100.0)) 98 | fout.close() 99 | return id2answers 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = ArgumentParser(description='Perform evidence integration') 104 | parser.add_argument('--ent_type', type=str, required=True, help="options are [crf|lstm|gru]") 105 | parser.add_argument('--rel_type', type=str, required=True, help="options are [lr|cnn|lstm|gru]") 106 | parser.add_argument('--index_reachpath', type=str, default="../indexes/reachability_2M.pkl", 107 | help='path to the pickle for the reachability index') 108 | parser.add_argument('--index_degreespath', type=str, default="../indexes/degrees_2M.pkl", 109 | help='path to the pickle for the index with the degree counts') 110 | parser.add_argument('--data_path', type=str, default="../data/processed_simplequestions_dataset/test.txt") 111 | parser.add_argument('--ent_path', type=str, default="../entity_linking/results/crf/test-h100.txt", 112 | help='path to the entity linking results') 113 | parser.add_argument('--rel_path', type=str, default="../relation_prediction/nn/results/cnn/test.txt", 114 | help='path to the relation prediction results') 115 | parser.add_argument('--wiki_path', type=str, default="../data/fb2w.nt") 116 | parser.add_argument('--hits_ent', type=int, default=50, 117 | help='the hits here has to be <= the hits in entity linking') 118 | parser.add_argument('--hits_rel', type=int, default=5, 119 | help='the hits here has to be <= the hits in relation prediction retrieval') 120 | parser.add_argument('--no_heuristics', action='store_false', help='do not use heuristics', dest='heuristics') 121 | parser.add_argument('--output_dir', type=str, default="./results") 122 | args = parser.parse_args() 123 | print(args) 124 | 125 | ent_type = args.ent_type.lower() 126 | rel_type = args.rel_type.lower() 127 | # assert (ent_type == "crf" or ent_type == "lstm" or ent_type == "gru") 128 | # assert (rel_type == "lr" or rel_type == "cnn" or rel_type == "lstm" or rel_type == "gru") 129 | output_dir = os.path.join(args.output_dir, "{}-{}".format(ent_type, rel_type)) 130 | os.makedirs(output_dir, exist_ok=True) 131 | 132 | index_reach = load_index(args.index_reachpath) 133 | index_degrees = load_index(args.index_degreespath) 134 | mid2wiki = get_mid2wiki(args.wiki_path) 135 | 136 | test_answers = answer_rerank(args.data_path, args.ent_path, args.rel_path, output_dir, index_reach, 137 | index_degrees, mid2wiki, args.heuristics, args.hits_ent, args.hits_rel) 138 | -------------------------------------------------------------------------------- /NeuralQA/AnswerRerank/util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from nltk.tokenize.treebank import TreebankWordTokenizer 3 | 4 | tokenizer = TreebankWordTokenizer() 5 | 6 | 7 | # Load up reachability graph 8 | 9 | def load_index(filename): 10 | print("Loading index map from {}".format(filename)) 11 | with open(filename, 'rb') as handler: 12 | index = pickle.load(handler) 13 | return index 14 | 15 | 16 | # Load predicted MIDs and relations for each question in valid/test set 17 | def get_mids(filename, hits): 18 | print("Entity Source : {}".format(filename)) 19 | id2mids = defaultdict(list) 20 | fin = open(filename) 21 | for line in fin.readlines(): 22 | items = line.strip().split(' %%%% ') 23 | lineid = items[0] 24 | cand_mids = items[1:][:hits] 25 | for mid_entry in cand_mids: 26 | # TODO:WHY MID_TYPE ONLY ONE! TYPE IS FROM CFO NAME EITHER TYPE.TOPIC.NAME OR COMMON.TOPIC.ALIAS 27 | mid, mid_name, mid_type, score = mid_entry.split('\t') 28 | id2mids[lineid].append((mid, mid_name, mid_type, float(score))) 29 | return id2mids 30 | 31 | 32 | def get_rels(filename, hits): 33 | print("Relation Source : {}".format(filename)) 34 | id2rels = defaultdict(list) 35 | fin = open(filename) 36 | for line in fin.readlines(): 37 | items = line.strip().split(' %%%% ') 38 | lineid = items[0].strip() 39 | rel = www2fb(items[1].strip()) 40 | label = items[2].strip() 41 | score = items[3].strip() 42 | if len(id2rels[lineid]) < hits: 43 | id2rels[lineid].append((rel, label, float(score))) 44 | return id2rels 45 | 46 | 47 | def get_questions(filename): 48 | print("getting questions ...") 49 | id2questions = {} 50 | id2goldmids = {} 51 | fin = open(filename) 52 | for line in fin.readlines(): 53 | items = line.strip().split('\t') 54 | lineid = items[0].strip() 55 | mid = items[1].strip() 56 | question = items[5].strip() 57 | rel = items[3].strip() 58 | id2questions[lineid] = (question, rel) 59 | id2goldmids[lineid] = mid 60 | return id2questions, id2goldmids 61 | 62 | 63 | def get_mid2wiki(filename): 64 | print("Loading Wiki") 65 | mid2wiki = defaultdict(bool) 66 | fin = open(filename) 67 | for line in fin.readlines(): 68 | items = line.strip().split('\t') 69 | sub = rdf2fb(clean_uri(items[0])) 70 | mid2wiki[sub] = True 71 | return mid2wiki 72 | 73 | 74 | def processed_text(text): 75 | text = text.replace('\\\\', '') 76 | # stripped = strip_accents(text.lower()) 77 | stripped = text.lower() 78 | toks = tokenizer.tokenize(stripped) 79 | return " ".join(toks) 80 | 81 | 82 | def strip_accents(text): 83 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 84 | 85 | 86 | def www2fb(in_str): 87 | if in_str.startswith("www.freebase.com"): 88 | in_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 89 | if in_str == 'fb:m.07s9rl0': 90 | in_str = 'fb:m.02822' 91 | if in_str == 'fb:m.0bb56b6': 92 | in_str = 'fb:m.0dn0r' 93 | # Manual Correction 94 | if in_str == 'fb:m.01g81dw': 95 | in_str = 'fb:m.01g_bfh' 96 | if in_str == 'fb:m.0y7q89y': 97 | in_str = 'fb:m.0wrt1c5' 98 | if in_str == 'fb:m.0b0w7': 99 | in_str = 'fb:m.0fq0s89' 100 | if in_str == 'fb:m.09rmm6y': 101 | in_str = 'fb:m.03cnrcc' 102 | if in_str == 'fb:m.0crsn60': 103 | in_str = 'fb:m.02pnlqy' 104 | if in_str == 'fb:m.04t1f8y': 105 | in_str = 'fb:m.04t1fjr' 106 | if in_str == 'fb:m.027z990': 107 | in_str = 'fb:m.0ghdhcb' 108 | if in_str == 'fb:m.02xhc2v': 109 | in_str = 'fb:m.084sq' 110 | if in_str == 'fb:m.02z8b2h': 111 | in_str = 'fb:m.033vn1' 112 | if in_str == 'fb:m.0w43mcj': 113 | in_str = 'fb:m.0m0qffc' 114 | if in_str == 'fb:m.07rqy': 115 | in_str = 'fb:m.0py_0' 116 | if in_str == 'fb:m.0y9s5rm': 117 | in_str = 'fb:m.0ybxl2g' 118 | if in_str == 'fb:m.037ltr7': 119 | in_str = 'fb:m.0qjx99s' 120 | return in_str 121 | 122 | 123 | def clean_uri(uri): 124 | if uri.startswith("<") and uri.endswith(">"): 125 | return clean_uri(uri[1:-1]) 126 | elif uri.startswith("\"") and uri.endswith("\""): 127 | return clean_uri(uri[1:-1]) 128 | return uri 129 | 130 | 131 | def rdf2fb(in_str): 132 | if in_str.startswith('http://rdf.freebase.com/ns/'): 133 | return 'fb:%s' % (in_str.split('http://rdf.freebase.com/ns/')[-1]) 134 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/change_format.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import FileUtil 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--filename") 7 | args = parser.parse_args() 8 | filename = args.filename 9 | 10 | context = FileUtil.readFile(filename) 11 | output = [] 12 | for i, c in enumerate(context): 13 | if i % 3 == 0: 14 | output.append("test-{} %%%% {}".format(int(i / 3 + 1), context[i + 1])) 15 | FileUtil.writeFile(output, filename + ".query") 16 | print("All done!") 17 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | 6 | from fuzzywuzzy import fuzz 7 | from nltk.corpus import stopwords 8 | from tqdm import tqdm 9 | 10 | from util import www2fb 11 | 12 | inverted_index = defaultdict(list) 13 | stopword = set(stopwords.words('english')) 14 | 15 | 16 | def get_ngram(text): 17 | # ngram = set() 18 | ngram = [] 19 | tokens = text.split() 20 | for i in range(len(tokens) + 1): 21 | for j in range(i): 22 | if i - j <= 3: # todo 23 | # ngram.add(" ".join(tokens[j:i])) 24 | temp = " ".join(tokens[j:i]) 25 | if temp not in ngram: 26 | ngram.append(temp) 27 | # ngram = list(ngram) 28 | ngram = sorted(ngram, key=lambda x: len(x.split()), reverse=True) 29 | return ngram 30 | 31 | 32 | def get_stat_inverted_index(filename): 33 | """ 34 | Get the number of entry and max length of the entry (How many mid in an entry) 35 | """ 36 | with open(filename, "rb") as handler: 37 | global inverted_index 38 | inverted_index = pickle.load(handler) 39 | inverted_index = defaultdict(str, inverted_index) 40 | # print("Total type of text: {}".format(len(inverted_index))) 41 | max_len = 0 42 | _entry = "" 43 | for entry, value in inverted_index.items(): 44 | if len(value) > max_len: 45 | max_len = len(value) 46 | _entry = entry 47 | # print("Max Length of entry is {}, text is {}".format(max_len, _entry)) 48 | 49 | 50 | def entity_linking(pred_mention, top_num): 51 | C = [] 52 | C_scored = [] 53 | tokens = get_ngram(pred_mention) 54 | 55 | if len(tokens) > 0: 56 | maxlen = len(tokens[0].split()) 57 | for item in tokens: 58 | if len(item.split()) < maxlen and len(C) == 0: 59 | maxlen = len(item.split()) 60 | if len(item.split()) < maxlen and len(C) > 0: 61 | break 62 | if item in stopword: 63 | continue 64 | C.extend(inverted_index[item]) 65 | # if len(C) > 0: 66 | # break 67 | for mid_text_type in sorted(set(C)): 68 | score = fuzz.ratio(mid_text_type[1], pred_mention) / 100.0 69 | # C_counts format : ((mid, text, type), score_based_on_fuzz) 70 | C_scored.append((mid_text_type, score)) 71 | 72 | C_scored.sort(key=lambda t: t[1], reverse=True) 73 | cand_mids = C_scored[:top_num] 74 | for mid_text_type, score in cand_mids: 75 | print("{}\t{}\t{}\t{}\t{}".format(pred_mention, mid_text_type[0], mid_text_type[1], mid_text_type[2], score)) 76 | 77 | if __name__ == "__main__": 78 | parser = ArgumentParser(description='Perform entity linking') 79 | parser.add_argument('--model_type', type=str, required=True, help="options are [crf|lstm|gru]") 80 | parser.add_argument('--index_ent', type=str, default="indexes/entity_2M.pkl", 81 | help='path to the pickle for the inverted entity index') 82 | parser.add_argument('--data_dir', type=str, default="data/processed_simplequestions_dataset") 83 | parser.add_argument('--query_dir', type=str, default="entity_detection/crf/query_text") 84 | parser.add_argument('--hits', type=int, default=10) 85 | parser.add_argument('--output_dir', type=str, default="./results") 86 | 87 | # added for demo 88 | # parser.add_argument('--input_mention', type=str, default='yao ming') 89 | parser.add_argument('--input_path', type=str, default='') 90 | args = parser.parse_args() 91 | # print(args) 92 | 93 | input_mentions = open(args.input_path).read().strip() 94 | # print("input_manetion:", input_mention) 95 | 96 | get_stat_inverted_index(args.index_ent) 97 | for mention in input_mentions.split(): 98 | entity_linking(mention, args.hits) 99 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/entity_linking.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | 5 | from fuzzywuzzy import fuzz 6 | from nltk.corpus import stopwords 7 | from tqdm import tqdm 8 | from util import www2fb, get_ngram, get_stat_inverted_index 9 | 10 | stopword = set(stopwords.words('english')) 11 | inverted_index = defaultdict(list) 12 | 13 | 14 | def entity_linking(predicted_file, gold_file, hits, output): 15 | predicted = open(predicted_file) 16 | gold = open(gold_file) 17 | fout = open(output, 'w') 18 | total, top1, top3, top5, top10, top20, top50, top100 = 0, 0, 0, 0, 0, 0, 0, 0 19 | for idx, (line, gold_id) in tqdm(enumerate(zip(predicted.readlines(), gold.readlines()))): 20 | total += 1 21 | line = line.strip().split(" %%%% ") 22 | gold_id = gold_id.strip().split('\t')[1] 23 | cand_entity, cand_score = [], [] 24 | line_id = line[0] 25 | if len(line) == 2: 26 | tokens = get_ngram(line[1]) 27 | else: 28 | tokens = [] 29 | 30 | if len(tokens) > 0: 31 | maxlen = len(tokens[0].split()) # 1, 2, 3 32 | # print(maxlen) 33 | for item in tokens: # todo 34 | if len(item.split()) < maxlen and len(cand_entity) == 0: 35 | maxlen = len(item.split()) 36 | if len(item.split()) < maxlen and len(cand_entity) > 0: 37 | break 38 | if item in stopword: 39 | continue 40 | cand_entity.extend(inverted_index[item]) 41 | print(item) 42 | print(inverted_index[item]) # all name/alias contain 'item'(string) 43 | # if len(cand_entity) > 0: 44 | # break 45 | print(cand_entity) 46 | for mid_text_type in sorted(set(cand_entity)): 47 | score = fuzz.ratio(mid_text_type[1], line[1]) / 100.0 48 | cand_score.append((mid_text_type, score)) 49 | 50 | cand_score.sort(key=lambda t: t[1], reverse=True) 51 | cand_mids = cand_score[:hits] 52 | fout.write("{}".format(line_id)) 53 | for mid_text_type, score in cand_mids: 54 | fout.write(" %%%% {}\t{}\t{}\t{}".format(mid_text_type[0], mid_text_type[1], mid_text_type[2], score)) 55 | fout.write('\n') 56 | gold_id = www2fb(gold_id) 57 | mids_list = [x[0][0] for x in cand_mids] 58 | if gold_id in mids_list[:1]: 59 | top1 += 1 60 | if gold_id in mids_list[:3]: 61 | top3 += 1 62 | if gold_id in mids_list[:5]: 63 | top5 += 1 64 | if gold_id in mids_list[:10]: 65 | top10 += 1 66 | if gold_id in mids_list[:20]: 67 | top20 += 1 68 | if gold_id in mids_list[:50]: 69 | top50 += 1 70 | if gold_id in mids_list[:100]: 71 | top100 += 1 72 | 73 | print("total: {}".format(total)) 74 | print("Top1 Entity Linking Accuracy: {}".format(top1 / total)) 75 | print("Top3 Entity Linking Accuracy: {}".format(top3 / total)) 76 | print("Top5 Entity Linking Accuracy: {}".format(top5 / total)) 77 | print("Top10 Entity Linking Accuracy: {}".format(top10 / total)) 78 | print("Top20 Entity Linking Accuracy: {}".format(top20 / total)) 79 | print("Top50 Entity Linking Accuracy: {}".format(top50 / total)) 80 | print("Top100 Entity Linking Accuracy: {}".format(top100 / total)) 81 | 82 | 83 | if __name__ == "__main__": 84 | # print(get_ngram("Which team have LeBron played basketball ?")) 85 | parser = ArgumentParser(description='Perform entity linking') 86 | parser.add_argument('--model_type', type=str, required=True, help="options are [crf|lstm|gru]") 87 | parser.add_argument('--index_ent', type=str, default="../indexes/entity_2M.pkl", 88 | help='path to the pickle for the inverted entity index') 89 | parser.add_argument('--data_dir', type=str, default="../data/processed_simplequestions_dataset") 90 | parser.add_argument('--query_dir', type=str, default="../entity_detection/nn/query_text") 91 | parser.add_argument('--hits', type=int, default=100) 92 | parser.add_argument('--output_dir', type=str, default="./results") 93 | args = parser.parse_args() 94 | print(args) 95 | 96 | model_type = args.model_type.lower() 97 | # assert(model_type == "crf" or model_type == "lstm" or model_type == "gru") 98 | output_dir = os.path.join(args.output_dir, model_type) 99 | os.makedirs(output_dir, exist_ok=True) 100 | 101 | get_stat_inverted_index(args.index_ent) 102 | print("valid result:") 103 | entity_linking( 104 | os.path.join(args.query_dir, "query.valid"), 105 | os.path.join(args.data_dir, "valid.txt"), 106 | args.hits, 107 | os.path.join(output_dir, "valid-h{}.txt".format(args.hits))) 108 | 109 | print("test result:") 110 | entity_linking( 111 | os.path.join(args.query_dir, "query.test"), 112 | os.path.join(args.data_dir, "test.txt"), 113 | args.hits, 114 | os.path.join(output_dir, "test-h{}.txt".format(args.hits))) 115 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from nltk.tokenize.treebank import TreebankWordTokenizer 3 | 4 | tokenizer = TreebankWordTokenizer() 5 | 6 | 7 | def readFile(filename): 8 | context = open(filename).readlines() 9 | return [c.strip() for c in context] 10 | 11 | 12 | def writeFile(context, filename, append=False): 13 | if not append: 14 | with open(filename, 'w+') as fout: 15 | for co in context: 16 | fout.write(co + "\n") 17 | else: 18 | with open(filename, 'a+') as fout: 19 | for co in context: 20 | fout.write(co + "\n") 21 | 22 | 23 | def list2str(l, split=" "): 24 | a = "" 25 | for li in l: 26 | a += (str(li) + split) 27 | a = a[:-len(split)] 28 | return a 29 | 30 | 31 | def get_ngram(text): 32 | # ngram = set() 33 | ngram = [] 34 | tokens = text.split() 35 | for i in range(len(tokens) + 1): 36 | for j in range(i): 37 | if i - j <= 3: # 3 ? 38 | # ngram.add(" ".join(tokens[j:i])) 39 | temp = " ".join(tokens[j:i]) 40 | if temp not in ngram: 41 | ngram.append(temp) 42 | # ngram = list(ngram) 43 | ngram = sorted(ngram, key=lambda x: len(x.split()), reverse=True) 44 | return ngram 45 | 46 | 47 | def get_stat_inverted_index(filename): 48 | """ 49 | Get the number of entry and max length of the entry (How many mid in an entry) 50 | """ 51 | with open(filename, "rb") as handler: 52 | global inverted_index 53 | inverted_index = pickle.load(handler) 54 | inverted_index = defaultdict(str, inverted_index) 55 | print("Total type of text: {}".format(len(inverted_index))) 56 | max_len = 0 57 | _entry = "" 58 | for entry, value in inverted_index.items(): 59 | if len(value) > max_len: 60 | max_len = len(value) 61 | _entry = entry 62 | print("Max Length of entry is {}, text is {}".format(max_len, _entry)) 63 | 64 | 65 | def processed_text(text): 66 | text = text.replace('\\\\', '') 67 | # stripped = strip_accents(text.lower()) 68 | stripped = text.lower() 69 | toks = tokenizer.tokenize(stripped) 70 | return " ".join(toks) 71 | 72 | 73 | def strip_accents(text): 74 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 75 | 76 | 77 | def www2fb(in_str): 78 | if in_str.startswith("www.freebase.com"): 79 | in_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 80 | if in_str == 'fb:m.07s9rl0': 81 | in_str = 'fb:m.02822' 82 | if in_str == 'fb:m.0bb56b6': 83 | in_str = 'fb:m.0dn0r' 84 | # Manual Correction 85 | if in_str == 'fb:m.01g81dw': 86 | in_str = 'fb:m.01g_bfh' 87 | if in_str == 'fb:m.0y7q89y': 88 | in_str = 'fb:m.0wrt1c5' 89 | if in_str == 'fb:m.0b0w7': 90 | in_str = 'fb:m.0fq0s89' 91 | if in_str == 'fb:m.09rmm6y': 92 | in_str = 'fb:m.03cnrcc' 93 | if in_str == 'fb:m.0crsn60': 94 | in_str = 'fb:m.02pnlqy' 95 | if in_str == 'fb:m.04t1f8y': 96 | in_str = 'fb:m.04t1fjr' 97 | if in_str == 'fb:m.027z990': 98 | in_str = 'fb:m.0ghdhcb' 99 | if in_str == 'fb:m.02xhc2v': 100 | in_str = 'fb:m.084sq' 101 | if in_str == 'fb:m.02z8b2h': 102 | in_str = 'fb:m.033vn1' 103 | if in_str == 'fb:m.0w43mcj': 104 | in_str = 'fb:m.0m0qffc' 105 | if in_str == 'fb:m.07rqy': 106 | in_str = 'fb:m.0py_0' 107 | if in_str == 'fb:m.0y9s5rm': 108 | in_str = 'fb:m.0ybxl2g' 109 | if in_str == 'fb:m.037ltr7': 110 | in_str = 'fb:m.0qjx99s' 111 | return in_str 112 | 113 | 114 | def clean_uri(uri): 115 | if uri.startswith("<") and uri.endswith(">"): 116 | return clean_uri(uri[1:-1]) 117 | elif uri.startswith("\"") and uri.endswith("\""): 118 | return clean_uri(uri[1:-1]) 119 | return uri 120 | 121 | 122 | def rdf2fb(in_str): 123 | if in_str.startswith('http://rdf.freebase.com/ns/'): 124 | return 'fb:%s' % (in_str.split('http://rdf.freebase.com/ns/')[-1]) 125 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/convert.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def convert(filename, output): 5 | fin = open(filename, 'r') 6 | fout = open(output, 'w') 7 | for line in fin.readlines(): 8 | items = line.strip().split('\t') 9 | sent, label = items[5], items[6] 10 | for word, tag in zip(sent.strip().split(), label.strip().split()): 11 | fout.write("{}\t{}\n".format(word, tag)) 12 | fout.write("\n") 13 | fout.close() 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = ArgumentParser(description='Convert dataset to stanford format for training') 18 | parser.add_argument('--data_dir', type=str, default="../../../data/processed_simplequestions_dataset/train.txt") 19 | parser.add_argument('--save_path', type=str, default="data/stanford.train") 20 | args = parser.parse_args() 21 | convert(args.data_dir, args.save_path) 22 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def get_span(label): 5 | start, end = 0, 0 6 | flag = False 7 | span = [] 8 | for k, l in enumerate(label): 9 | if l == 'I' and not flag: 10 | start = k 11 | flag = True 12 | if l != 'I' and flag: 13 | flag = False 14 | en = k 15 | span.append((start, en)) 16 | start, end = 0, 0 17 | if start != 0 and end == 0: 18 | end = len(label) + 1 # bug fixed: geoff 19 | span.append((start, end)) 20 | return span 21 | 22 | 23 | def evaluation(filename): 24 | fin = open(filename, 'r') 25 | pred = [] 26 | gold = [] 27 | right = 0 28 | predicted = 0 29 | total_en = 0 30 | for line in fin.readlines(): 31 | if line == '\n': 32 | gold_span = get_span(gold) 33 | pred_span = get_span(pred) 34 | total_en += len(gold_span) 35 | predicted += len(pred_span) 36 | for item in pred_span: 37 | if item in gold_span: 38 | right += 1 39 | gold = [] 40 | pred = [] 41 | else: 42 | word, gold_label, pred_label = line.strip().split() 43 | gold.append(gold_label) 44 | pred.append(pred_label) 45 | 46 | if gold != [] or pred != []: 47 | gold_span = get_span(gold) 48 | pred_span = get_span(pred) 49 | total_en += len(gold_span) 50 | predicted += len(pred_span) 51 | for item in pred_span: 52 | if item in gold_span: 53 | right += 1 54 | 55 | if predicted == 0: 56 | precision = 0 57 | else: 58 | precision = right / predicted 59 | if total_en == 0: 60 | recall = 0 61 | else: 62 | recall = right / total_en 63 | if precision + recall == 0: 64 | f1 = 0 65 | else: 66 | f1 = 2 * precision * recall / (precision + recall) 67 | print("Precision", precision, "Recall", recall, "F1", f1, "right", right, "predicted", predicted, "total", total_en) 68 | 69 | 70 | if __name__ == '__main__': 71 | if len(sys.argv) != 2: 72 | print("Need to specify the file") 73 | filename = sys.argv[1] 74 | evaluation(filename) 75 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/output.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def convert(fileName, idFile, outputFile): 5 | fin = open(fileName) 6 | fid = open(idFile) 7 | fout = open(outputFile, "w") 8 | word_list = [] 9 | pred_query = [] 10 | line_id = [] 11 | for line in fid.readlines(): 12 | line_id.append(line.strip()) 13 | index = 0 14 | for line in fin.readlines(): 15 | if line == '\n': 16 | if len(pred_query) == 0: 17 | pred_query = word_list 18 | fout.write("{} %%%% {}\n".format(line_id[index], " ".join(pred_query))) 19 | index += 1 20 | pred_query = [] 21 | word_list = [] 22 | else: 23 | word, gold_label, pred_label = line.strip().split() 24 | word_list.append(word) 25 | if pred_label == 'I': 26 | pred_query.append(word) 27 | if (index != len(line_id)): 28 | print("Length Error") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = ArgumentParser(description='Convert result to query text') 33 | parser.add_argument('--data_dir', type=str, default="stanford-ner/data/stanford.predicted.valid") 34 | parser.add_argument('--valid_line', type=str, 35 | default="../../data/processed_simplequestions_dataset/lineids_valid.txt") 36 | parser.add_argument('--results_path', type=str, default="query_text/query.valid") 37 | args = parser.parse_args() 38 | convert(args.data_dir, args.valid_line, args.results_path) 39 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/stanford-ner/NERDemo.java: -------------------------------------------------------------------------------- 1 | import edu.stanford.nlp.ie.AbstractSequenceClassifier; 2 | import edu.stanford.nlp.ie.crf.*; 3 | import edu.stanford.nlp.io.IOUtils; 4 | import edu.stanford.nlp.ling.CoreLabel; 5 | import edu.stanford.nlp.ling.CoreAnnotations; 6 | import edu.stanford.nlp.sequences.DocumentReaderAndWriter; 7 | import edu.stanford.nlp.util.Triple; 8 | 9 | import java.util.List; 10 | 11 | 12 | /** This is a demo of calling CRFClassifier programmatically. 13 | *
14 | * Usage: {@code java -mx400m -cp "*" NERDemo [serializedClassifier [fileName]] } 15 | *
16 | * If arguments aren't specified, they default to 17 | * classifiers/english.all.3class.distsim.crf.ser.gz and some hardcoded sample text. 18 | * If run with arguments, it shows some of the ways to get k-best labelings and 19 | * probabilities out with CRFClassifier. If run without arguments, it shows some of 20 | * the alternative output formats that you can get. 21 | *
22 | * To use CRFClassifier from the command line: 23 | *
24 | * {@code java -mx400m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier [classifier] -textFile [file] } 25 | *
26 | * Or if the file is already tokenized and one word per line, perhaps in 27 | * a tab-separated value format with extra columns for part-of-speech tag, 28 | * etc., use the version below (note the 's' instead of the 'x'): 29 | *
30 | * {@code java -mx400m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier [classifier] -testFile [file] } 31 | *32 | * 33 | * @author Jenny Finkel 34 | * @author Christopher Manning 35 | */ 36 | 37 | public class NERDemo { 38 | 39 | public static void main(String[] args) throws Exception { 40 | 41 | String serializedClassifier = "classifiers/english.all.3class.distsim.crf.ser.gz"; 42 | 43 | if (args.length > 0) { 44 | serializedClassifier = args[0]; 45 | } 46 | 47 | AbstractSequenceClassifier