├── NeuralQA ├── AnswerRerank │ ├── demo.py │ ├── rerank.py │ └── util.py ├── EntityLinking │ ├── change_format.py │ ├── demo.py │ ├── entity_linking.py │ └── util.py ├── MentionDetection │ ├── crf │ │ ├── convert.py │ │ ├── eval.py │ │ ├── output.py │ │ └── stanford-ner │ │ │ └── NERDemo.java │ └── nn │ │ ├── args.py │ │ ├── demo.py │ │ ├── model.py │ │ ├── test.py │ │ ├── tmp_data.txt │ │ ├── train.py │ │ └── util │ │ ├── datasets.py │ │ └── util.py └── RelationDetection │ ├── nn │ ├── args.py │ ├── datasets.py │ ├── demo.py │ ├── model.py │ ├── preprocess.py │ ├── test.py │ └── train.py │ └── siamese │ ├── args.py │ ├── datasets.py │ ├── eval.py │ ├── model.py │ ├── train.py │ └── util.py ├── README.md ├── SiameseNetwork ├── config.py ├── qa_cnn.py ├── qa_lstm.py ├── qa_nn.py ├── train_cnn.py ├── train_lstm.py ├── train_nn.py └── util │ ├── dataset.py │ └── util.py ├── WebQA ├── data │ └── indexes │ │ └── relation_sub_2M.pkl ├── src │ ├── args.py │ ├── main.py │ ├── model.py │ └── simpleQA.py ├── static │ ├── css │ │ └── bootstrap.min.css │ ├── images │ │ ├── favicon.ico │ │ └── help.png │ └── js │ │ ├── bootstrap.js │ │ ├── bootstrap.min.js │ │ └── jquery-3.1.1.min.js ├── templates │ ├── coming_soon.html │ ├── entity.html │ ├── homepage.html │ ├── mention.html │ ├── relation.html │ └── test.html └── util │ ├── datasets.py │ └── utils.py └── demo.gif /NeuralQA/AnswerRerank/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | 6 | import math 7 | 8 | from util import clean_uri, www2fb, rdf2fb 9 | 10 | 11 | # Load up reachability graph 12 | 13 | def load_index(filename): 14 | # print("Loading index map from {}".format(filename)) 15 | with open(filename, 'rb') as handler: 16 | index = pickle.load(handler) 17 | return index 18 | 19 | 20 | # Load predicted MIDs and relations for each question in valid/test set 21 | def get_mids(filename, hits): 22 | # print("Entity Source : {}".format(filename)) 23 | id2mids = defaultdict(list) 24 | fin = open(filename) 25 | for line in fin.readlines(): 26 | items = line.strip().split(' %%%% ') 27 | lineid = items[0] 28 | cand_mids = items[1:][:hits] 29 | for mid_entry in cand_mids: 30 | # TODO:WHY MID_TYPE ONLY ONE! TYPE IS FROM CFO NAME EITHER TYPE.TOPIC.NAME OR COMMON.TOPIC.ALIAS 31 | mid, mid_name, mid_type, score = mid_entry.split('\t') 32 | id2mids[lineid].append((mid, mid_name, mid_type, float(score))) 33 | return id2mids 34 | 35 | 36 | def get_rels(filename, hits): 37 | # print("Relation Source : {}".format(filename)) 38 | id2rels = defaultdict(list) 39 | fin = open(filename) 40 | for line in fin.readlines(): 41 | items = line.strip().split(' %%%% ') 42 | lineid = items[0].strip() 43 | rel = www2fb(items[1].strip()) 44 | label = items[2].strip() 45 | score = items[3].strip() 46 | if len(id2rels[lineid]) < hits: 47 | id2rels[lineid].append((rel, label, float(score))) 48 | return id2rels 49 | 50 | 51 | def get_questions(filename): 52 | # print("getting questions ...") 53 | id2questions = {} 54 | id2goldmids = {} 55 | fin = open(filename) 56 | for line in fin.readlines(): 57 | items = line.strip().split('\t') 58 | lineid = items[0].strip() 59 | mid = items[1].strip() 60 | question = items[5].strip() 61 | rel = items[3].strip() 62 | id2questions[lineid] = (question, rel) 63 | id2goldmids[lineid] = mid 64 | return id2questions, id2goldmids 65 | 66 | 67 | def get_mid2wiki(filename): 68 | # print("Loading Wiki") 69 | mid2wiki = defaultdict(bool) 70 | fin = open(filename) 71 | for line in fin.readlines(): 72 | items = line.strip().split('\t') 73 | sub = rdf2fb(clean_uri(items[0])) 74 | mid2wiki[sub] = True 75 | return mid2wiki 76 | 77 | 78 | def evidence_integration(index_reach, index_degrees, mid2wiki, is_heuristics, input_entity, input_relation): 79 | id2answers = list() 80 | mids = input_entity.split("\n") 81 | rels = input_relation.split("\n") 82 | # print(rels) 83 | 84 | if is_heuristics: 85 | for item in mids: 86 | _, mid, mid_name, mid_type, mid_score = item.strip().split("\t") 87 | for item2 in rels: 88 | rel, rel_log_score = item2.strip().split("\t") 89 | # if this (mid, rel) exists in FB 90 | if rel in index_reach[mid]: 91 | rel_score = math.exp(float(rel_log_score)) 92 | comb_score = (float(mid_score) ** 0.6) * (rel_score ** 0.1) 93 | id2answers.append((mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, 94 | int(mid2wiki[mid]), int(index_degrees[mid][0]))) 95 | # I cannot use retrieved here because I use contain different name_type 96 | # if mid ==truth_mid and rel == truth_rel: 97 | # retrieved += 1 98 | id2answers.sort(key=lambda t: (t[6], t[3], t[7], t[8]), reverse=True) 99 | else: 100 | id2answers = [(mids[0][0], rels[0][0])] 101 | 102 | # write to file 103 | # TODO:CHANGED FOR SWITCH IS_HEURISTICS 104 | if is_heuristics: 105 | for answer in id2answers: 106 | mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, _, _ = answer 107 | print("{}\t{}\t{}\t{}\t{}".format(mid, rel, mid_name, mid_score, rel_score, comb_score)) 108 | else: 109 | for answer in id2answers: 110 | mid, rel = answer 111 | print("{}\t{}".format(mid, rel)) 112 | return id2answers 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = ArgumentParser(description='Perform evidence integration') 117 | parser.add_argument('--ent_type', type=str, required=True, help="options are [crf|lstm|gru]") 118 | parser.add_argument('--rel_type', type=str, required=True, help="options are [lr|cnn|lstm|gru]") 119 | parser.add_argument('--index_reachpath', type=str, default="indexes/reachability_2M.pkl", 120 | help='path to the pickle for the reachability index') 121 | parser.add_argument('--index_degreespath', type=str, default="indexes/degrees_2M.pkl", 122 | help='path to the pickle for the index with the degree counts') 123 | parser.add_argument('--data_path', type=str, default="data/processed_simplequestions_dataset/test.txt") 124 | parser.add_argument('--ent_path', type=str, default="entity_linking/results/lstm/test-h100.txt", 125 | help='path to the entity linking results') 126 | parser.add_argument('--rel_path', type=str, default="relation_prediction/nn/results/cnn/test.txt", 127 | help='path to the relation prediction results') 128 | parser.add_argument('--wiki_path', type=str, default="data/fb2w.nt") 129 | parser.add_argument('--hits_ent', type=int, default=50, 130 | help='the hits here has to be <= the hits in entity linking') 131 | parser.add_argument('--hits_rel', type=int, default=5, 132 | help='the hits here has to be <= the hits in relation prediction retrieval') 133 | parser.add_argument('--no_heuristics', action='store_false', help='do not use heuristics', dest='heuristics') 134 | parser.add_argument('--output_dir', type=str, default="./results") 135 | 136 | # added for demo 137 | parser.add_argument('--input_ent_path', type=str, default='el_result.txt') 138 | parser.add_argument('--input_rel_path', type=str, default='rp_result.txt') 139 | args = parser.parse_args() 140 | # print(args) 141 | 142 | ent_type = args.ent_type.lower() 143 | rel_type = args.rel_type.lower() 144 | output_dir = os.path.join(args.output_dir, "{}-{}".format(ent_type, rel_type)) 145 | os.makedirs(output_dir, exist_ok=True) 146 | 147 | index_reach = load_index(args.index_reachpath) 148 | # print(index_reach) 149 | index_degrees = load_index(args.index_degreespath) 150 | mid2wiki = get_mid2wiki(args.wiki_path) 151 | 152 | candidate_entity = open(args.input_ent_path).read().strip() 153 | candidate_relation = open(args.input_rel_path).read().strip() 154 | test_answers = evidence_integration(index_reach, index_degrees, mid2wiki, args.heuristics, candidate_entity, 155 | candidate_relation) 156 | -------------------------------------------------------------------------------- /NeuralQA/AnswerRerank/rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | 5 | import math 6 | from util import get_mid2wiki, get_rels, get_mids, get_questions, www2fb, load_index 7 | 8 | 9 | def answer_rerank(data_path, ent_path, rel_path, output_dir, index_reach, index_degrees, mid2wiki, is_heuristics, 10 | ent_hits, rel_hits): 11 | id2questions, id2goldmids = get_questions(data_path) 12 | id2mids = get_mids(ent_path, ent_hits) 13 | id2rels = get_rels(rel_path, rel_hits) 14 | file_base_name = os.path.basename(data_path) 15 | fout = open(os.path.join(output_dir, file_base_name), 'w') 16 | 17 | id2answers = defaultdict(list) 18 | found, notfound_both, notfound_mid, notfound_rel = 0, 0, 0, 0 19 | retrieved, retrieved_top1, retrieved_top2, retrieved_top3 = 0, 0, 0, 0 20 | lineids_found1 = [] 21 | lineids_found2 = [] 22 | lineids_found3 = [] 23 | 24 | # for every lineid 25 | for line_id in id2goldmids: 26 | if line_id not in id2mids and line_id not in id2rels: 27 | notfound_both += 1 28 | continue 29 | elif line_id not in id2mids: 30 | notfound_mid += 1 31 | continue 32 | elif line_id not in id2rels: 33 | notfound_rel += 1 34 | continue 35 | found += 1 36 | question, truth_rel = id2questions[line_id] 37 | truth_rel = www2fb(truth_rel) 38 | truth_mid = id2goldmids[line_id] 39 | mids = id2mids[line_id] 40 | rels = id2rels[line_id] 41 | 42 | if is_heuristics: 43 | for (mid, mid_name, mid_type, mid_score) in mids: 44 | for (rel, rel_label, rel_log_score) in rels: 45 | # if this (mid, rel) exists in FB 46 | if rel in index_reach[mid]: 47 | rel_score = math.exp(float(rel_log_score)) 48 | comb_score = (float(mid_score) ** 0.6) * (rel_score ** 0.1) 49 | id2answers[line_id].append((mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, 50 | int(mid2wiki[mid]), int(index_degrees[mid][0]))) 51 | # I cannot use retrieved here because I use contain different name_type 52 | # if mid ==truth_mid and rel == truth_rel: 53 | # retrieved += 1 54 | id2answers[line_id].sort(key=lambda t: (t[6], t[3], t[7], t[8]), reverse=True) 55 | else: 56 | id2answers[line_id] = [(mids[0][0], rels[0][0])] 57 | 58 | # write to file 59 | fout.write("{}".format(line_id)) 60 | if is_heuristics: 61 | for answer in id2answers[line_id]: 62 | mid, rel, mid_name, mid_type, mid_score, rel_score, comb_score, _, _ = answer 63 | fout.write(" %%%% {}\t{}\t{}\t{}\t{}".format(mid, rel, mid_name, mid_score, rel_score, comb_score)) 64 | else: 65 | for answer in id2answers[line_id]: 66 | mid, rel = answer 67 | fout.write(" %%%% {}\t{}".format(mid, rel)) 68 | fout.write('\n') 69 | 70 | if is_heuristics: 71 | if len(id2answers[line_id]) >= 1 and id2answers[line_id][0][1] == truth_rel: # id2answers[line_id][0][0] == truth_mid and 72 | retrieved_top1 += 1 73 | retrieved_top2 += 1 74 | retrieved_top3 += 1 75 | lineids_found1.append(line_id) 76 | elif len(id2answers[line_id]) >= 2 and id2answers[line_id][1][0] == truth_mid \ 77 | and id2answers[line_id][1][1] == truth_rel: 78 | retrieved_top2 += 1 79 | retrieved_top3 += 1 80 | lineids_found2.append(line_id) 81 | elif len(id2answers[line_id]) >= 3 and id2answers[line_id][2][0] == truth_mid \ 82 | and id2answers[line_id][2][1] == truth_rel: 83 | retrieved_top3 += 1 84 | lineids_found3.append(line_id) 85 | else: 86 | if len(id2answers[line_id]) >= 1 and id2answers[line_id][0][0] == truth_mid \ 87 | and id2answers[line_id][0][1] == truth_rel: 88 | retrieved_top1 += 1 89 | retrieved_top2 += 1 90 | retrieved_top3 += 1 91 | lineids_found1.append(line_id) 92 | print() 93 | print("found: {}".format(found / len(id2goldmids) * 100.0)) 94 | print("retrieved at top 1: {}".format(retrieved_top1 / len(id2goldmids) * 100.0)) 95 | print("retrieved at top 2: {}".format(retrieved_top2 / len(id2goldmids) * 100.0)) 96 | print("retrieved at top 3: {}".format(retrieved_top3 / len(id2goldmids) * 100.0)) 97 | # print("retrieved at inf: {}".format(retrieved / len(id2goldmids) * 100.0)) 98 | fout.close() 99 | return id2answers 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = ArgumentParser(description='Perform evidence integration') 104 | parser.add_argument('--ent_type', type=str, required=True, help="options are [crf|lstm|gru]") 105 | parser.add_argument('--rel_type', type=str, required=True, help="options are [lr|cnn|lstm|gru]") 106 | parser.add_argument('--index_reachpath', type=str, default="../indexes/reachability_2M.pkl", 107 | help='path to the pickle for the reachability index') 108 | parser.add_argument('--index_degreespath', type=str, default="../indexes/degrees_2M.pkl", 109 | help='path to the pickle for the index with the degree counts') 110 | parser.add_argument('--data_path', type=str, default="../data/processed_simplequestions_dataset/test.txt") 111 | parser.add_argument('--ent_path', type=str, default="../entity_linking/results/crf/test-h100.txt", 112 | help='path to the entity linking results') 113 | parser.add_argument('--rel_path', type=str, default="../relation_prediction/nn/results/cnn/test.txt", 114 | help='path to the relation prediction results') 115 | parser.add_argument('--wiki_path', type=str, default="../data/fb2w.nt") 116 | parser.add_argument('--hits_ent', type=int, default=50, 117 | help='the hits here has to be <= the hits in entity linking') 118 | parser.add_argument('--hits_rel', type=int, default=5, 119 | help='the hits here has to be <= the hits in relation prediction retrieval') 120 | parser.add_argument('--no_heuristics', action='store_false', help='do not use heuristics', dest='heuristics') 121 | parser.add_argument('--output_dir', type=str, default="./results") 122 | args = parser.parse_args() 123 | print(args) 124 | 125 | ent_type = args.ent_type.lower() 126 | rel_type = args.rel_type.lower() 127 | # assert (ent_type == "crf" or ent_type == "lstm" or ent_type == "gru") 128 | # assert (rel_type == "lr" or rel_type == "cnn" or rel_type == "lstm" or rel_type == "gru") 129 | output_dir = os.path.join(args.output_dir, "{}-{}".format(ent_type, rel_type)) 130 | os.makedirs(output_dir, exist_ok=True) 131 | 132 | index_reach = load_index(args.index_reachpath) 133 | index_degrees = load_index(args.index_degreespath) 134 | mid2wiki = get_mid2wiki(args.wiki_path) 135 | 136 | test_answers = answer_rerank(args.data_path, args.ent_path, args.rel_path, output_dir, index_reach, 137 | index_degrees, mid2wiki, args.heuristics, args.hits_ent, args.hits_rel) 138 | -------------------------------------------------------------------------------- /NeuralQA/AnswerRerank/util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from nltk.tokenize.treebank import TreebankWordTokenizer 3 | 4 | tokenizer = TreebankWordTokenizer() 5 | 6 | 7 | # Load up reachability graph 8 | 9 | def load_index(filename): 10 | print("Loading index map from {}".format(filename)) 11 | with open(filename, 'rb') as handler: 12 | index = pickle.load(handler) 13 | return index 14 | 15 | 16 | # Load predicted MIDs and relations for each question in valid/test set 17 | def get_mids(filename, hits): 18 | print("Entity Source : {}".format(filename)) 19 | id2mids = defaultdict(list) 20 | fin = open(filename) 21 | for line in fin.readlines(): 22 | items = line.strip().split(' %%%% ') 23 | lineid = items[0] 24 | cand_mids = items[1:][:hits] 25 | for mid_entry in cand_mids: 26 | # TODO:WHY MID_TYPE ONLY ONE! TYPE IS FROM CFO NAME EITHER TYPE.TOPIC.NAME OR COMMON.TOPIC.ALIAS 27 | mid, mid_name, mid_type, score = mid_entry.split('\t') 28 | id2mids[lineid].append((mid, mid_name, mid_type, float(score))) 29 | return id2mids 30 | 31 | 32 | def get_rels(filename, hits): 33 | print("Relation Source : {}".format(filename)) 34 | id2rels = defaultdict(list) 35 | fin = open(filename) 36 | for line in fin.readlines(): 37 | items = line.strip().split(' %%%% ') 38 | lineid = items[0].strip() 39 | rel = www2fb(items[1].strip()) 40 | label = items[2].strip() 41 | score = items[3].strip() 42 | if len(id2rels[lineid]) < hits: 43 | id2rels[lineid].append((rel, label, float(score))) 44 | return id2rels 45 | 46 | 47 | def get_questions(filename): 48 | print("getting questions ...") 49 | id2questions = {} 50 | id2goldmids = {} 51 | fin = open(filename) 52 | for line in fin.readlines(): 53 | items = line.strip().split('\t') 54 | lineid = items[0].strip() 55 | mid = items[1].strip() 56 | question = items[5].strip() 57 | rel = items[3].strip() 58 | id2questions[lineid] = (question, rel) 59 | id2goldmids[lineid] = mid 60 | return id2questions, id2goldmids 61 | 62 | 63 | def get_mid2wiki(filename): 64 | print("Loading Wiki") 65 | mid2wiki = defaultdict(bool) 66 | fin = open(filename) 67 | for line in fin.readlines(): 68 | items = line.strip().split('\t') 69 | sub = rdf2fb(clean_uri(items[0])) 70 | mid2wiki[sub] = True 71 | return mid2wiki 72 | 73 | 74 | def processed_text(text): 75 | text = text.replace('\\\\', '') 76 | # stripped = strip_accents(text.lower()) 77 | stripped = text.lower() 78 | toks = tokenizer.tokenize(stripped) 79 | return " ".join(toks) 80 | 81 | 82 | def strip_accents(text): 83 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 84 | 85 | 86 | def www2fb(in_str): 87 | if in_str.startswith("www.freebase.com"): 88 | in_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 89 | if in_str == 'fb:m.07s9rl0': 90 | in_str = 'fb:m.02822' 91 | if in_str == 'fb:m.0bb56b6': 92 | in_str = 'fb:m.0dn0r' 93 | # Manual Correction 94 | if in_str == 'fb:m.01g81dw': 95 | in_str = 'fb:m.01g_bfh' 96 | if in_str == 'fb:m.0y7q89y': 97 | in_str = 'fb:m.0wrt1c5' 98 | if in_str == 'fb:m.0b0w7': 99 | in_str = 'fb:m.0fq0s89' 100 | if in_str == 'fb:m.09rmm6y': 101 | in_str = 'fb:m.03cnrcc' 102 | if in_str == 'fb:m.0crsn60': 103 | in_str = 'fb:m.02pnlqy' 104 | if in_str == 'fb:m.04t1f8y': 105 | in_str = 'fb:m.04t1fjr' 106 | if in_str == 'fb:m.027z990': 107 | in_str = 'fb:m.0ghdhcb' 108 | if in_str == 'fb:m.02xhc2v': 109 | in_str = 'fb:m.084sq' 110 | if in_str == 'fb:m.02z8b2h': 111 | in_str = 'fb:m.033vn1' 112 | if in_str == 'fb:m.0w43mcj': 113 | in_str = 'fb:m.0m0qffc' 114 | if in_str == 'fb:m.07rqy': 115 | in_str = 'fb:m.0py_0' 116 | if in_str == 'fb:m.0y9s5rm': 117 | in_str = 'fb:m.0ybxl2g' 118 | if in_str == 'fb:m.037ltr7': 119 | in_str = 'fb:m.0qjx99s' 120 | return in_str 121 | 122 | 123 | def clean_uri(uri): 124 | if uri.startswith("<") and uri.endswith(">"): 125 | return clean_uri(uri[1:-1]) 126 | elif uri.startswith("\"") and uri.endswith("\""): 127 | return clean_uri(uri[1:-1]) 128 | return uri 129 | 130 | 131 | def rdf2fb(in_str): 132 | if in_str.startswith('http://rdf.freebase.com/ns/'): 133 | return 'fb:%s' % (in_str.split('http://rdf.freebase.com/ns/')[-1]) 134 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/change_format.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import FileUtil 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--filename") 7 | args = parser.parse_args() 8 | filename = args.filename 9 | 10 | context = FileUtil.readFile(filename) 11 | output = [] 12 | for i, c in enumerate(context): 13 | if i % 3 == 0: 14 | output.append("test-{} %%%% {}".format(int(i / 3 + 1), context[i + 1])) 15 | FileUtil.writeFile(output, filename + ".query") 16 | print("All done!") 17 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | 6 | from fuzzywuzzy import fuzz 7 | from nltk.corpus import stopwords 8 | from tqdm import tqdm 9 | 10 | from util import www2fb 11 | 12 | inverted_index = defaultdict(list) 13 | stopword = set(stopwords.words('english')) 14 | 15 | 16 | def get_ngram(text): 17 | # ngram = set() 18 | ngram = [] 19 | tokens = text.split() 20 | for i in range(len(tokens) + 1): 21 | for j in range(i): 22 | if i - j <= 3: # todo 23 | # ngram.add(" ".join(tokens[j:i])) 24 | temp = " ".join(tokens[j:i]) 25 | if temp not in ngram: 26 | ngram.append(temp) 27 | # ngram = list(ngram) 28 | ngram = sorted(ngram, key=lambda x: len(x.split()), reverse=True) 29 | return ngram 30 | 31 | 32 | def get_stat_inverted_index(filename): 33 | """ 34 | Get the number of entry and max length of the entry (How many mid in an entry) 35 | """ 36 | with open(filename, "rb") as handler: 37 | global inverted_index 38 | inverted_index = pickle.load(handler) 39 | inverted_index = defaultdict(str, inverted_index) 40 | # print("Total type of text: {}".format(len(inverted_index))) 41 | max_len = 0 42 | _entry = "" 43 | for entry, value in inverted_index.items(): 44 | if len(value) > max_len: 45 | max_len = len(value) 46 | _entry = entry 47 | # print("Max Length of entry is {}, text is {}".format(max_len, _entry)) 48 | 49 | 50 | def entity_linking(pred_mention, top_num): 51 | C = [] 52 | C_scored = [] 53 | tokens = get_ngram(pred_mention) 54 | 55 | if len(tokens) > 0: 56 | maxlen = len(tokens[0].split()) 57 | for item in tokens: 58 | if len(item.split()) < maxlen and len(C) == 0: 59 | maxlen = len(item.split()) 60 | if len(item.split()) < maxlen and len(C) > 0: 61 | break 62 | if item in stopword: 63 | continue 64 | C.extend(inverted_index[item]) 65 | # if len(C) > 0: 66 | # break 67 | for mid_text_type in sorted(set(C)): 68 | score = fuzz.ratio(mid_text_type[1], pred_mention) / 100.0 69 | # C_counts format : ((mid, text, type), score_based_on_fuzz) 70 | C_scored.append((mid_text_type, score)) 71 | 72 | C_scored.sort(key=lambda t: t[1], reverse=True) 73 | cand_mids = C_scored[:top_num] 74 | for mid_text_type, score in cand_mids: 75 | print("{}\t{}\t{}\t{}\t{}".format(pred_mention, mid_text_type[0], mid_text_type[1], mid_text_type[2], score)) 76 | 77 | if __name__ == "__main__": 78 | parser = ArgumentParser(description='Perform entity linking') 79 | parser.add_argument('--model_type', type=str, required=True, help="options are [crf|lstm|gru]") 80 | parser.add_argument('--index_ent', type=str, default="indexes/entity_2M.pkl", 81 | help='path to the pickle for the inverted entity index') 82 | parser.add_argument('--data_dir', type=str, default="data/processed_simplequestions_dataset") 83 | parser.add_argument('--query_dir', type=str, default="entity_detection/crf/query_text") 84 | parser.add_argument('--hits', type=int, default=10) 85 | parser.add_argument('--output_dir', type=str, default="./results") 86 | 87 | # added for demo 88 | # parser.add_argument('--input_mention', type=str, default='yao ming') 89 | parser.add_argument('--input_path', type=str, default='') 90 | args = parser.parse_args() 91 | # print(args) 92 | 93 | input_mentions = open(args.input_path).read().strip() 94 | # print("input_manetion:", input_mention) 95 | 96 | get_stat_inverted_index(args.index_ent) 97 | for mention in input_mentions.split(): 98 | entity_linking(mention, args.hits) 99 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/entity_linking.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | 5 | from fuzzywuzzy import fuzz 6 | from nltk.corpus import stopwords 7 | from tqdm import tqdm 8 | from util import www2fb, get_ngram, get_stat_inverted_index 9 | 10 | stopword = set(stopwords.words('english')) 11 | inverted_index = defaultdict(list) 12 | 13 | 14 | def entity_linking(predicted_file, gold_file, hits, output): 15 | predicted = open(predicted_file) 16 | gold = open(gold_file) 17 | fout = open(output, 'w') 18 | total, top1, top3, top5, top10, top20, top50, top100 = 0, 0, 0, 0, 0, 0, 0, 0 19 | for idx, (line, gold_id) in tqdm(enumerate(zip(predicted.readlines(), gold.readlines()))): 20 | total += 1 21 | line = line.strip().split(" %%%% ") 22 | gold_id = gold_id.strip().split('\t')[1] 23 | cand_entity, cand_score = [], [] 24 | line_id = line[0] 25 | if len(line) == 2: 26 | tokens = get_ngram(line[1]) 27 | else: 28 | tokens = [] 29 | 30 | if len(tokens) > 0: 31 | maxlen = len(tokens[0].split()) # 1, 2, 3 32 | # print(maxlen) 33 | for item in tokens: # todo 34 | if len(item.split()) < maxlen and len(cand_entity) == 0: 35 | maxlen = len(item.split()) 36 | if len(item.split()) < maxlen and len(cand_entity) > 0: 37 | break 38 | if item in stopword: 39 | continue 40 | cand_entity.extend(inverted_index[item]) 41 | print(item) 42 | print(inverted_index[item]) # all name/alias contain 'item'(string) 43 | # if len(cand_entity) > 0: 44 | # break 45 | print(cand_entity) 46 | for mid_text_type in sorted(set(cand_entity)): 47 | score = fuzz.ratio(mid_text_type[1], line[1]) / 100.0 48 | cand_score.append((mid_text_type, score)) 49 | 50 | cand_score.sort(key=lambda t: t[1], reverse=True) 51 | cand_mids = cand_score[:hits] 52 | fout.write("{}".format(line_id)) 53 | for mid_text_type, score in cand_mids: 54 | fout.write(" %%%% {}\t{}\t{}\t{}".format(mid_text_type[0], mid_text_type[1], mid_text_type[2], score)) 55 | fout.write('\n') 56 | gold_id = www2fb(gold_id) 57 | mids_list = [x[0][0] for x in cand_mids] 58 | if gold_id in mids_list[:1]: 59 | top1 += 1 60 | if gold_id in mids_list[:3]: 61 | top3 += 1 62 | if gold_id in mids_list[:5]: 63 | top5 += 1 64 | if gold_id in mids_list[:10]: 65 | top10 += 1 66 | if gold_id in mids_list[:20]: 67 | top20 += 1 68 | if gold_id in mids_list[:50]: 69 | top50 += 1 70 | if gold_id in mids_list[:100]: 71 | top100 += 1 72 | 73 | print("total: {}".format(total)) 74 | print("Top1 Entity Linking Accuracy: {}".format(top1 / total)) 75 | print("Top3 Entity Linking Accuracy: {}".format(top3 / total)) 76 | print("Top5 Entity Linking Accuracy: {}".format(top5 / total)) 77 | print("Top10 Entity Linking Accuracy: {}".format(top10 / total)) 78 | print("Top20 Entity Linking Accuracy: {}".format(top20 / total)) 79 | print("Top50 Entity Linking Accuracy: {}".format(top50 / total)) 80 | print("Top100 Entity Linking Accuracy: {}".format(top100 / total)) 81 | 82 | 83 | if __name__ == "__main__": 84 | # print(get_ngram("Which team have LeBron played basketball ?")) 85 | parser = ArgumentParser(description='Perform entity linking') 86 | parser.add_argument('--model_type', type=str, required=True, help="options are [crf|lstm|gru]") 87 | parser.add_argument('--index_ent', type=str, default="../indexes/entity_2M.pkl", 88 | help='path to the pickle for the inverted entity index') 89 | parser.add_argument('--data_dir', type=str, default="../data/processed_simplequestions_dataset") 90 | parser.add_argument('--query_dir', type=str, default="../entity_detection/nn/query_text") 91 | parser.add_argument('--hits', type=int, default=100) 92 | parser.add_argument('--output_dir', type=str, default="./results") 93 | args = parser.parse_args() 94 | print(args) 95 | 96 | model_type = args.model_type.lower() 97 | # assert(model_type == "crf" or model_type == "lstm" or model_type == "gru") 98 | output_dir = os.path.join(args.output_dir, model_type) 99 | os.makedirs(output_dir, exist_ok=True) 100 | 101 | get_stat_inverted_index(args.index_ent) 102 | print("valid result:") 103 | entity_linking( 104 | os.path.join(args.query_dir, "query.valid"), 105 | os.path.join(args.data_dir, "valid.txt"), 106 | args.hits, 107 | os.path.join(output_dir, "valid-h{}.txt".format(args.hits))) 108 | 109 | print("test result:") 110 | entity_linking( 111 | os.path.join(args.query_dir, "query.test"), 112 | os.path.join(args.data_dir, "test.txt"), 113 | args.hits, 114 | os.path.join(output_dir, "test-h{}.txt".format(args.hits))) 115 | -------------------------------------------------------------------------------- /NeuralQA/EntityLinking/util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from nltk.tokenize.treebank import TreebankWordTokenizer 3 | 4 | tokenizer = TreebankWordTokenizer() 5 | 6 | 7 | def readFile(filename): 8 | context = open(filename).readlines() 9 | return [c.strip() for c in context] 10 | 11 | 12 | def writeFile(context, filename, append=False): 13 | if not append: 14 | with open(filename, 'w+') as fout: 15 | for co in context: 16 | fout.write(co + "\n") 17 | else: 18 | with open(filename, 'a+') as fout: 19 | for co in context: 20 | fout.write(co + "\n") 21 | 22 | 23 | def list2str(l, split=" "): 24 | a = "" 25 | for li in l: 26 | a += (str(li) + split) 27 | a = a[:-len(split)] 28 | return a 29 | 30 | 31 | def get_ngram(text): 32 | # ngram = set() 33 | ngram = [] 34 | tokens = text.split() 35 | for i in range(len(tokens) + 1): 36 | for j in range(i): 37 | if i - j <= 3: # 3 ? 38 | # ngram.add(" ".join(tokens[j:i])) 39 | temp = " ".join(tokens[j:i]) 40 | if temp not in ngram: 41 | ngram.append(temp) 42 | # ngram = list(ngram) 43 | ngram = sorted(ngram, key=lambda x: len(x.split()), reverse=True) 44 | return ngram 45 | 46 | 47 | def get_stat_inverted_index(filename): 48 | """ 49 | Get the number of entry and max length of the entry (How many mid in an entry) 50 | """ 51 | with open(filename, "rb") as handler: 52 | global inverted_index 53 | inverted_index = pickle.load(handler) 54 | inverted_index = defaultdict(str, inverted_index) 55 | print("Total type of text: {}".format(len(inverted_index))) 56 | max_len = 0 57 | _entry = "" 58 | for entry, value in inverted_index.items(): 59 | if len(value) > max_len: 60 | max_len = len(value) 61 | _entry = entry 62 | print("Max Length of entry is {}, text is {}".format(max_len, _entry)) 63 | 64 | 65 | def processed_text(text): 66 | text = text.replace('\\\\', '') 67 | # stripped = strip_accents(text.lower()) 68 | stripped = text.lower() 69 | toks = tokenizer.tokenize(stripped) 70 | return " ".join(toks) 71 | 72 | 73 | def strip_accents(text): 74 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 75 | 76 | 77 | def www2fb(in_str): 78 | if in_str.startswith("www.freebase.com"): 79 | in_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 80 | if in_str == 'fb:m.07s9rl0': 81 | in_str = 'fb:m.02822' 82 | if in_str == 'fb:m.0bb56b6': 83 | in_str = 'fb:m.0dn0r' 84 | # Manual Correction 85 | if in_str == 'fb:m.01g81dw': 86 | in_str = 'fb:m.01g_bfh' 87 | if in_str == 'fb:m.0y7q89y': 88 | in_str = 'fb:m.0wrt1c5' 89 | if in_str == 'fb:m.0b0w7': 90 | in_str = 'fb:m.0fq0s89' 91 | if in_str == 'fb:m.09rmm6y': 92 | in_str = 'fb:m.03cnrcc' 93 | if in_str == 'fb:m.0crsn60': 94 | in_str = 'fb:m.02pnlqy' 95 | if in_str == 'fb:m.04t1f8y': 96 | in_str = 'fb:m.04t1fjr' 97 | if in_str == 'fb:m.027z990': 98 | in_str = 'fb:m.0ghdhcb' 99 | if in_str == 'fb:m.02xhc2v': 100 | in_str = 'fb:m.084sq' 101 | if in_str == 'fb:m.02z8b2h': 102 | in_str = 'fb:m.033vn1' 103 | if in_str == 'fb:m.0w43mcj': 104 | in_str = 'fb:m.0m0qffc' 105 | if in_str == 'fb:m.07rqy': 106 | in_str = 'fb:m.0py_0' 107 | if in_str == 'fb:m.0y9s5rm': 108 | in_str = 'fb:m.0ybxl2g' 109 | if in_str == 'fb:m.037ltr7': 110 | in_str = 'fb:m.0qjx99s' 111 | return in_str 112 | 113 | 114 | def clean_uri(uri): 115 | if uri.startswith("<") and uri.endswith(">"): 116 | return clean_uri(uri[1:-1]) 117 | elif uri.startswith("\"") and uri.endswith("\""): 118 | return clean_uri(uri[1:-1]) 119 | return uri 120 | 121 | 122 | def rdf2fb(in_str): 123 | if in_str.startswith('http://rdf.freebase.com/ns/'): 124 | return 'fb:%s' % (in_str.split('http://rdf.freebase.com/ns/')[-1]) 125 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/convert.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def convert(filename, output): 5 | fin = open(filename, 'r') 6 | fout = open(output, 'w') 7 | for line in fin.readlines(): 8 | items = line.strip().split('\t') 9 | sent, label = items[5], items[6] 10 | for word, tag in zip(sent.strip().split(), label.strip().split()): 11 | fout.write("{}\t{}\n".format(word, tag)) 12 | fout.write("\n") 13 | fout.close() 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = ArgumentParser(description='Convert dataset to stanford format for training') 18 | parser.add_argument('--data_dir', type=str, default="../../../data/processed_simplequestions_dataset/train.txt") 19 | parser.add_argument('--save_path', type=str, default="data/stanford.train") 20 | args = parser.parse_args() 21 | convert(args.data_dir, args.save_path) 22 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def get_span(label): 5 | start, end = 0, 0 6 | flag = False 7 | span = [] 8 | for k, l in enumerate(label): 9 | if l == 'I' and not flag: 10 | start = k 11 | flag = True 12 | if l != 'I' and flag: 13 | flag = False 14 | en = k 15 | span.append((start, en)) 16 | start, end = 0, 0 17 | if start != 0 and end == 0: 18 | end = len(label) + 1 # bug fixed: geoff 19 | span.append((start, end)) 20 | return span 21 | 22 | 23 | def evaluation(filename): 24 | fin = open(filename, 'r') 25 | pred = [] 26 | gold = [] 27 | right = 0 28 | predicted = 0 29 | total_en = 0 30 | for line in fin.readlines(): 31 | if line == '\n': 32 | gold_span = get_span(gold) 33 | pred_span = get_span(pred) 34 | total_en += len(gold_span) 35 | predicted += len(pred_span) 36 | for item in pred_span: 37 | if item in gold_span: 38 | right += 1 39 | gold = [] 40 | pred = [] 41 | else: 42 | word, gold_label, pred_label = line.strip().split() 43 | gold.append(gold_label) 44 | pred.append(pred_label) 45 | 46 | if gold != [] or pred != []: 47 | gold_span = get_span(gold) 48 | pred_span = get_span(pred) 49 | total_en += len(gold_span) 50 | predicted += len(pred_span) 51 | for item in pred_span: 52 | if item in gold_span: 53 | right += 1 54 | 55 | if predicted == 0: 56 | precision = 0 57 | else: 58 | precision = right / predicted 59 | if total_en == 0: 60 | recall = 0 61 | else: 62 | recall = right / total_en 63 | if precision + recall == 0: 64 | f1 = 0 65 | else: 66 | f1 = 2 * precision * recall / (precision + recall) 67 | print("Precision", precision, "Recall", recall, "F1", f1, "right", right, "predicted", predicted, "total", total_en) 68 | 69 | 70 | if __name__ == '__main__': 71 | if len(sys.argv) != 2: 72 | print("Need to specify the file") 73 | filename = sys.argv[1] 74 | evaluation(filename) 75 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/output.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def convert(fileName, idFile, outputFile): 5 | fin = open(fileName) 6 | fid = open(idFile) 7 | fout = open(outputFile, "w") 8 | word_list = [] 9 | pred_query = [] 10 | line_id = [] 11 | for line in fid.readlines(): 12 | line_id.append(line.strip()) 13 | index = 0 14 | for line in fin.readlines(): 15 | if line == '\n': 16 | if len(pred_query) == 0: 17 | pred_query = word_list 18 | fout.write("{} %%%% {}\n".format(line_id[index], " ".join(pred_query))) 19 | index += 1 20 | pred_query = [] 21 | word_list = [] 22 | else: 23 | word, gold_label, pred_label = line.strip().split() 24 | word_list.append(word) 25 | if pred_label == 'I': 26 | pred_query.append(word) 27 | if (index != len(line_id)): 28 | print("Length Error") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = ArgumentParser(description='Convert result to query text') 33 | parser.add_argument('--data_dir', type=str, default="stanford-ner/data/stanford.predicted.valid") 34 | parser.add_argument('--valid_line', type=str, 35 | default="../../data/processed_simplequestions_dataset/lineids_valid.txt") 36 | parser.add_argument('--results_path', type=str, default="query_text/query.valid") 37 | args = parser.parse_args() 38 | convert(args.data_dir, args.valid_line, args.results_path) 39 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/crf/stanford-ner/NERDemo.java: -------------------------------------------------------------------------------- 1 | import edu.stanford.nlp.ie.AbstractSequenceClassifier; 2 | import edu.stanford.nlp.ie.crf.*; 3 | import edu.stanford.nlp.io.IOUtils; 4 | import edu.stanford.nlp.ling.CoreLabel; 5 | import edu.stanford.nlp.ling.CoreAnnotations; 6 | import edu.stanford.nlp.sequences.DocumentReaderAndWriter; 7 | import edu.stanford.nlp.util.Triple; 8 | 9 | import java.util.List; 10 | 11 | 12 | /** This is a demo of calling CRFClassifier programmatically. 13 | *

14 | * Usage: {@code java -mx400m -cp "*" NERDemo [serializedClassifier [fileName]] } 15 | *

16 | * If arguments aren't specified, they default to 17 | * classifiers/english.all.3class.distsim.crf.ser.gz and some hardcoded sample text. 18 | * If run with arguments, it shows some of the ways to get k-best labelings and 19 | * probabilities out with CRFClassifier. If run without arguments, it shows some of 20 | * the alternative output formats that you can get. 21 | *

22 | * To use CRFClassifier from the command line: 23 | *

24 | * {@code java -mx400m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier [classifier] -textFile [file] } 25 | *

26 | * Or if the file is already tokenized and one word per line, perhaps in 27 | * a tab-separated value format with extra columns for part-of-speech tag, 28 | * etc., use the version below (note the 's' instead of the 'x'): 29 | *

30 | * {@code java -mx400m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier [classifier] -testFile [file] } 31 | *
32 | * 33 | * @author Jenny Finkel 34 | * @author Christopher Manning 35 | */ 36 | 37 | public class NERDemo { 38 | 39 | public static void main(String[] args) throws Exception { 40 | 41 | String serializedClassifier = "classifiers/english.all.3class.distsim.crf.ser.gz"; 42 | 43 | if (args.length > 0) { 44 | serializedClassifier = args[0]; 45 | } 46 | 47 | AbstractSequenceClassifier classifier = CRFClassifier.getClassifier(serializedClassifier); 48 | 49 | /* For either a file to annotate or for the hardcoded text example, this 50 | demo file shows several ways to process the input, for teaching purposes. 51 | */ 52 | 53 | if (args.length > 1) { 54 | 55 | /* For the file, it shows (1) how to run NER on a String, (2) how 56 | to get the entities in the String with character offsets, and 57 | (3) how to run NER on a whole file (without loading it into a String). 58 | */ 59 | 60 | String fileContents = IOUtils.slurpFile(args[1]); 61 | List> out = classifier.classify(fileContents); 62 | for (List sentence : out) { 63 | for (CoreLabel word : sentence) { 64 | System.out.print(word.word() + '/' + word.get(CoreAnnotations.AnswerAnnotation.class) + ' '); 65 | } 66 | System.out.println(); 67 | } 68 | 69 | System.out.println("---"); 70 | out = classifier.classifyFile(args[1]); 71 | for (List sentence : out) { 72 | for (CoreLabel word : sentence) { 73 | System.out.print(word.word() + '/' + word.get(CoreAnnotations.AnswerAnnotation.class) + ' '); 74 | } 75 | System.out.println(); 76 | } 77 | 78 | System.out.println("---"); 79 | List> list = classifier.classifyToCharacterOffsets(fileContents); 80 | for (Triple item : list) { 81 | System.out.println(item.first() + ": " + fileContents.substring(item.second(), item.third())); 82 | } 83 | System.out.println("---"); 84 | System.out.println("Ten best entity labelings"); 85 | DocumentReaderAndWriter readerAndWriter = classifier.makePlainTextReaderAndWriter(); 86 | classifier.classifyAndWriteAnswersKBest(args[1], 10, readerAndWriter); 87 | 88 | System.out.println("---"); 89 | System.out.println("Per-token marginalized probabilities"); 90 | classifier.printProbs(args[1], readerAndWriter); 91 | 92 | // -- This code prints out the first order (token pair) clique probabilities. 93 | // -- But that output is a bit overwhelming, so we leave it commented out by default. 94 | // System.out.println("---"); 95 | // System.out.println("First Order Clique Probabilities"); 96 | // ((CRFClassifier) classifier).printFirstOrderProbs(args[1], readerAndWriter); 97 | 98 | } else { 99 | 100 | /* For the hard-coded String, it shows how to run it on a single 101 | sentence, and how to do this and produce several formats, including 102 | slash tags and an inline XML output format. It also shows the full 103 | contents of the {@code CoreLabel}s that are constructed by the 104 | classifier. And it shows getting out the probabilities of different 105 | assignments and an n-best list of classifications with probabilities. 106 | */ 107 | 108 | String[] example = {"Good afternoon Rajat Raina, how are you today?", 109 | "I go to school at Stanford University, which is located in California." }; 110 | for (String str : example) { 111 | System.out.println(classifier.classifyToString(str)); 112 | } 113 | System.out.println("---"); 114 | 115 | for (String str : example) { 116 | // This one puts in spaces and newlines between tokens, so just print not println. 117 | System.out.print(classifier.classifyToString(str, "slashTags", false)); 118 | } 119 | System.out.println("---"); 120 | 121 | for (String str : example) { 122 | // This one is best for dealing with the output as a TSV (tab-separated column) file. 123 | // The first column gives entities, the second their classes, and the third the remaining text in a document 124 | System.out.print(classifier.classifyToString(str, "tabbedEntities", false)); 125 | } 126 | System.out.println("---"); 127 | 128 | for (String str : example) { 129 | System.out.println(classifier.classifyWithInlineXML(str)); 130 | } 131 | System.out.println("---"); 132 | 133 | for (String str : example) { 134 | System.out.println(classifier.classifyToString(str, "xml", true)); 135 | } 136 | System.out.println("---"); 137 | 138 | for (String str : example) { 139 | System.out.print(classifier.classifyToString(str, "tsv", false)); 140 | } 141 | System.out.println("---"); 142 | 143 | // This gets out entities with character offsets 144 | int j = 0; 145 | for (String str : example) { 146 | j++; 147 | List> triples = classifier.classifyToCharacterOffsets(str); 148 | for (Triple trip : triples) { 149 | System.out.printf("%s over character offsets [%d, %d) in sentence %d.%n", 150 | trip.first(), trip.second(), trip.third, j); 151 | } 152 | } 153 | System.out.println("---"); 154 | 155 | // This prints out all the details of what is stored for each token 156 | int i=0; 157 | for (String str : example) { 158 | for (List lcl : classifier.classify(str)) { 159 | for (CoreLabel cl : lcl) { 160 | System.out.print(i++ + ": "); 161 | System.out.println(cl.toShorterString()); 162 | } 163 | } 164 | } 165 | 166 | System.out.println("---"); 167 | 168 | } 169 | } 170 | 171 | } 172 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | def get_args(): 4 | parser = ArgumentParser(description="Joint Prediction") 5 | parser.add_argument('--mention_detection_mode', type=str, required=True, help='options are LSTM, GRU') 6 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 7 | parser.add_argument('--gpu', type=int, default=-1) # Use -1 for CPU 8 | parser.add_argument('--epochs', type=int, default=30) 9 | parser.add_argument('--batch_size', type=int, default=32) 10 | parser.add_argument('--dataset', type=str, default="EntityDetection") 11 | parser.add_argument('--lr', type=float, default=1e-4) 12 | parser.add_argument('--seed', type=int, default=3435) 13 | parser.add_argument('--dev_every', type=int, default=2000) 14 | parser.add_argument('--log_every', type=int, default=1000) 15 | parser.add_argument('--patience', type=int, default=10) 16 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 17 | parser.add_argument('--specify_prefix', type=str, default='id1') 18 | parser.add_argument('--words_dim', type=int, default=300) 19 | parser.add_argument('--num_layer', type=int, default=2) 20 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3) 21 | parser.add_argument('--input_size', type=int, default=300) 22 | parser.add_argument('--hidden_size', type=int, default=300) 23 | parser.add_argument('--rnn_dropout', type=float, default=0.3) 24 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 25 | parser.add_argument('--vector_cache', type=str, default="../../data/sq_glove300d.pt") 26 | parser.add_argument('--weight_decay',type=float, default=0) 27 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 28 | parser.add_argument('--hits', type=int, default=100) 29 | # added for testing 30 | parser.add_argument('--trained_model', type=str, default='') 31 | parser.add_argument('--data_dir', type=str, default='../../data/processed_simplequestions_dataset') 32 | parser.add_argument('--results_path', type=str, default='query_text') 33 | # added for demo 34 | parser.add_argument('--input', type=str, default='Which teams have James played basketball?') 35 | args = parser.parse_args() 36 | return args 37 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from args import get_args 7 | from datasets import SimpleQuestionDataset 8 | from torchtext import data 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 11 | 12 | np.set_printoptions(threshold=np.nan) 13 | # Set default configuration in : args.py 14 | args = get_args() 15 | 16 | # Set random seed for reproducibility 17 | torch.manual_seed(args.seed) 18 | np.random.seed(args.seed) 19 | random.seed(args.seed) 20 | 21 | print(args) 22 | if not args.cuda: 23 | args.gpu = -1 24 | if torch.cuda.is_available() and args.cuda: 25 | # print("Note: You are using GPU for training") 26 | torch.cuda.set_device(args.gpu) 27 | torch.cuda.manual_seed(args.seed) 28 | if torch.cuda.is_available() and not args.cuda: 29 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 30 | 31 | TEXT = data.Field(lower=True) 32 | ED = data.Field() 33 | 34 | train, dev, test = SimpleQuestionDataset.splits(TEXT, ED, path=args.data_dir) # text_field, label_field 35 | TEXT.build_vocab(train, dev, test) 36 | ED.build_vocab(train, dev, test) 37 | 38 | # load the model 39 | model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) 40 | 41 | # print(model) 42 | 43 | if args.dataset == 'EntityDetection': 44 | index2tag = np.array(ED.vocab.itos) 45 | # print(index2tag) 46 | else: 47 | print("Wrong Dataset") 48 | exit(1) 49 | 50 | index2word = np.array(TEXT.vocab.itos) 51 | 52 | results_path = os.path.join(args.results_path, args.entity_detection_mode.lower()) 53 | if not os.path.exists(results_path): 54 | os.makedirs(results_path, exist_ok=True) 55 | 56 | sentence = "who wrote the film Brave heart" # args.input which genre of album is harder ... faster 57 | # sentence = args.input 58 | # print(sentence.split()) 59 | # print(index2word[:10]) # '', '' 60 | 61 | sent_idx = list() 62 | for word in sentence.split(): 63 | if word.lower() in index2word: 64 | sent_idx.append(index2word.tolist().index(word.lower())) 65 | else: 66 | sent_idx.append(0) # 1? 67 | 68 | 69 | # print(sent_idx) 70 | 71 | 72 | def predict(data_name="demo"): 73 | model.eval() 74 | 75 | data_batch = np.array(sent_idx).reshape(-1, 1) 76 | scores = model(data_batch) 77 | # print(scores) 78 | 79 | if args.dataset == 'EntityDetection': 80 | # print(torch.max(scores, 1)[1]) 81 | index_tag = np.transpose(torch.max(scores, 1)[1].cpu().data.numpy()) 82 | tag_array = index2tag[index_tag] 83 | index_question = np.transpose(data_batch).reshape(-1) 84 | question_array = index2word[index_question] 85 | 86 | mentions = list() 87 | mention = "" 88 | for question, label in zip(question_array, tag_array): 89 | # print("{}\t{}\t".format("".join(question), " ".join(label))) 90 | if label == 'I' and question != "" and question != "": 91 | mention += question + " " 92 | if label == 'O' and mention != "": 93 | mentions.append(mention.strip()) 94 | mention = "" 95 | if mention != "": 96 | mentions.append(mention.strip()) 97 | for mention in mentions: 98 | print(mention) 99 | else: 100 | print("Wrong Dataset") 101 | exit() 102 | 103 | 104 | # run the model on the demo set and write the output to a file 105 | predict(data_name="demo") 106 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class MentionDetection(nn.Module): 8 | def __init__(self, config): 9 | super(MentionDetection, self).__init__() 10 | self.config = config 11 | self.embed = nn.Embedding(config.words_num, config.words_dim) 12 | if config.train_embed == False: 13 | self.embed.weight.requires_grad = False 14 | if config.mention_detection_mode.upper() == 'LSTM': 15 | self.lstm = nn.LSTM(input_size=config.input_size, 16 | hidden_size=config.hidden_size, 17 | num_layers=config.num_layer, 18 | dropout=config.rnn_dropout, 19 | bidirectional=True) 20 | elif config.mention_detection_mode.upper() == 'GRU': 21 | self.gru = nn.GRU(input_size=config.input_size, 22 | hidden_size=config.hidden_size, 23 | num_layers=config.num_layer, 24 | dropout=config.rnn_dropout, 25 | bidirectional=True) 26 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 27 | self.relu = nn.ReLU() 28 | self.hidden2tag = nn.Sequential( 29 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 30 | nn.BatchNorm1d(config.hidden_size * 2), 31 | self.relu, 32 | self.dropout, 33 | nn.Linear(config.hidden_size * 2, config.label) 34 | ) 35 | 36 | def forward(self, x): 37 | # x = (sequence length, batch_size, dimension of embedding) 38 | # text = x.text # geoff: demo 39 | 40 | if self.config.cuda: 41 | text = torch.LongTensor(x).cuda() 42 | else: 43 | text = torch.LongTensor(x) 44 | # print("########") 45 | # print(text) 46 | 47 | batch_size = text.size()[1] 48 | x = self.embed(text) 49 | # h0 / c0 = (layer*direction, batch_size, hidden_dim) 50 | if self.config.mention_detection_mode.upper() == 'LSTM': 51 | if self.config.cuda: 52 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 53 | self.config.hidden_size).cuda()) 54 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 55 | self.config.hidden_size).cuda()) 56 | else: 57 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 58 | self.config.hidden_size)) 59 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 60 | self.config.hidden_size)) 61 | # output = (sentence length, batch_size, hidden_size * num_direction) 62 | # ht = (layer*direction, batch, hidden_dim) 63 | # ct = (layer*direction, batch, hidden_dim) 64 | outputs, (ht, ct) = self.lstm(x, (h0, c0)) 65 | elif self.config.mention_detection_mode.upper() == 'GRU': 66 | if self.config.cuda: 67 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 68 | self.config.hidden_size).cuda()) 69 | else: 70 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 71 | self.config.hidden_size)) 72 | # output = (sentence length, batch_size, hidden_size * num_direction) 73 | # ht = (layer*direction, batch, hidden_dim) 74 | outputs, ht = self.gru(x, h0) 75 | else: 76 | print("Wrong Mention Detection Mode") 77 | exit(1) 78 | tags = self.hidden2tag(outputs.view(-1, outputs.size(2))) 79 | scores = F.log_softmax(tags) 80 | return scores 81 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torchtext import data 7 | from tqdm import tqdm 8 | 9 | from args import get_args 10 | from util import evaluation 11 | from datasets import SimpleQuestionDataset 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 14 | 15 | np.set_printoptions(threshold=np.nan) 16 | # Set default configuration in : args.py 17 | args = get_args() 18 | 19 | # Set random seed for reproducibility 20 | torch.manual_seed(args.seed) 21 | np.random.seed(args.seed) 22 | random.seed(args.seed) 23 | 24 | if not args.cuda: 25 | args.gpu = -1 26 | if torch.cuda.is_available() and args.cuda: 27 | print("Note: You are using GPU for training") 28 | torch.cuda.set_device(args.gpu) 29 | torch.cuda.manual_seed(args.seed) 30 | if torch.cuda.is_available() and not args.cuda: 31 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 32 | 33 | TEXT = data.Field(lower=True) 34 | ED = data.Field() 35 | 36 | train, dev, test = SimpleQuestionDataset.splits(TEXT, ED, path=args.data_dir) # text_field, label_field 37 | TEXT.build_vocab(train, dev, test) 38 | ED.build_vocab(train, dev, test) 39 | 40 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=args.gpu, train=True, repeat=False, 41 | sort=False, shuffle=True) 42 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 43 | sort=False, shuffle=False) 44 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 45 | sort=False, shuffle=False) 46 | 47 | # load the model 48 | if not torch.cuda.is_available(): 49 | model = torch.load(args.trained_model) 50 | else: 51 | model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) 52 | 53 | print(model) 54 | 55 | if args.dataset == 'EntityDetection': 56 | index2tag = np.array(ED.vocab.itos) 57 | # print(index2tag) 58 | else: 59 | print("Wrong Dataset") 60 | exit(1) 61 | 62 | index2word = np.array(TEXT.vocab.itos) 63 | 64 | results_path = os.path.join(args.results_path, args.entity_detection_mode.lower()) 65 | if not os.path.exists(results_path): 66 | os.makedirs(results_path, exist_ok=True) 67 | 68 | 69 | def convert(fileName, idFile, outputFile): 70 | fin = open(fileName) 71 | fid = open(idFile) 72 | fout = open(outputFile, "w") 73 | # holiday barbie doll by bob mackie # h8583 O O O O O O I I I O O O I I 74 | # valid-10837 %%%% holiday barbie doll %%%% # h8583 75 | for line, line_id in tqdm(zip(fin.readlines(), fid.readlines())): 76 | query_list = [] 77 | query_text = [] 78 | line = line.strip().split('\t') 79 | sent = line[0].strip().split() 80 | pred = line[1].strip().split() 81 | for token, label in zip(sent, pred): 82 | if label == 'I': 83 | query_text.append(token) 84 | if label == 'O': 85 | query_text = list(filter(lambda x: x != '', query_text)) 86 | if len(query_text) != 0: 87 | query_list.append(" ".join(list(filter(lambda x: x != '', query_text)))) 88 | query_text = [] 89 | query_text = list(filter(lambda x: x != '', query_text)) 90 | if len(query_text) != 0: 91 | query_list.append(" ".join(list(filter(lambda x: x != '', query_text)))) 92 | if len(query_list) == 0: 93 | query_list.append(" ".join(list(filter(lambda x: x != '', sent)))) 94 | fout.write(" %%%% ".join([line_id.strip()] + query_list) + "\n") 95 | 96 | 97 | def predict(dataset_iter=test_iter, dataset=test, data_name="test"): 98 | print("Dataset: {}".format(data_name)) 99 | model.eval() 100 | dataset_iter.init_epoch() 101 | 102 | n_correct = 0 103 | fname = "{}.txt".format(data_name) 104 | temp_file = 'tmp' + fname 105 | results_file = open(temp_file, 'w') 106 | 107 | gold_list = [] 108 | pred_list = [] 109 | 110 | for data_batch_idx, data_batch in enumerate(dataset_iter): 111 | scores = model(data_batch) 112 | if args.dataset == 'EntityDetection': 113 | n_correct += ((torch.max(scores, 1)[1].view(data_batch.ed.size()).data == data_batch.ed.data).sum(dim=0) \ 114 | == data_batch.ed.size()[0]).sum() 115 | index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch.ed.size()).cpu().data.numpy()) 116 | tag_array = index2tag[index_tag] 117 | index_question = np.transpose(data_batch.text.cpu().data.numpy()) 118 | question_array = index2word[index_question] 119 | gold_list.append(np.transpose(data_batch.ed.cpu().data.numpy())) 120 | gold_array = index2tag[np.transpose(data_batch.ed.cpu().data.numpy())] 121 | pred_list.append(index_tag) 122 | for question, label, gold in zip(question_array, tag_array, gold_array): 123 | results_file.write("{}\t{}\t{}\n".format(" ".join(question), " ".join(label), " ".join(gold))) 124 | # print("{}\t{}\t{}\n".format(" ".join(question), " ".join(label), " ".join(gold))) 125 | else: 126 | print("Wrong Dataset") 127 | exit() 128 | 129 | if args.dataset == 'EntityDetection': 130 | P, R, F = evaluation(gold_list, pred_list, index2tag, type=False) 131 | print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R, 132 | 100. * F)) 133 | else: 134 | print("Wrong dataset") 135 | exit() 136 | results_file.flush() 137 | results_file.close() 138 | convert(temp_file, os.path.join(args.data_dir, "lineids_{}.txt".format(data_name)), 139 | os.path.join(results_path, "query.{}".format(data_name))) 140 | os.remove(temp_file) 141 | 142 | 143 | # run the model on the dev set and write the output to a file 144 | predict(dataset_iter=dev_iter, dataset=dev, data_name="valid2") 145 | 146 | # run the model on the test set and write the output to a file 147 | # predict(dataset_iter=test_iter, dataset=test, data_name="test") 148 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/tmp_data.txt: -------------------------------------------------------------------------------- 1 | who was the trump ocean club international hotel and tower named after O O O I I I I I I I O O O O O I I I I I I I O O 2 | where was sasha vujačić born O O I I O O O I I O 3 | what is a region that dead combo was released in O O O O O I I O O O O O O O O I I O O O 4 | what is a film directed by wiebke von carolsfeld ? O O O O O O I I I O O O O O O O I I I O 5 | what country was music for stock exchange released in O O O I I I I O O O O O I I I I O O 6 | where is adler school of professional psychology located ? O O I I I I I O O O O I I I I I O O 7 | where was john miltern born O O I I O O O I I O 8 | what city is vancouver millionaires from ? O O O I I O O O O O I I O O 9 | what was seymour parker gilbert 's profession ? O O I I I O O O O O I I I O O O 10 | what does ( 12385 ) 1994 uo orbit O O I I I I I O O O I I I I I O 11 | who is the singer of only women bleed O O O O O I I I O O O O O I I I 12 | in what french city did antoine de févin die O O O O O I I I O O O O O O I I I O 13 | who published rama O O I O O I 14 | who was an advisor for irving langmuir ? O O O O O I I O O O O O O I I O 15 | what is the language of the film bon voyage ? O O O O O O O I I O O O O O O O O I I O 16 | which country was the hunyadi family from O O O I I I O O O O O I I O 17 | what major cities does u.s. route 2 run through O O O O I I I O O O O O O I I I O O 18 | who was a child of mithibai jinnah O O O O O I I O O O O O I I 19 | whats a version of the single titled star O O O O O O O I O O O O O O O I 20 | what is a song by john rutter ? O O O O O I I O O O O O O I I O 21 | what job does jamie hewlett have O O O I I O O O O I I O 22 | what 's an example of an album O O O O O O I O O O O O O I 23 | what is the film tempo di uccidere about O O O O I I I O O O O O I I I O 24 | what country is ghost house from O O O I I O O O O I I O 25 | which country was the yamakinkarudu movie produced O O O O I O O O O O O I O O 26 | what 's the time zone in sub-saharan africa O O O O O O I I O O O O O O I I 27 | what author wrote the book liquor ? O O O O O I O O O O O O I O 28 | what is the release type of the album wake ? O O O O O O O O I O O O O O O O O O I O 29 | who is the chid of fritz leiber ? O O O O O I I O O O O O O I I O 30 | what artist creates riot grrrl music O O O I I O O O O I I O 31 | what country is rafael sorkin from O O O I I O O O O I I O 32 | what was marcy rae 's profession ? O O I I O O O O O I I O O O 33 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from args import get_args 9 | from datasets import SimpleQuestionDataset 10 | from model import MentionDetection 11 | from torchtext import data 12 | from util import evaluation 13 | 14 | np.set_printoptions(threshold=np.nan) 15 | # Set default configuration in : args.py 16 | args = get_args() 17 | 18 | # Set random seed for reproducibility 19 | torch.manual_seed(args.seed) 20 | np.random.seed(args.seed) 21 | random.seed(args.seed) 22 | torch.backends.cudnn.deterministic = True 23 | 24 | if not args.cuda: 25 | args.gpu = -1 26 | if torch.cuda.is_available() and args.cuda: 27 | print("Note: You are using GPU for training") 28 | torch.cuda.set_device(args.gpu) 29 | torch.cuda.manual_seed(args.seed) 30 | if torch.cuda.is_available() and not args.cuda: 31 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 32 | 33 | # Set up the data for training 34 | TEXT = data.Field(lower=True) 35 | ED = data.Field() 36 | 37 | train, dev, test = SimpleQuestionDataset.splits(TEXT, ED, args.data_dir) 38 | TEXT.build_vocab(train, dev, test) 39 | ED.build_vocab(train, dev, test) 40 | 41 | match_embedding = 0 42 | if os.path.isfile(args.vector_cache): 43 | stoi, vectors, dim = torch.load(args.vector_cache) # stoi? 44 | TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), dim) 45 | for i, token in enumerate(TEXT.vocab.itos): 46 | wv_index = stoi.get(token, None) 47 | if wv_index is not None: 48 | TEXT.vocab.vectors[i] = vectors[wv_index] 49 | match_embedding += 1 50 | else: 51 | TEXT.vocab.vectors[i] = torch.FloatTensor(dim).uniform_(-0.25, 0.25) 52 | else: 53 | print("Error: Need word embedding pt file") 54 | exit(1) 55 | 56 | print("Embedding match number {} out of {}".format(match_embedding, len(TEXT.vocab))) 57 | 58 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=args.gpu, train=True, repeat=False, 59 | sort=False, shuffle=True) 60 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 61 | sort=False, shuffle=False) 62 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 63 | sort=False, shuffle=False) 64 | 65 | config = args 66 | config.words_num = len(TEXT.vocab) 67 | 68 | if args.dataset == 'EntityDetection': 69 | config.label = len(ED.vocab) 70 | model = MentionDetection(config) 71 | else: 72 | print("Error Dataset") 73 | exit() 74 | 75 | model.embed.weight.data.copy_(TEXT.vocab.vectors) 76 | if args.cuda: 77 | model.cuda() 78 | print("Shift model to GPU") 79 | 80 | print(config) 81 | print("VOCAB num", len(TEXT.vocab)) 82 | print("Train instance", len(train)) 83 | print("Dev instance", len(dev)) 84 | print("Test instance", len(test)) 85 | print("Entity Type", len(ED.vocab)) 86 | print(model) 87 | 88 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 89 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 90 | criterion = nn.NLLLoss() 91 | 92 | early_stop = False 93 | best_dev_F = 0 94 | best_dev_P = 0 95 | best_dev_R = 0 96 | iterations = 0 97 | iters_not_improved = 0 98 | num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1 99 | patience = args.patience * num_dev_in_epoch # for early stopping 100 | epoch = 0 101 | start = time.time() 102 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 103 | dev_log_template = ' '.join( 104 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 105 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{},{}'.split(',')) 106 | save_path = os.path.join(args.save_path, args.entity_detection_mode.lower()) 107 | os.makedirs(save_path, exist_ok=True) 108 | print(header) 109 | 110 | if args.dataset == 'EntityDetection': 111 | index2tag = np.array(ED.vocab.itos) 112 | else: 113 | print("Wrong Dataset") 114 | exit(1) 115 | 116 | while True: 117 | if early_stop: 118 | print("Early Stopping. Epoch: {}, Best Dev F1: {}".format(epoch, best_dev_F)) 119 | break 120 | epoch += 1 121 | train_iter.init_epoch() 122 | n_correct, n_total = 0, 0 123 | n_correct_ed, n_correct_ner, n_correct_rel = 0, 0, 0 124 | 125 | for batch_idx, batch in enumerate(train_iter): 126 | # Batch size : (Sentence Length, Batch_size) 127 | iterations += 1 128 | model.train() 129 | optimizer.zero_grad() 130 | scores = model(batch) 131 | # Entity Detection 132 | if args.dataset == 'EntityDetection': 133 | n_correct += ((torch.max(scores, 1)[1].view(batch.ed.size()).data == batch.ed.data).sum(dim=0) \ 134 | == batch.ed.size()[0]).sum() 135 | loss = criterion(scores, batch.ed.view(-1, 1)[:, 0]) 136 | else: 137 | print("Wrong Dataset") 138 | exit() 139 | 140 | n_total += batch.batch_size 141 | loss.backward() 142 | # clip the gradient 143 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_gradient) 144 | optimizer.step() 145 | 146 | # evaluate performance on validation set periodically 147 | if iterations % args.dev_every == 0: 148 | model.eval() 149 | dev_iter.init_epoch() 150 | n_dev_correct = 0 151 | n_dev_correct_rel = 0 152 | 153 | gold_list = [] 154 | pred_list = [] 155 | 156 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 157 | answer = model(dev_batch) 158 | if args.dataset == 'EntityDetection': 159 | n_dev_correct += ( 160 | (torch.max(answer, 1)[1].view(dev_batch.ed.size()).data == dev_batch.ed.data).sum(dim=0) \ 161 | == dev_batch.ed.size()[0]).sum() 162 | index_tag = np.transpose(torch.max(answer, 1)[1].view(dev_batch.ed.size()).cpu().data.numpy()) 163 | gold_list.append(np.transpose(dev_batch.ed.cpu().data.numpy())) 164 | pred_list.append(index_tag) 165 | else: 166 | print("Wrong Dataset") 167 | exit() 168 | 169 | if args.dataset == 'EntityDetection': 170 | P, R, F = evaluation(gold_list, pred_list, index2tag, type=False) 171 | print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R, 172 | 100. * F)) 173 | else: 174 | print("Wrong dataset") 175 | exit() 176 | 177 | # update model 178 | if args.dataset == 'EntityDetection': 179 | if F > best_dev_F: 180 | best_dev_F = F 181 | best_dev_P = P 182 | best_dev_R = R 183 | iters_not_improved = 0 184 | snapshot_path = os.path.join(save_path, args.specify_prefix + '_best_model_cpu.pt') 185 | # save model, delete previous 'best_snapshot' files 186 | torch.save(model, snapshot_path) 187 | else: 188 | iters_not_improved += 1 189 | if iters_not_improved > patience: 190 | early_stop = True 191 | break 192 | else: 193 | print("Wrong dataset") 194 | exit() 195 | 196 | if iterations % args.log_every == 1: 197 | # print progress message 198 | print(log_template.format(time.time() - start, 199 | epoch, iterations, 1 + batch_idx, len(train_iter), 200 | 100. * (1 + batch_idx) / len(train_iter), loss.data[0], ' ' * 8, 201 | 100. * n_correct / n_total, ' ' * 12)) 202 | 203 | print('Time of train model: %f' % (time.time() - start)) 204 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/util/datasets.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | 3 | 4 | class SimpleQuestionDataset(data.TabularDataset): 5 | @classmethod 6 | def splits(cls, text_field, label_field, path, 7 | train='train.txt', validation='valid.txt', test='test.txt'): 8 | return super(SimpleQuestionDataset, cls).splits( 9 | path=path, train=train, validation=validation, test=test, 10 | format='TSV', fields=[('id', None), ('sub', None), ('entity', None), ('relation', None), 11 | ('obj', None), ('text', text_field), ('ed', label_field)] 12 | ) 13 | -------------------------------------------------------------------------------- /NeuralQA/MentionDetection/nn/util/util.py: -------------------------------------------------------------------------------- 1 | def get_span(label): 2 | start, end = 0, 0 3 | flag = False 4 | span = [] 5 | for k, l in enumerate(label): 6 | if l == 'I' and not flag: 7 | start = k 8 | flag = True 9 | if l != 'I' and flag: 10 | flag = False 11 | en = k 12 | span.append((start, en)) 13 | start, end = 0, 0 14 | if start != 0 and end == 0: 15 | end = len(label) + 1 # bug fixed: geoff 16 | span.append((start, end)) 17 | return span 18 | 19 | 20 | def evaluation(gold, pred, index2tag, type): 21 | right = 0 22 | predicted = 0 23 | total_en = 0 24 | # fout = open('log.valid', 'w') 25 | for i in range(len(gold)): 26 | gold_batch = gold[i] 27 | pred_batch = pred[i] 28 | for j in range(len(gold_batch)): 29 | gold_label = gold_batch[j] 30 | pred_label = pred_batch[j] 31 | gold_span = get_span(gold_label, index2tag, type) 32 | pred_span = get_span(pred_label, index2tag, type) 33 | # fout.write('{}\n{}'.format(gold_span, pred_span)) 34 | total_en += len(gold_span) 35 | predicted += len(pred_span) 36 | for item in pred_span: 37 | if item in gold_span: 38 | right += 1 39 | if predicted == 0: 40 | precision = 0 41 | else: 42 | precision = right / predicted 43 | if total_en == 0: 44 | recall = 0 45 | else: 46 | recall = right / total_en 47 | if precision + recall == 0: 48 | f1 = 0 49 | else: 50 | f1 = 2 * precision * recall / (precision + recall) 51 | # fout.flush() 52 | # fout.close() 53 | return precision, recall, f1 54 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description="Relation Detection") 6 | parser.add_argument('--relation_detection_mode', required=True, type=str, help='options are CNN, GRU, LSTM') 7 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 8 | parser.add_argument('--gpu', type=int, default=-1) # Use -1 for CPU 9 | parser.add_argument('--epochs', type=int, default=30) 10 | parser.add_argument('--batch_size', type=int, default=32) 11 | parser.add_argument('--dataset', type=str, default="RelationDetection") 12 | parser.add_argument('--mode', type=str, default='static') 13 | parser.add_argument('--lr', type=float, default=1e-4) 14 | parser.add_argument('--seed', type=int, default=3435) 15 | parser.add_argument('--dev_every', type=int, default=2000) 16 | parser.add_argument('--log_every', type=int, default=1000) 17 | parser.add_argument('--patience', type=int, default=10) 18 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 19 | parser.add_argument('--specify_prefix', type=str, default='id1') 20 | parser.add_argument('--output_channel', type=int, default=300) 21 | parser.add_argument('--words_dim', type=int, default=300) 22 | parser.add_argument('--num_layer', type=int, default=2) 23 | parser.add_argument('--rnn_dropout', type=float, default=0.3, help='dropout in rnn') 24 | parser.add_argument('--input_size', type=int, default=300) 25 | parser.add_argument('--hidden_size', type=int, default=300) 26 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3, help='dropout before fully connected layer in RNN') 27 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 28 | parser.add_argument('--vector_cache', type=str, default="../../data/sq_glove300d.pt") 29 | parser.add_argument('--weight_decay', type=float, default=0) 30 | parser.add_argument('--cnn_dropout', type=float, default=0.5, help='dropout before fully connected layer in CNN') 31 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 32 | parser.add_argument('--hits', type=int, default=30) # 5 33 | # added for testing 34 | parser.add_argument('--data_dir', type=str, default='../../data/processed_simplequestions_dataset/') 35 | parser.add_argument('--trained_model', type=str, default='saved_checkpoints/cnn/id1_best_model_cpu2.pt') 36 | parser.add_argument('--results_path', type=str, default='results') 37 | # added for demo 38 | parser.add_argument('--input', type=str, default='') 39 | 40 | args = parser.parse_args() 41 | return args 42 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/datasets.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | 3 | 4 | class SimpleQuestionsDataset(data.TabularDataset): 5 | @classmethod 6 | def splits(cls, text_field, label_field, path, # train.txt 7 | train='train_relation', validation='valid_relation', test='test_relation'): 8 | return super(SimpleQuestionsDataset, cls).splits( 9 | path, '', train, validation, test, 10 | format='TSV', fields=[('id', None), ('sub', None), ('entity', None), ('relation', label_field), 11 | ('obj', None), ('text', text_field), ('ed', None)] 12 | ) 13 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | 6 | from datasets import SimpleQuestionsDataset 7 | from torchtext import data 8 | from args import get_args 9 | 10 | 11 | np.set_printoptions(threshold=np.nan) 12 | # Set default configuration in : args.py 13 | args = get_args() 14 | 15 | # Set random seed for reproducibility 16 | torch.manual_seed(args.seed) 17 | np.random.seed(args.seed) 18 | random.seed(args.seed) 19 | 20 | if not args.cuda: 21 | args.gpu = -1 22 | if torch.cuda.is_available() and args.cuda: 23 | # print("Note: You are using GPU for training") 24 | torch.cuda.set_device(args.gpu) 25 | torch.cuda.manual_seed(args.seed) 26 | if torch.cuda.is_available() and not args.cuda: 27 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 28 | 29 | TEXT = data.Field(lower=True) 30 | RELATION = data.Field(sequential=False) 31 | 32 | train, dev, test = SimpleQuestionsDataset.splits(TEXT, RELATION, args.data_dir) 33 | TEXT.build_vocab(train, dev, test) 34 | RELATION.build_vocab(train, dev) 35 | 36 | # load the model 37 | model = torch.load(args.trained_model) 38 | # print(model) 39 | 40 | if args.dataset == 'RelationDetection': 41 | index2tag = np.array(RELATION.vocab.itos) 42 | print(len(index2tag)) 43 | # print(index2tag[:5]) 44 | else: 45 | print("Wrong Dataset") 46 | exit(1) 47 | 48 | index2word = np.array(TEXT.vocab.itos).tolist() 49 | # print(len(index2word)) 50 | 51 | results_path = os.path.join(args.results_path, args.relation_detection_mode.lower()) 52 | if not os.path.exists(results_path): 53 | os.makedirs(results_path, exist_ok=True) 54 | 55 | sentence = "who is the child of Obama" # args.input obama 56 | print(model.embed(torch.LongTensor([12]))) 57 | # sentence = args.input 58 | # print(sentence.split()) 59 | # print(index2word.index('what')) 60 | # print(index2word[:10]) # '', '' 61 | 62 | 63 | sent_idx = list() 64 | for word in sentence.split(): 65 | if word.lower() in index2word: 66 | sent_idx.append(index2word.index(word.lower())) 67 | else: 68 | sent_idx.append(0) # 1? 69 | print(sent_idx) 70 | 71 | 72 | def predict(data_name="test"): 73 | # print("Dataset: {}".format(data_name)) 74 | model.eval() 75 | 76 | data_batch = np.array(sent_idx).reshape(-1, 1) 77 | # print(data_batch) 78 | 79 | scores = model(data_batch) # 80 | # print(scores) 81 | 82 | if args.dataset == 'RelationDetection': 83 | # Get top k 84 | top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True) # shape: (batch_size, k) 85 | top_k_scores_array = top_k_scores.cpu().data.numpy() 86 | top_k_indices_array = top_k_indices.cpu().data.numpy() 87 | top_k_relatons_array = index2tag[top_k_indices_array] 88 | for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)): 89 | for j, (rel, score) in enumerate(zip(relations_row, scores_row)): 90 | print("{}\t{}".format(rel, score)) 91 | else: 92 | print("Wrong Dataset") 93 | exit() 94 | 95 | 96 | # run the model on the test set and write the output to a file 97 | predict(data_name="demo") 98 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class RelationDetection(nn.Module): 8 | def __init__(self, config): 9 | super(RelationDetection, self).__init__() 10 | self.config = config 11 | target_size = config.rel_label 12 | self.embed = nn.Embedding(config.words_num, config.words_dim) 13 | if config.train_embed == False: 14 | self.embed.weight.requires_grad = False 15 | if config.relation_prediction_mode.upper() == "GRU": 16 | self.gru = nn.GRU(input_size=config.input_size, 17 | hidden_size=config.hidden_size, 18 | num_layers=config.num_layer, 19 | dropout=config.rnn_dropout, 20 | bidirectional=True) 21 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 22 | self.relu = nn.ReLU() 23 | self.hidden2tag = nn.Sequential( 24 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 25 | nn.BatchNorm1d(config.hidden_size * 2), 26 | self.relu, 27 | self.dropout, 28 | nn.Linear(config.hidden_size * 2, target_size) 29 | ) 30 | if config.relation_prediction_mode.upper() == "LSTM": 31 | self.lstm = nn.LSTM(input_size=config.input_size, 32 | hidden_size=config.hidden_size, 33 | num_layers=config.num_layer, 34 | dropout=config.rnn_dropout, 35 | bidirectional=True) 36 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 37 | self.relu = nn.ReLU() 38 | self.hidden2tag = nn.Sequential( 39 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 40 | nn.BatchNorm1d(config.hidden_size * 2), 41 | self.relu, 42 | self.dropout, 43 | nn.Linear(config.hidden_size * 2, target_size) 44 | ) 45 | if config.relation_prediction_mode.upper() == "CNN": 46 | input_channel = 1 47 | Ks = 3 48 | self.conv1 = nn.Conv2d(input_channel, config.output_channel, (2, config.words_dim), padding=(1, 0)) 49 | self.conv2 = nn.Conv2d(input_channel, config.output_channel, (3, config.words_dim), padding=(2, 0)) 50 | self.conv3 = nn.Conv2d(input_channel, config.output_channel, (4, config.words_dim), padding=(3, 0)) 51 | self.dropout = nn.Dropout(config.cnn_dropout) 52 | self.fc1 = nn.Linear(Ks * config.output_channel, target_size) 53 | 54 | def forward(self, x): 55 | # x = (sequence length, batch_size, dimension of embedding) 56 | # text = x.text # todo 57 | 58 | if self.config.cuda: # geoff: demo 59 | text = torch.LongTensor(x).cuda() 60 | else: 61 | text = torch.LongTensor(x) 62 | 63 | batch_size = text.size()[1] 64 | x = self.embed(text) 65 | if self.config.relation_prediction_mode.upper() == "LSTM": 66 | # h0 / c0 = (layer*direction, batch_size, hidden_dim) 67 | if self.config.cuda: 68 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 69 | self.config.hidden_size).cuda()) 70 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 71 | self.config.hidden_size).cuda()) 72 | else: 73 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 74 | self.config.hidden_size)) 75 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 76 | self.config.hidden_size)) 77 | # output = (sentence length, batch_size, hidden_size * num_direction) 78 | # ht = (layer*direction, batch, hidden_dim) 79 | # ct = (layer*direction, batch, hidden_dim) 80 | outputs, (ht, ct) = self.lstm(x, (h0, c0)) 81 | print(outputs) 82 | print("&&&&&&&&&&&&&") 83 | print(ht[-2:]) 84 | tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)) 85 | scores = F.log_softmax(tags) 86 | return scores 87 | elif self.config.relation_prediction_mode.upper() == "GRU": 88 | if self.config.cuda: 89 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 90 | self.config.hidden_size).cuda()) 91 | else: 92 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 93 | self.config.hidden_size)) 94 | outputs, ht = self.gru(x, h0) 95 | 96 | tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)) 97 | scores = F.log_softmax(tags) 98 | return scores 99 | elif self.config.relation_prediction_mode.upper() == "CNN": 100 | x = x.transpose(0, 1).contiguous().unsqueeze(1) # (batch, channel_input, sent_len, embed_dim) 101 | # (batch, embed_dim, conv_sent_len) 102 | x = [F.relu(self.conv1(x)).squeeze(3), F.relu(self.conv2(x)).squeeze(3), F.relu(self.conv3(x)).squeeze(3)] 103 | # (batch, channel_output, ~=sent_len) * Ks 104 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # max-over-time pooling 105 | # (batch, channel_output) * Ks 106 | x = torch.cat(x, 1) # (batch, channel_output * Ks) 107 | x = self.dropout(x) 108 | logit = self.fc1(x) # (batch, target_size) 109 | scores = F.log_softmax(logit) 110 | return scores 111 | else: 112 | print("Unknown Mode") 113 | exit(1) 114 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/preprocess.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | def preprocess(mention, query): # "you", "how are you" 5 | new_query = query.replace(mention, "") 6 | print(new_query) 7 | return new_query 8 | 9 | 10 | def replace_mention(filepath, output_path): 11 | with open(filepath) as f: 12 | with open(output_path, "w") as fw: 13 | not_match = 0 14 | for idx, line in enumerate(f): 15 | tokens = line.split(" ") 16 | qid = tokens[0] 17 | sub = tokens[1] 18 | mention = tokens[2] 19 | relation = tokens[3] # qid = 20 | obj = tokens[4] 21 | query = tokens[5] 22 | tag = tokens[6] 23 | if mention not in query: 24 | not_match += 1 25 | new_query = query.replace(mention, "") 26 | fw.write( 27 | qid + "\t" + sub + "\t" + mention + "\t" + relation + "\t" + obj + "\t" + new_query + "\t" + tag) 28 | print("num", not_match) 29 | 30 | 31 | def get_mid2wiki(filepath): 32 | print("Loading Wiki") 33 | mid2wiki = defaultdict(bool) 34 | fin = open(filepath) 35 | idx = 0 36 | for line in fin.readlines(): 37 | items = line.strip().split('\t') 38 | if len(items) != 3: 39 | continue 40 | else: 41 | idx += 1 42 | url = items[2] 43 | print(idx, url[1:-3]) 44 | # sub = rdf2fb(clean_uri(items[0])) 45 | # mid2wiki[sub] = True 46 | return mid2wiki 47 | 48 | 49 | if __name__ == "__main__": 50 | # filepath = "../../data/processed_simplequestions_dataset/train.txt" 51 | # replace_mention(filepath) 52 | # get_mid2wiki("../../data/fb2w.nt") 53 | # result = preprocess("second battle of fort fisher", 54 | # "which military was involved in the second battle of fort fisher in China") 55 | replace_mention("../../data/processed_simplequestions_dataset/test.txt", "test_relation") 56 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torchtext import data 7 | 8 | from args import get_args 9 | from datasets import SimpleQuestionsDataset 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | 13 | np.set_printoptions(threshold=np.nan) 14 | # Set default configuration in : args.py 15 | args = get_args() 16 | 17 | # Set random seed for reproducibility 18 | torch.manual_seed(args.seed) 19 | np.random.seed(args.seed) 20 | random.seed(args.seed) 21 | 22 | if not args.cuda: 23 | args.gpu = -1 24 | if torch.cuda.is_available() and args.cuda: 25 | print("Note: You are using GPU for training") 26 | torch.cuda.set_device(args.gpu) 27 | torch.cuda.manual_seed(args.seed) 28 | if torch.cuda.is_available() and not args.cuda: 29 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 30 | 31 | TEXT = data.Field(lower=True) 32 | RELATION = data.Field(sequential=False) 33 | 34 | train, dev, test = SimpleQuestionsDataset.splits(TEXT, RELATION, args.data_dir) 35 | TEXT.build_vocab(train, dev, test) 36 | RELATION.build_vocab(train, dev) 37 | 38 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=args.gpu, train=True, repeat=False, 39 | sort=False, shuffle=True) 40 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 41 | sort=False, shuffle=False) 42 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 43 | sort=False, shuffle=False) 44 | 45 | # load the model 46 | model = torch.load(args.trained_model) 47 | # model = torch.load(args.trained_model, map_location=lambda storage, location: storage.cuda(args.gpu)) 48 | 49 | print(model) 50 | 51 | if args.dataset == 'RelationDetection': 52 | index2tag = np.array(RELATION.vocab.itos) 53 | else: 54 | print("Wrong Dataset") 55 | exit(1) 56 | 57 | index2word = np.array(TEXT.vocab.itos) 58 | 59 | results_path = os.path.join(args.results_path, args.relation_detection_mode.lower()) 60 | if not os.path.exists(results_path): 61 | os.makedirs(results_path, exist_ok=True) 62 | 63 | 64 | def predict(dataset_iter=test_iter, dataset=test, data_name="test"): 65 | print("Dataset: {}".format(data_name)) 66 | model.eval() 67 | dataset_iter.init_epoch() 68 | 69 | n_correct = 0 70 | fname = "{}.txt".format(data_name) 71 | results_file = open(os.path.join(results_path, fname), 'w') 72 | n_retrieved = 0 73 | 74 | fid = open(os.path.join(args.data_dir, "lineids_{}.txt".format(data_name))) 75 | sent_id = [x.strip() for x in fid.readlines()] 76 | 77 | for data_batch_idx, data_batch in enumerate(dataset_iter): 78 | scores = model(data_batch) 79 | if args.dataset == 'RelationDetection': 80 | n_correct += ( 81 | torch.max(scores, 1)[1].view(data_batch.relation.size()).data == data_batch.relation.data).sum() 82 | # Get top k 83 | top_k_scores, top_k_indices = torch.topk(scores, k=args.hits, dim=1, sorted=True) # shape: (batch_size, k) 84 | top_k_scores_array = top_k_scores.cpu().data.numpy() 85 | top_k_indices_array = top_k_indices.cpu().data.numpy() 86 | top_k_relatons_array = index2tag[top_k_indices_array] 87 | for i, (relations_row, scores_row) in enumerate(zip(top_k_relatons_array, top_k_scores_array)): 88 | index = (data_batch_idx * args.batch_size) + i 89 | example = data_batch.dataset.examples[index] 90 | for j, (rel, score) in enumerate(zip(relations_row, scores_row)): 91 | if (rel == example.relation): 92 | label = 1 93 | n_retrieved += 1 94 | else: 95 | label = 0 96 | results_file.write( 97 | "{} %%%% {} %%%% {} %%%% {}\n".format(sent_id[index], rel, label, score)) 98 | else: 99 | print("Wrong Dataset") 100 | exit() 101 | 102 | if args.dataset == 'RelationDetection': 103 | P = 1. * n_correct / len(dataset) 104 | print("{} Precision: {:10.6f}%".format(data_name, 100. * P)) 105 | print("no. retrieved: {} out of {}".format(n_retrieved, len(dataset))) 106 | retrieval_rate = 100. * n_retrieved / len(dataset) 107 | print("{} Retrieval Rate {:10.6f}".format(data_name, retrieval_rate)) 108 | else: 109 | print("Wrong dataset") 110 | exit() 111 | 112 | 113 | # run the model on the dev set and write the output to a file 114 | predict(dataset_iter=dev_iter, dataset=dev, data_name="valid") 115 | 116 | # run the model on the test set and write the output to a file 117 | predict(dataset_iter=test_iter, dataset=test, data_name="test") 118 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/nn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import numpy as np 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from torchtext import data 9 | 10 | from model import RelationDetection 11 | from datasets import SimpleQuestionsDataset 12 | from args import get_args 13 | start = time.time() 14 | 15 | np.set_printoptions(threshold=np.nan) 16 | # Set default configuration in : args.py 17 | args = get_args() 18 | 19 | # Set random seed for reproducibility 20 | torch.manual_seed(args.seed) 21 | np.random.seed(args.seed) 22 | random.seed(args.seed) 23 | torch.backends.cudnn.deterministic = True 24 | 25 | if not args.cuda: 26 | args.gpu = -1 27 | if torch.cuda.is_available() and args.cuda: 28 | print("Note: You are using GPU for training") 29 | torch.cuda.set_device(args.gpu) 30 | torch.cuda.manual_seed(args.seed) 31 | if torch.cuda.is_available() and not args.cuda: 32 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 33 | 34 | # Set up the data for training 35 | TEXT = data.Field(lower=True) 36 | RELATION = data.Field(sequential=False) 37 | 38 | train, dev, test = SimpleQuestionsDataset.splits(TEXT, RELATION, args.data_dir) 39 | TEXT.build_vocab(train, dev, test) 40 | RELATION.build_vocab(train, dev) # bug 41 | 42 | match_embedding = 0 43 | if os.path.isfile(args.vector_cache): 44 | stoi, vectors, dim = torch.load(args.vector_cache) # todo 45 | print(stoi) 46 | TEXT.vocab.vectors = torch.Tensor(len(TEXT.vocab), dim) 47 | for i, token in enumerate(TEXT.vocab.itos): 48 | wv_index = stoi.get(token, None) 49 | if wv_index is not None: 50 | TEXT.vocab.vectors[i] = vectors[wv_index] 51 | match_embedding += 1 52 | else: 53 | TEXT.vocab.vectors[i] = torch.FloatTensor(dim).uniform_(-0.25, 0.25) 54 | else: 55 | print("Error: Need word embedding pt file") 56 | exit(1) 57 | 58 | print("Embedding match number {} out of {}".format(match_embedding, len(TEXT.vocab))) 59 | 60 | train_iter = data.Iterator(train, batch_size=args.batch_size, device=args.gpu, train=True, repeat=False, 61 | sort=False, shuffle=True) 62 | dev_iter = data.Iterator(dev, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 63 | sort=False, shuffle=False) 64 | test_iter = data.Iterator(test, batch_size=args.batch_size, device=args.gpu, train=False, repeat=False, 65 | sort=False, shuffle=False) 66 | 67 | config = args 68 | config.words_num = len(TEXT.vocab) 69 | print("text vocabulary size:", config.words_num) 70 | print("relation vocabulary size:", RELATION.vocab) 71 | 72 | if args.dataset == 'RelationDetection': 73 | config.rel_label = len(RELATION.vocab) 74 | model = RelationDetection(config) 75 | else: 76 | print("Error Dataset") 77 | exit() 78 | 79 | model.embed.weight.data.copy_(TEXT.vocab.vectors) 80 | if args.cuda: 81 | model.cuda() 82 | print("Shift model to GPU") 83 | 84 | print(config) 85 | print("VOCAB num", len(TEXT.vocab)) 86 | print("Train instance", len(train)) 87 | print("Dev instance", len(dev)) 88 | print("Test instance", len(test)) 89 | print("Relation Type", len(RELATION.vocab)) 90 | print(model) 91 | print(args.train_embed) 92 | 93 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 94 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 95 | 96 | criterion = nn.NLLLoss() 97 | early_stop = False 98 | best_dev_P = 0 99 | iterations = 0 100 | iters_not_improved = 0 101 | num_dev_in_epoch = (len(train) // args.batch_size // args.dev_every) + 1 102 | patience = args.patience * num_dev_in_epoch # for early stopping 103 | epoch = 0 104 | start = time.time() 105 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 106 | dev_log_template = ' '.join( 107 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 108 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{},{}'.split(',')) 109 | save_path = os.path.join(args.save_path, args.relation_detection_mode.lower()) 110 | os.makedirs(save_path, exist_ok=True) 111 | print(header) 112 | 113 | if args.dataset == 'RelationDetection': 114 | index2tag = np.array(RELATION.vocab.itos) 115 | else: 116 | print("Wrong Dataset") 117 | exit(1) 118 | 119 | while True: 120 | if early_stop: 121 | print("Early Stopping. Epoch: {}, Best Dev Acc: {}".format(epoch, best_dev_P)) 122 | break 123 | epoch += 1 124 | train_iter.init_epoch() 125 | n_correct, n_total = 0, 0 126 | n_correct_ed, n_correct_ner, n_correct_rel = 0, 0, 0 127 | 128 | for batch_idx, batch in enumerate(train_iter): 129 | # Batch size : (Sentence Length, Batch_size) 130 | iterations += 1 131 | model.train() 132 | optimizer.zero_grad() 133 | scores = model(batch) 134 | if args.dataset == 'RelationDetection': 135 | n_correct += (torch.max(scores, 1)[1].view(batch.relation.size()).data == batch.relation.data).sum() 136 | loss = criterion(scores, batch.relation) 137 | else: 138 | print("Wrong Dataset") 139 | exit() 140 | 141 | n_total += batch.batch_size 142 | loss.backward() 143 | optimizer.step() 144 | 145 | # evaluate performance on validation set periodically 146 | if iterations % args.dev_every == 0: 147 | model.eval() 148 | dev_iter.init_epoch() 149 | n_dev_correct = 0 150 | n_dev_correct_rel = 0 151 | 152 | gold_list = [] 153 | pred_list = [] 154 | 155 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 156 | answer = model(dev_batch) 157 | 158 | if args.dataset == 'RelationDetection': 159 | n_dev_correct += ( 160 | torch.max(answer, 1)[1].view(dev_batch.relation.size()).data == dev_batch.relation.data).sum() 161 | else: 162 | print("Wrong Dataset") 163 | exit() 164 | 165 | if args.dataset == 'RelationDetection': 166 | P = 1. * n_dev_correct / len(dev) 167 | print("{} Precision: {:10.6f}%".format("Dev", 100. * P)) 168 | else: 169 | print("Wrong dataset") 170 | exit() 171 | 172 | # update model 173 | if args.dataset == 'RelationDetection': 174 | if P > best_dev_P: 175 | best_dev_P = P 176 | iters_not_improved = 0 177 | snapshot_path = os.path.join(save_path, args.specify_prefix + '_best_model_cpu3.pt') 178 | torch.save(model, snapshot_path) 179 | else: 180 | iters_not_improved += 1 181 | if iters_not_improved > patience: 182 | early_stop = True 183 | break 184 | else: 185 | print("Wrong dataset") 186 | exit() 187 | 188 | if iterations % args.log_every == 1: 189 | # print progress message 190 | print(log_template.format(time.time() - start, 191 | epoch, iterations, 1 + batch_idx, len(train_iter), 192 | 100. * (1 + batch_idx) / len(train_iter), loss.data[0], ' ' * 8, 193 | 100. * n_correct / n_total, ' ' * 12)) 194 | 195 | print('Time of train model: %f' % (time.time() - start)) 196 | # TypeError: unsupported operand type(s) for -: 'datetime.datetime' and 'float' 197 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description="Relation Detection") 6 | parser.add_argument('--relation_detection_mode', required=True, type=str, help='options are CNN, GRU, LSTM') 7 | parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda') 8 | parser.add_argument('--gpu', type=int, default=3) # Use -1 for CPU 9 | parser.add_argument('--epochs', type=int, default=30) 10 | parser.add_argument('--batch_size', type=int, default=128) # 32 11 | parser.add_argument('--dataset', type=str, default="RelationPrediction") 12 | parser.add_argument('--mode', type=str, default='static') 13 | parser.add_argument('--lr', type=float, default=1e-4) 14 | parser.add_argument('--seed', type=int, default=3435) 15 | parser.add_argument('--dev_every', type=int, default=2000) 16 | parser.add_argument('--log_every', type=int, default=1000) 17 | parser.add_argument('--patience', type=int, default=10) 18 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 19 | parser.add_argument('--specify_prefix', type=str, default='id1') 20 | parser.add_argument('--output_channel', type=int, default=100) # 300 21 | parser.add_argument('--words_dim', type=int, default=100) # 300 22 | parser.add_argument('--num_layer', type=int, default=2) 23 | parser.add_argument('--rnn_dropout', type=float, default=0.3, help='dropout in rnn') 24 | parser.add_argument('--input_size', type=int, default=100) # 300 25 | parser.add_argument('--hidden_size', type=int, default=100) # 300 26 | parser.add_argument('--rnn_fc_dropout', type=float, default=0.3, help='dropout before fully connected layer in RNN') 27 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 28 | parser.add_argument('--vector_cache', type=str, default="../../data/sq_glove300d.pt") 29 | parser.add_argument('--weight_decay', type=float, default=0) 30 | parser.add_argument('--cnn_dropout', type=float, default=0.5, help='dropout before fully connected layer in CNN') 31 | parser.add_argument('--fix_embed', action='store_false', dest='train_embed') 32 | parser.add_argument('--hits', type=int, default=30) # 5 33 | parser.add_argument('--neg_size', type=int, default=50) 34 | # added for testing 35 | parser.add_argument('--data_dir', type=str, default='../../data/processed_simplequestions_dataset/') 36 | parser.add_argument('--trained_model', type=str, default='') 37 | parser.add_argument('--results_path', type=str, default='results') 38 | # added for demo 39 | parser.add_argument('--input', type=str, default='') 40 | 41 | parser.add_argument('--vector_file', type=str, default="../../data/word2vec/gigaxin_ldc_vectors.min5.en") 42 | parser.add_argument('--index_relation', type=str, default="../../indexes/relation_sub_2M.pkl") 43 | args = parser.parse_args() 44 | return args 45 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/datasets.py: -------------------------------------------------------------------------------- 1 | import util 2 | 3 | 4 | class Dataset(object): 5 | def __init__(self, ques_list, rela_list, label, max_sent_len, word_dict): 6 | self.ques = ques_list # [[]] ? 7 | self.rela = rela_list 8 | self.word_dict = word_dict 9 | self.label = label 10 | self.size = len(label) 11 | self.max_sent_len = max_sent_len 12 | self.ques_idx, self.rela_idx = self.get_voc_idx(self.ques, self.rela) 13 | 14 | def get_voc_idx(self, ques, rela): 15 | # pad sentence 16 | pad = lambda x: util.pad_sentences(x, self.max_sent_len) 17 | pad_lst = lambda x: list(map(pad, x)) 18 | self.ques_pad = list(map(pad, ques)) 19 | self.rela_pad = list(map(pad_lst, rela)) 20 | # Represent sentences as list(nparray) of ints 21 | idx_func = lambda word: self.word_dict[word] if word in self.word_dict else self.word_dict[""] 22 | u_idx_func = lambda words: list(map(idx_func, words)) 23 | v_idx_func = lambda words_list: list(map(u_idx_func, words_list)) 24 | return list(map(u_idx_func, self.ques_pad)), list(map(v_idx_func, self.rela_pad)) 25 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding:utf-8 -*- 3 | import random 4 | import os 5 | import matplotlib 6 | import numpy as np 7 | import time 8 | 9 | matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! 10 | import matplotlib.pyplot as plt 11 | from matplotlib.font_manager import * 12 | 13 | plt.rcParams['font.sans-serif'] = ['YaHei Consolas Hybrid'] # set font family 14 | 15 | 16 | # zh_font = FontProperties(fname='../img/font/yahei_consolas_hybrid.ttf') # load chinese font for matplob 17 | 18 | 19 | def remove_dirs(rootdir='../img/'): 20 | filelist = os.listdir(rootdir) 21 | for f in filelist: 22 | filepath = os.path.join(rootdir, f) 23 | if os.path.isfile(filepath): 24 | os.remove(filepath) 25 | # print(filepath + " removed") 26 | # elif os.path.isdir(filepath): 27 | # shutil.rmtree(filepath, True) 28 | # print("dir " + filepath + " removed") 29 | 30 | 31 | def plot_attention(data, x_label=None, y_label=None, rootdir='imgs/'): 32 | ''' 33 | Plot the attention model heatmap 34 | Args: 35 | data: attn_matrix with shape [ty, tx], cutted before 'PAD' 36 | x_label: list of size tx, encoder tags 37 | y_label: list of size ty, decoder tags 38 | ''' 39 | fig, ax = plt.subplots(figsize=(20, 8)) # set figure size 40 | heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9) 41 | # Set axis labels 42 | if x_label != None and y_label != None: 43 | x_label = [x.decode('utf-8') for x in x_label] 44 | y_label = [y.decode('utf-8') for y in y_label] 45 | xticks = [x + 0.5 for x in range(0, len(x_label))] # range(0, len(x_label)) 46 | ax.set_xticks(xticks, minor=False) # major ticks 47 | ax.set_xticklabels(x_label, minor=False, rotation=90) # labels should be 'unicode' , fontproperties=zh_font 48 | yticks = [y + 0.5 for y in range(0, len(y_label))] # range(0, len(y_label)) 49 | ax.set_yticks(yticks, minor=False) 50 | ax.set_yticklabels(y_label, minor=False) # labels should be 'unicode' , fontproperties=zh_font 51 | # ax.grid(True) 52 | # Save Figure 53 | plt.title(u'Attention Heatmap') 54 | timestamp = int(time.time()) 55 | file_name = rootdir + str(timestamp) + "_" + str(random.randint(0, 1000)) + ".png" 56 | fig.savefig(file_name) # save the figure to file 57 | plt.close(fig) # close the figure 58 | 59 | 60 | # drop self attention 61 | def plot_attention2(data, x_label=None, rootdir='../img/self_'): 62 | fig, ax = plt.subplots(figsize=(20, 2)) 63 | heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9) 64 | if x_label != None: 65 | x_label = [x.decode('utf-8') for x in x_label] 66 | xticks = [x + 0.5 for x in range(0, len(x_label))] 67 | ax.set_xticks(xticks, minor=False) 68 | # ax.set_xticklabels(x_label, minor=False, fontproperties=zh_font) 69 | ax.set_yticks([0], minor=False) 70 | # ax.set_yticklabels("", minor=False, fontproperties=zh_font) 71 | # ax.grid(True) 72 | plt.title(u'Self Attention Heatmap') 73 | timestamp = int(time.time()) 74 | file_name = rootdir + str(timestamp) + "_" + str(random.randint(0, 1000)) + ".png" 75 | fig.savefig(file_name) # save the figure to file 76 | plt.close(fig) # close the figure 77 | 78 | 79 | if __name__ == "__main__": 80 | remove_dirs('img/') 81 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, cuda, mem_dim): 9 | super(Attention, self).__init__() 10 | self.cudaFlag = cuda 11 | self.mem_dim = mem_dim 12 | self.WQ = nn.Linear(mem_dim, mem_dim) # bias=False 13 | # self.WV = nn.Linear(mem_dim, mem_dim) 14 | # self.WP = nn.Linear(mem_dim, 1) 15 | 16 | def forward(self, query_h, doc_h): 17 | # doc_h = torch.unsqueeze(doc_h, 0) 18 | # ha = F.tanh(self.WQ(query_h) + self.WV(doc_h).expand_as(query_h)) # tan(W1a + W2b) 19 | # p = F.softmax(self.WP(ha).squeeze()) 20 | # weighted = p.unsqueeze(1).expand_as( 21 | # query_h) * query_h 22 | # v = weighted.sum(dim=0) 23 | p = F.softmax(torch.transpose(torch.mm(query_h, doc_h.unsqueeze(1)), 0, 1)) # dot 24 | weighted = torch.transpose(p, 0, 1).expand_as(query_h) * query_h 25 | v = weighted.sum(dim=0) 26 | return v, p 27 | 28 | 29 | class RelationDetection(nn.Module): 30 | def __init__(self, config): 31 | super(RelationDetection, self).__init__() 32 | self.config = config 33 | target_size = config.rel_label 34 | self.embed = nn.Embedding(config.words_num, config.words_dim) 35 | # self.attention = Attention(cuda, mem_dim) 36 | if config.train_embed == False: 37 | self.embed.weight.requires_grad = False 38 | if config.relation_detection_mode.upper() == "GRU": 39 | self.gru = nn.GRU(input_size=config.input_size, 40 | hidden_size=config.hidden_size, 41 | num_layers=config.num_layer, 42 | dropout=config.rnn_dropout, 43 | bidirectional=True) 44 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 45 | self.relu = nn.ReLU() 46 | self.hidden2tag = nn.Sequential( 47 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 48 | nn.BatchNorm1d(config.hidden_size * 2), 49 | self.relu, 50 | self.dropout, 51 | nn.Linear(config.hidden_size * 2, target_size) 52 | ) 53 | if config.relation_detection_mode.upper() == "LSTM": 54 | self.lstm = nn.LSTM(input_size=config.input_size, 55 | hidden_size=config.hidden_size, 56 | num_layers=config.num_layer, 57 | dropout=config.rnn_dropout, 58 | bidirectional=True) 59 | self.dropout = nn.Dropout(p=config.rnn_fc_dropout) 60 | self.relu = nn.ReLU() 61 | self.hidden2tag = nn.Sequential( 62 | nn.Linear(config.hidden_size * 2, config.hidden_size * 2), 63 | nn.BatchNorm1d(config.hidden_size * 2), 64 | self.relu, 65 | self.dropout, 66 | nn.Linear(config.hidden_size * 2, target_size) 67 | ) 68 | if config.relation_detection_mode.upper() == "CNN": 69 | input_channel = 1 70 | Ks = 3 71 | self.conv1 = nn.Conv2d(input_channel, config.output_channel, (2, config.words_dim), padding=(1, 0)) 72 | self.conv2 = nn.Conv2d(input_channel, config.output_channel, (3, config.words_dim), padding=(2, 0)) 73 | self.conv3 = nn.Conv2d(input_channel, config.output_channel, (4, config.words_dim), padding=(3, 0)) 74 | self.dropout = nn.Dropout(config.cnn_dropout) 75 | self.fc1 = nn.Linear(Ks * config.output_channel, target_size) 76 | 77 | def forward(self, ques, rela_list): 78 | batch_size = ques.size()[0] 79 | ques = self.embed(ques) # (batch_size, sent_len, embed_dim) 80 | rela_list = [self.embed(rela) for rela in rela_list] # (num_classes, batch_size, sent_len, embed_dim) 81 | rela_output = list() 82 | if self.config.relation_detection_mode.upper() == "LSTM": 83 | # h0 / c0 = (layer*direction, batch_size, hidden_dim) 84 | if self.config.cuda: 85 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 86 | self.config.hidden_size).cuda()) 87 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 88 | self.config.hidden_size).cuda()) 89 | else: 90 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 91 | self.config.hidden_size)) 92 | c0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 93 | self.config.hidden_size)) 94 | # output = (sentence length, batch_size, hidden_size * num_direction) 95 | # ht = (layer*direction, batch, hidden_dim) 96 | # ct = (layer*direction, batch, hidden_dim) 97 | outputs1, (ht1, ct1) = self.lstm(ques, (h0, c0)) 98 | # cross attention 99 | 100 | # query_cross_alphas = Var(torch.Tensor(query_state.size(0), target_state.size(0))) 101 | # target_cross_alphas = Var(torch.Tensor(target_state.size(0), query_state.size(0))) 102 | # q_to_t = Var(torch.Tensor(query_state.size(0), self.mem_dim)) 103 | # t_to_q = Var(torch.Tensor(target_state.size(0), self.mem_dim)) 104 | # for rela in rela_list: 105 | # outputs2, (ht2, ct2) = self.lstm(rela, (h0, c0)) 106 | # for i in range(query_state.size(0)): 107 | # q_to_t[i], query_cross_alphas[i] = self.attention(target_state, query_state[i,]) 108 | tags = self.hidden2tag(ht1[-2:].transpose(0, 1).contiguous().view(batch_size, -1)) 109 | scores = F.log_softmax(tags) 110 | return scores 111 | elif self.config.relation_detection_mode.upper() == "GRU": 112 | if self.config.cuda: 113 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 114 | self.config.hidden_size).cuda()) 115 | else: 116 | h0 = Variable(torch.zeros(self.config.num_layer * 2, batch_size, 117 | self.config.hidden_size)) 118 | outputs, ht = self.gru(ques, h0) 119 | 120 | tags = self.hidden2tag(ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)) 121 | scores = F.log_softmax(tags) 122 | return scores 123 | elif self.config.relation_detection_mode.upper() == "CNN": 124 | ques = ques.contiguous().unsqueeze(1) 125 | ques = [F.relu(self.conv1(ques)).squeeze(3), F.relu(self.conv2(ques)).squeeze(3), 126 | F.relu(self.conv3(ques)).squeeze(3)] 127 | ques = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in ques] # max-over-time pooling 128 | ques = torch.cat(ques, 1) # (batch, channel_output * Ks) 129 | ques = self.dropout(ques) 130 | # logit = self.fc1(ques) # (batch, target_size) 131 | ques = ques.unsqueeze(1) # (batch, 1, channel_output * Ks) 132 | for rela in rela_list: 133 | rela = rela.contiguous().unsqueeze(1) # rela.transpose(0, 1) 134 | rela = [F.relu(self.conv1(rela)).squeeze(3), F.relu(self.conv2(rela)).squeeze(3), 135 | F.relu(self.conv3(rela)).squeeze(3)] 136 | rela = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in rela] 137 | rela = torch.cat(rela, 1) 138 | rela = self.dropout(rela) 139 | rela = rela.unsqueeze(1) 140 | rela_output.append(rela) 141 | rela = torch.cat(rela_output, 1).transpose(0, 1).contiguous() 142 | dot = torch.sum(torch.mul(ques, rela), 2) 143 | sqrt_ques = torch.sqrt(torch.sum(torch.pow(ques, 2), 2)) 144 | sqrt_rela = torch.sqrt(torch.sum(torch.pow(rela, 2), 2)) 145 | # print(sqrt_ques, sqrt_rela) # 32,1 32,51 146 | epsilon = 1e-6 # 1e-6 147 | scores = dot / (sqrt_ques * sqrt_rela + epsilon) # torch.max(a, b)??? 148 | return scores 149 | else: 150 | print("Unknown Mode") 151 | exit(1) 152 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import math 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from torch.autograd import Variable as Var 11 | import util 12 | from args import get_args 13 | from datasets import Dataset 14 | from model import RelationDetection 15 | 16 | 17 | def pull_batch(question_list, relation_list, label_list, batch_idx): 18 | batch_size = config.batch_size 19 | if (batch_idx + 1) * batch_size < len(question_list): 20 | question_list = question_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 21 | relation_list = relation_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 22 | label_list = label_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 23 | else: # last batch 24 | question_list = question_list[batch_idx * batch_size:] 25 | relation_list = relation_list[batch_idx * batch_size:] 26 | label_list = label_list[batch_idx * batch_size:] 27 | return torch.LongTensor(question_list), torch.LongTensor(relation_list), torch.LongTensor(label_list) 28 | 29 | 30 | start = time.time() 31 | np.set_printoptions(threshold=np.nan) 32 | # Set default configuration in : args.py 33 | args = get_args() 34 | 35 | # Set random seed for reproducibility 36 | torch.manual_seed(args.seed) 37 | np.random.seed(args.seed) 38 | random.seed(args.seed) 39 | torch.backends.cudnn.deterministic = True 40 | 41 | if not args.cuda: 42 | args.gpu = -1 43 | if torch.cuda.is_available() and args.cuda: 44 | print("Note: You are using GPU for training") 45 | torch.cuda.set_device(args.gpu) 46 | torch.cuda.manual_seed(args.seed) 47 | if torch.cuda.is_available() and not args.cuda: 48 | print("Warning: You have Cuda but not use it. You are using CPU for training.") 49 | 50 | print ("loading word embedding...") 51 | word_dict, embedding = util.get_pretrained_word_vector(args.vector_file, (288694, 100)) 52 | print ("vocabulary size: %d" % len(word_dict)) 53 | 54 | print ("loading train data...") 55 | train_path = "../../data/processed_simplequestions_dataset/train.txt" 56 | valid_path = "../../data/processed_simplequestions_dataset/valid.txt" 57 | test_path = "../../data/processed_simplequestions_dataset/test.txt" 58 | x_u, x_r, y_train, max_ques, max_rela = util.load_data(args.index_relation, train_path, args.neg_size) 59 | train_set = Dataset(x_u, x_r, y_train, max_ques, word_dict) # todo 60 | print (np.array(train_set.ques_idx).shape, np.array(train_set.rela_idx).shape, np.array( 61 | train_set.label).shape) 62 | 63 | print ("loading dev data...") 64 | x_u, x_r, y_valid, max_ques, max_rela = util.load_data(args.index_relation, valid_path, args.neg_size) 65 | dev_set = Dataset(x_u, x_r, y_valid, max_ques, word_dict) 66 | print (np.array(dev_set.ques_idx).shape, np.array(dev_set.rela_idx).shape, np.array(dev_set.label).shape) 67 | 68 | # print ("loading test data...") 69 | # x_u, x_r, y_test, max_ques, max_rela = util.load_data(args.index_relation, valid_path) 70 | # dev_dataset = Dataset(x_u, x_r, y_test, max_ques, word_dict) 71 | # print (np.array(dev_dataset.ques_idx).shape, np.array(dev_dataset.rela_idx).shape, np.array(dev_dataset.label).shape) 72 | 73 | config = args 74 | config.words_num = 288694 + 1 75 | print("text vocabulary size:", config.words_num) 76 | 77 | if args.dataset == 'RelationDetection': 78 | config.rel_label = args.neg_size + 1 # num of classes 79 | model = RelationDetection(config) 80 | else: 81 | print("Error Dataset") 82 | exit() 83 | 84 | model.embed.weight.data.copy_(torch.FloatTensor(embedding)) 85 | 86 | if args.cuda: 87 | model.cuda() 88 | print("Shift model to GPU") 89 | 90 | print(config) 91 | print("VOCAB num", config.words_num) 92 | print("Train instance", len(y_train)) 93 | print("Dev instance", len(y_valid)) 94 | # print("Test instance", len(y_test)) 95 | print(model) 96 | 97 | parameter = filter(lambda p: p.requires_grad, model.parameters()) 98 | optimizer = torch.optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay) 99 | 100 | criterion = nn.CrossEntropyLoss() # nn.NLLLoss() 101 | early_stop = False 102 | best_dev_P = 0 103 | iterations = 0 104 | iters_not_improved = 0 105 | num_train_iter = int(math.ceil(train_set.size * 1.0 / args.batch_size)) 106 | num_dev_iter = int(math.ceil(dev_set.size * 1.0 / config.batch_size)) 107 | 108 | num_dev_in_epoch = (len(y_train) // args.batch_size // args.dev_every) + 1 109 | patience = args.patience * num_dev_in_epoch # for early stopping 110 | epoch = 0 111 | start = time.time() 112 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 113 | dev_log_template = ' '.join( 114 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 115 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{},{}'.split(',')) 116 | save_path = os.path.join(args.save_path, args.relation_detection_mode.lower()) 117 | os.makedirs(save_path, exist_ok=True) 118 | print(header) 119 | 120 | while epoch <= args.epochs: 121 | if early_stop: 122 | print("Early Stopping. Epoch: {}, Best Dev Acc: {}".format(epoch, best_dev_P)) 123 | break 124 | epoch += 1 125 | n_correct, n_total = 0, 0 126 | 127 | for train_step in tqdm(range(num_train_iter), 128 | desc='Training epoch ' + str(epoch) + ''): 129 | ques_batch, rela_batch, label_batch = pull_batch(train_set.ques_idx, train_set.rela_idx, 130 | train_set.label, train_step) # tensor 131 | 132 | # Batch size : (Sentence Length, Batch_size) 133 | iterations += 1 134 | model.train() 135 | optimizer.zero_grad() 136 | scores = model(ques_batch, rela_batch) 137 | if args.dataset == 'RelationDetection': 138 | # print(torch.max(scores, 1)[1]) 139 | n_correct += (torch.max(scores, 1)[1].data == label_batch).sum() 140 | loss = criterion(scores, Var(label_batch)) # volatile=True 141 | else: 142 | print("Wrong Dataset") 143 | exit() 144 | 145 | n_total += args.batch_size 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # evaluate performance on validation set periodically 150 | if iterations % args.dev_every == 0: 151 | model.eval() 152 | n_dev_correct = 0 153 | for dev_step in range(num_dev_iter): 154 | ques_batch, rela_batch, label_batch = pull_batch(dev_set.ques_idx, dev_set.rela_idx, 155 | dev_set.label, dev_step) 156 | dev_score = model(ques_batch, rela_batch) 157 | # target = Var(torch.LongTensor([int(label)]), volatile=True) 158 | if args.dataset == 'RelationDetection': 159 | n_dev_correct += (torch.max(dev_score, 1)[1].data == label_batch).sum() 160 | loss = criterion(dev_score, Var(label_batch, volatile=True)) 161 | else: 162 | print("Wrong Dataset") 163 | exit() 164 | 165 | if args.dataset == 'RelationDetection': 166 | P = 1. * n_dev_correct / dev_set.size 167 | print("{} Precision: {:10.6f}%".format("Dev", 100. * P)) 168 | else: 169 | print("Wrong dataset") 170 | exit() 171 | 172 | # update model 173 | if args.dataset == 'RelationDetection': 174 | if P > best_dev_P: 175 | best_dev_P = P 176 | iters_not_improved = 0 177 | snapshot_path = os.path.join(save_path, args.specify_prefix + '_dssm_best_model_cpu.pt') 178 | torch.save(model, snapshot_path) 179 | else: 180 | iters_not_improved += 1 181 | if iters_not_improved > patience: 182 | early_stop = True 183 | break 184 | else: 185 | print("Wrong dataset") 186 | exit() 187 | 188 | if iterations % args.log_every == 1: 189 | # print progress message 190 | print(log_template.format(time.time() - start, 191 | epoch, iterations, 1 + train_step, num_train_iter, 192 | 100. * (1 + train_step) / num_train_iter, loss.data[0], ' ' * 8, 193 | 100. * n_correct / n_total, ' ' * 12)) 194 | 195 | print('Time of train model: %f' % (time.time() - start)) 196 | # TypeError: unsupported operand type(s) for -: 'datetime.datetime' and 'float' 197 | -------------------------------------------------------------------------------- /NeuralQA/RelationDetection/siamese/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | from argparse import ArgumentParser 5 | from collections import defaultdict 6 | import random 7 | import numpy as np 8 | 9 | 10 | def clean_str(string): 11 | """ 12 | Tokenization/string cleaning for all datasets except for SST. 13 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 14 | """ 15 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 16 | string = re.sub(r"\'s", " \'s", string) 17 | string = re.sub(r"\'ve", " \'ve", string) 18 | string = re.sub(r"n\'t", " n\'t", string) 19 | string = re.sub(r"\'re", " \'re", string) 20 | string = re.sub(r"\'d", " \'d", string) 21 | string = re.sub(r"\'ll", " \'ll", string) 22 | string = re.sub(r",", " , ", string) 23 | string = re.sub(r"!", " ! ", string) 24 | string = re.sub(r"\(", " \( ", string) 25 | string = re.sub(r"\)", " \) ", string) 26 | string = re.sub(r"\?", " \? ", string) 27 | string = re.sub(r"\s{2,}", " ", string) 28 | return string.strip().lower() 29 | 30 | 31 | def get_questions(filename): 32 | print("getting questions ...") 33 | id2questions = {} 34 | id2goldrelas = {} 35 | qids = list() 36 | fin = open(filename) 37 | for line in fin.readlines(): 38 | items = line.strip().split('\t') 39 | lineid = items[0].strip() 40 | # mid = items[1].strip() 41 | question = items[5].strip() 42 | rel = items[3].strip() 43 | qids.append(lineid) 44 | id2questions[lineid] = question 45 | id2goldrelas[lineid] = rel 46 | return qids, id2questions, id2goldrelas 47 | 48 | 49 | # Load predicted MIDs and relations for each question in valid/test set 50 | def get_mids(filename, hits=100): 51 | print("Entity Source : {}".format(filename)) 52 | id2mids = defaultdict(list) 53 | qids = list() 54 | fin = open(filename) 55 | for line in fin.readlines(): 56 | items = line.strip().split(' %%%% ') 57 | lineid = items[0] 58 | cand_mids = items[1:] # [:hits] 59 | qids.append(lineid) 60 | for mid_entry in cand_mids: 61 | mid, mid_name, mid_type, score = mid_entry.split('\t') 62 | id2mids[lineid].append(mid) 63 | return qids, id2mids 64 | 65 | 66 | def pad_sentences(sentence, length, padding_word=""): 67 | num_padding = length - len(sentence) 68 | new_sentence = sentence + [padding_word] * num_padding 69 | return new_sentence 70 | 71 | 72 | def get_pretrained_word_vector(file_path, dim): 73 | word_dict = {} 74 | embedding_matrix = np.zeros((dim[0] + 1, dim[1])) 75 | with open(file_path) as f: 76 | for idx, lines in enumerate(f): 77 | word_dict[lines.split()[0]] = idx 78 | embedding_matrix[idx] = np.array([float(item) for item in lines.split()[1:]]) 79 | word_dict[''] = dim[0] # 80 | return word_dict, embedding_matrix 81 | 82 | 83 | def rela2idx(relation): 84 | rela_split = relation[3:].replace('_', '.').split('.') 85 | return rela_split 86 | 87 | 88 | def load_index(filename): 89 | print("Loading index map from {}".format(filename)) 90 | with open(filename, 'rb') as handler: 91 | index = pickle.load(handler) 92 | return index 93 | 94 | 95 | def preprocess(index_reach, data_path, output_dir, ent_path, hits_ent=100): 96 | _, id2mids = get_mids(ent_path, hits_ent) 97 | qids, id2questions, id2goldrelas = get_questions(data_path) 98 | results_file = open(os.path.join(output_dir, "rela.top50el.test"), 'w') 99 | hit, max_cad_rela = 0, 0 # test top100 entity: 312 100 | for qid in qids: 101 | cand_relas = set() # set 102 | question = id2questions[qid] 103 | gold_rela = id2goldrelas[qid] 104 | cand_mids = id2mids[qid] 105 | for mid in cand_mids: 106 | link_rel = index_reach[mid] 107 | cand_relas = cand_relas | set(link_rel) 108 | max_cad_rela = max(max_cad_rela, len(cand_relas)) 109 | if gold_rela in cand_relas: 110 | hit += 1 111 | for cand_rela in set(cand_relas): 112 | results_file.write("{} %%%% {} %%%% {} %%%% {}\n".format(qid, question, gold_rela, cand_rela)) 113 | print(max_cad_rela) 114 | print(hit, len(qids), float(hit / len(qids))) 115 | 116 | 117 | def load_data(rela_voc_path, data_path, neg_size=50): 118 | qids, id2questions, id2goldrelas = get_questions(data_path) 119 | ques_list, rela_list = list(), list() 120 | max_ques_len, max_rela_len = 0, 0 121 | label = list() 122 | rela_voc = load_index(rela_voc_path) 123 | 124 | for qid in qids: 125 | temp_list = list() 126 | question = id2questions[qid] 127 | gold_rela = id2goldrelas[qid][3:] # fb: 128 | gold_rela_split = clean_str(gold_rela).split(" ") 129 | temp_list.append(gold_rela_split) 130 | while len(temp_list) < (neg_size + 1): 131 | rand_idx = random.randint(0, len(rela_voc) - 1) 132 | rela_split = clean_str(rela_voc[rand_idx][3:]).split(" ") 133 | if rela_split not in temp_list: 134 | temp_list.append(rela_split) 135 | random.shuffle(temp_list) 136 | rela_list.append(temp_list) 137 | 138 | ques_split = clean_str(question).split(" ") 139 | max_ques_len = max(max_ques_len, len(ques_split)) 140 | ques_list.append(ques_split) 141 | 142 | for rela in temp_list: 143 | if rela == gold_rela_split: 144 | idx = temp_list.index(rela) 145 | label.append(idx) 146 | print("max_ques_len:", max_ques_len) 147 | # print("max_rela_len:", max_rela_len) # 17 148 | print(len(qids), len(label)) 149 | print(ques_list[1], rela_list[1]) 150 | print(label[:5]) 151 | return ques_list, rela_list, label, max_ques_len, max_rela_len 152 | 153 | 154 | if __name__ == "__main__": 155 | parser = ArgumentParser(description='Perform evidence integration') 156 | parser.add_argument('--index_reachpath', type=str, default="../../indexes/reachability_2M.pkl", 157 | help='path to the pickle for the reachability index') 158 | parser.add_argument('--index_relation', type=str, default="../../indexes/relation_sub_2M.pkl", 159 | help='path to the pickle for the relation index') 160 | parser.add_argument('--data_path', type=str, default="../../data/processed_simplequestions_dataset/test.txt") 161 | # parser.add_argument('--ent_path', type=str, default="../../entity_linking/results/nn/test-h100.txt", 162 | # help='path to the entity linking results') 163 | parser.add_argument('--hits_ent', type=int, default=50, 164 | help='the hits here has to be <= the hits in entity linking') 165 | parser.add_argument('--output_dir', type=str, default="./results") 166 | args = parser.parse_args() 167 | print(args) 168 | 169 | relation_voc = load_index(args.index_relation) 170 | # index_reach = load_index(args.index_reachpath) 171 | load_data(relation_voc, args.data_path) 172 | # print "process valid data:" 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NJU_KBQA 2 | WebQA:基于Freebase知识库的线上问答系统
3 | NeuralQA:神经网络问答基线系统
4 | SiameseNetwork:短文本匹配相关模型

5 | ![image](https://github.com/geofftong/NJU_KBQA/blob/master/demo.gif) 6 | -------------------------------------------------------------------------------- /SiameseNetwork/config.py: -------------------------------------------------------------------------------- 1 | # 75910/10845/21687 2 | train_file = "../data/SimpleQuestions/output/simple.train.top20el.top300relation" 3 | dev_file = "../data/SimpleQuestions/output/simple.valid.top20el.top300relation" 4 | test_file = "../data/SimpleQuestions/output/simple.test.top20el.top300relation" 5 | train_output_file = "../data/SimpleQuestions/output/simple.train.top20el.top100relation.score" 6 | dev_output_file = "../data/SimpleQuestions/output/simple.valid.top20el.top100relation.score" 7 | test_output_file = "../data/SimpleQuestions/output/simple.test.top20el.top300relation.score" 8 | word2vec_file = "../data/word2vec/gigaxin_ldc_vectors.min5.en" 9 | neg_sample = 50 # train 50 10 | num_classes = 300 # 20 valid/test 11 | max_sent_len = 36 12 | model_name = "model" 13 | voc_size = 288694 14 | emb_size = 100 15 | filter_sizes = [2, 3, 4] 16 | num_filters = 50 17 | dropout_keep_prob = 0.85 # 1.0 18 | embeddings_trainable = True 19 | epoch_num = 30 # 30 epoches 20 | batch_size = 16 # 32 21 | l2_reg_lambda = 0.1 22 | early_step = 3 # epoch step of no improving in dev set 23 | 24 | # lstm 25 | hidden_units = 128 26 | -------------------------------------------------------------------------------- /SiameseNetwork/qa_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import config 4 | 5 | 6 | class QaCNN(object): 7 | def __init__(self, sequence_length, num_classes, filter_sizes, 8 | num_filters, init_embeddings, 9 | embeddings_trainable=False, l2_reg_lambda=0.0): 10 | self.batch_size = config.batch_size 11 | self.embedding_size = np.shape(init_embeddings)[1] 12 | # Keeping track of l2 regularization loss (optional) 13 | l2_loss = tf.constant(0.0) 14 | # self.train_flag = tf.placeholder(dtype=tf.bool, name='bool') 15 | with tf.name_scope('input'): 16 | u_shape = [None, sequence_length] # self.batch_size 17 | v_shape = [None, num_classes, sequence_length] # train:100 test:300 18 | self.input_x_u = tf.placeholder(tf.int32, u_shape, name="input_x_u") 19 | self.input_x_r = tf.placeholder(tf.int32, v_shape, name="input_x_r") 20 | self.input_y = tf.placeholder(tf.int64, [None], name="input_y") # input_y: batch_size, 21 | 22 | # Embedding layer 23 | with tf.name_scope("embedding"): 24 | W = tf.Variable(init_embeddings, trainable=embeddings_trainable, dtype=tf.float32, name='W') 25 | self.embedded_u = tf.nn.embedding_lookup(W, self.input_x_u) # bs x seq_len x emb_size 26 | print ("DEBUG: embedded_u -> %s" % self.embedded_u) 27 | self.embedded_r = tf.nn.embedding_lookup(W, self.input_x_r) # bs x neg_size x seq_len x emb_size 28 | print ("DEBUG: embedded_r -> %s" % self.embedded_r) 29 | self.embedded_u_expanded = tf.expand_dims(self.embedded_u, -1) # bs x seq_len x emb_size x 1 30 | print ("DEBUG: embedded_u_expanded -> %s" % self.embedded_u_expanded) 31 | self.embedded_r_expanded = tf.expand_dims(self.embedded_r, -1) # bs x neg_size x seq_len x emb_size x 1 32 | print ("DEBUG: embedded_r_expanded -> %s" % self.embedded_r_expanded) 33 | 34 | # Create a convolution + maxpooling layer for each filter size 35 | pooled_outputs_u = [] 36 | pooled_outputs_r = [] 37 | for i, filter_size in enumerate(filter_sizes): # [2, 3, 4] 38 | with tf.name_scope("conv-maxpool-%s" % filter_size): # 50 39 | # Convolution layer 40 | filter_shape = [filter_size, self.embedding_size, 1, num_filters] 41 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), 42 | name='W') 43 | b = tf.Variable(tf.constant(0.1, shape=[num_filters]), 44 | name='b') 45 | l2_loss += tf.nn.l2_loss(W) 46 | l2_loss += tf.nn.l2_loss(b) 47 | conv_u = tf.nn.conv2d( 48 | self.embedded_u_expanded, 49 | W, 50 | strides=[1, 1, 1, 1], 51 | padding="VALID", 52 | name="conv-u") 53 | 54 | # Apply nonlinearity 55 | h_u = tf.nn.sigmoid(tf.nn.bias_add(conv_u, b), name="activation-u") 56 | 57 | # Maxpooling over outputs 58 | pooled_u = tf.nn.max_pool( 59 | h_u, 60 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 61 | strides=[1, 1, 1, 1], 62 | padding="VALID", 63 | name="pool-u") 64 | pooled_outputs_u.append(pooled_u) # 1 x num_filters 65 | 66 | # Pass each element in x_r through the same layer 67 | pooled_outputs_r_classes = [] 68 | 69 | # num_classes = tf.where(self.train_flag, config.negative_size, config.max_candidate_relation) 70 | 71 | for j in range(num_classes): 72 | embedded_r = self.embedded_r_expanded[:, j, :, :, :] 73 | conv_r_j = tf.nn.conv2d( 74 | embedded_r, 75 | W, 76 | strides=[1, 1, 1, 1], 77 | padding="VALID", 78 | name="conv-r-%s" % j) 79 | 80 | h_r_j = tf.nn.sigmoid(tf.nn.bias_add(conv_r_j, b), name="activation-r-%s" % j) 81 | 82 | pooled_r_j = tf.nn.max_pool( 83 | h_r_j, 84 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 85 | strides=[1, 1, 1, 1], 86 | padding="VALID", 87 | name="pool-r-%s" % j) 88 | pooled_outputs_r_classes.append(pooled_r_j) 89 | # print "DEBUG: pooled_outputs_r_classes -> %s" % pooled_outputs_r_classes 90 | 91 | # out_tensor: batch_size x 1 x num_class x num_filters 92 | out_tensor = tf.concat(pooled_outputs_r_classes, 2) 93 | print ("DEBUG: out_tensor -> %s" % out_tensor) 94 | pooled_outputs_r.append(out_tensor) 95 | 96 | # Combine all the pooled features 97 | num_filters_total = num_filters * len(filter_sizes) 98 | print ("DEBUG: pooled_outputs_u -> %s" % pooled_outputs_u) 99 | self.h_pool_u = tf.concat(pooled_outputs_u, 3) 100 | print ("DEBUG: h_pool_u -> %s" % self.h_pool_u) 101 | # batch_size x 1 x num_filters_total 102 | self.h_pool_flat_u = tf.reshape(self.h_pool_u, [-1, 1, num_filters_total]) 103 | print ("DEBUG: h_pool_flat_u -> %s" % self.h_pool_flat_u) 104 | 105 | print ("DEBUG: pooled_outputs_r -> %s" % pooled_outputs_r) 106 | self.h_pool_r = tf.concat(pooled_outputs_r, 3) 107 | print ("DEBUG: h_pool_r -> %s" % self.h_pool_r) 108 | # h_pool_flat_r: batch_size x num_classes X num_filters_total 109 | self.h_pool_flat_r = tf.reshape(self.h_pool_r, [-1, num_classes, num_filters_total]) 110 | print ("DEBUG: h_pool_flat_r -> %s" % self.h_pool_flat_r) 111 | 112 | # # Add dropout layer to avoid overfitting 113 | # with tf.name_scope("dropout"): 114 | # self.h_features = tf.concat([self.h_pool_flat_u, self.h_pool_flat_r], 1) 115 | # print "DEBUG: h_features -> %s" % self.h_features 116 | # self.h_features_dropped = tf.nn.dropout(self.h_features, config.dropout_keep_prob) 117 | # 118 | # self.h_dropped_u = self.h_features_dropped[:, :1, :] 119 | # self.h_dropped_r = self.h_features_dropped[:, 1:, :] 120 | # print "DEBUG: h_dropped_u -> %s" % self.h_dropped_u 121 | # print "DEBUG: h_dropped_r -> %s" % self.h_dropped_r 122 | 123 | # cosine layer - final scores and predictions 124 | with tf.name_scope("cosine_layer"): 125 | self.dot = tf.reduce_sum(tf.multiply(self.h_pool_flat_u, self.h_pool_flat_r), 2) 126 | print ("DEBUG: dot -> %s" % self.dot) 127 | self.sqrt_u = tf.sqrt(tf.reduce_sum(self.h_pool_flat_u ** 2, 2)) 128 | print ("DEBUG: sqrt_u -> %s" % self.sqrt_u) 129 | self.sqrt_r = tf.sqrt(tf.reduce_sum(self.h_pool_flat_r ** 2, 2)) 130 | print ("DEBUG: sqrt_r -> %s" % self.sqrt_r) 131 | epsilon = 1e-5 132 | self.cosine = tf.maximum(self.dot / (tf.maximum(self.sqrt_u * self.sqrt_r, epsilon)), epsilon) 133 | print ("DEBUG: cosine -> %s" % self.cosine) 134 | self.score = tf.nn.softmax(self.cosine) # TODO 135 | print ("DEBUG: score -> %s" % self.score) 136 | self.predictions = tf.argmax(self.cosine, 1, name="predictions") 137 | print ("DEBUG: predictions -> %s" % self.predictions) 138 | 139 | # softmax regression - loss and prediction 140 | with tf.name_scope("loss"): 141 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_y, logits=100 * self.cosine) 142 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 143 | 144 | # Calculate Accuracy 145 | with tf.name_scope("accuracy"): 146 | correct_predictions = tf.equal(self.predictions, self.input_y) 147 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 148 | self.correct_num = tf.reduce_sum(tf.cast(correct_predictions, "float"), name="correct_num") 149 | -------------------------------------------------------------------------------- /SiameseNetwork/qa_lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow import nn 4 | from tensorflow.contrib import rnn 5 | 6 | 7 | class SiameseLSTM(object): 8 | def BiRNN(self, x, x_lens, n_steps, n_hidden, biRnnScopeName, dropoutName): 9 | x = tf.nn.dropout(x, 0.5, name=dropoutName) 10 | # Forward direction cell 11 | lstm_fw_cell = rnn.LSTMCell(n_hidden) 12 | # Backward direction cell 13 | lstm_bw_cell = rnn.LSTMCell(n_hidden) 14 | # need scope to identified different cell 15 | outputs, output_states = nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, x, sequence_length=x_lens, 16 | dtype=tf.float32, scope=biRnnScopeName) 17 | return outputs, output_states 18 | 19 | def __init__(self, max_length, init_embeddings, num_classes, hidden_units, 20 | embeddings_trainable=False, l2_reg_lambda=0): # vocab_size, embedding_size 21 | with tf.name_scope('input'): 22 | u_shape = [None, max_length] # self.batch_size 23 | v_shape = [None, num_classes, max_length] # train:100 test:300 24 | self.input_x_u = tf.placeholder(tf.int32, u_shape, name="input_x_u") 25 | self.input_x_r = tf.placeholder(tf.int32, v_shape, name="input_x_r") 26 | self.input_y = tf.placeholder(tf.int64, [None], name="input_y") # input_y: batch_size, 27 | self.u_lens = tf.placeholder(tf.int32, [None]) 28 | self.v_lens = tf.placeholder(tf.int32, [None]) 29 | self.n_steps = max_length 30 | self.hidden_size = hidden_units 31 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 32 | self.embedding_size = np.shape(init_embeddings)[1] 33 | l2_loss = tf.constant(0.0, name="l2_loss") # optional: l2 regularization loss 34 | 35 | with tf.name_scope("embedding"): 36 | # self.W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_size]), trainable=embeddings_trainable, 37 | # name="W") 38 | self.W = tf.Variable(init_embeddings, trainable=embeddings_trainable, dtype=tf.float32, name='W') 39 | self.embedded_u = tf.nn.embedding_lookup(self.W, self.input_x_u) # (batch_size, max_len, dim) 40 | print("DEBUG: embedded_u -> %s" % self.embedded_u) 41 | self.embedded_v = tf.nn.embedding_lookup(self.W, self.input_x_r) # (batch_size, num_classes, max_len, dim) 42 | print ("DEBUG: embedded_v -> %s" % self.embedded_v) 43 | 44 | # Create a convolution + maxpool layer for each filter size 45 | with tf.name_scope("output"): 46 | self.output1, _ = self.BiRNN(self.embedded_u, self.u_lens, self.n_steps, self.hidden_size, 47 | "relation_1", "relation_dropout_1") # batch_size * dim 48 | outputs_fw, outputs_bw = self.output1 # ?, max_time, dim 49 | self.question_embedding = tf.concat([outputs_fw[:, -1, :], outputs_bw[:, 0, :]], 1) 50 | self.question_embedding = tf.expand_dims(self.question_embedding, 1) 51 | print ("DEBUG: question_embedding -> %s" % self.question_embedding ) # ?, 2*dim 52 | 53 | outputs_v_classes = [] 54 | for j in range(num_classes): 55 | embedded_v = self.embedded_v[:, j, :, :] 56 | self.output2, _ = self.BiRNN(embedded_v, self.v_lens, self.n_steps, self.hidden_size, 57 | "relation_2_%d" % j, 58 | "relation_dropout_2") 59 | outputs_fw, outputs_bw = self.output2 60 | relation_embedding = tf.concat([outputs_fw[:, -1, :], outputs_bw[:, 0, :]], 1) 61 | # print "DEBUG: relation_embedding_temp -> %s" % relation_embedding 62 | relation_embedding_expand = tf.expand_dims(relation_embedding, 1) 63 | outputs_v_classes.append(relation_embedding_expand) 64 | self.relation_embedding = tf.concat(outputs_v_classes, 1) 65 | print ("DEBUG: relation_embedding -> %s" % self.relation_embedding) 66 | 67 | # cosine layer - final scores and predictions 68 | with tf.name_scope("cosine_layer"): 69 | self.dot = tf.reduce_sum(tf.multiply(self.question_embedding, self.relation_embedding), 2) 70 | print ("DEBUG: dot -> %s" % self.dot) 71 | self.sqrt_u = tf.sqrt(tf.reduce_sum(self.question_embedding ** 2, 2)) 72 | print ("DEBUG: sqrt_u -> %s" % self.sqrt_u) 73 | self.sqrt_r = tf.sqrt(tf.reduce_sum(self.relation_embedding ** 2, 2)) 74 | print ("DEBUG: sqrt_r -> %s" % self.sqrt_r) 75 | epsilon = 1e-5 76 | self.cosine = tf.maximum(self.dot / (tf.maximum(self.sqrt_u * self.sqrt_r, epsilon)), epsilon) 77 | print ("DEBUG: cosine -> %s" % self.cosine) 78 | self.score = tf.nn.softmax(self.cosine) # TODO 79 | print ("DEBUG: score -> %s" % self.score) 80 | self.predictions = tf.argmax(self.cosine, 1, name="predictions") 81 | print ("DEBUG: predictions -> %s" % self.predictions) 82 | 83 | # softmax regression - loss and prediction 84 | with tf.name_scope("loss"): 85 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_y, logits=100 * self.cosine) 86 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 87 | 88 | # Calculate Accuracy 89 | with tf.name_scope("accuracy"): 90 | correct_predictions = tf.equal(self.predictions, self.input_y) 91 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 92 | self.correct_num = tf.reduce_sum(tf.cast(correct_predictions, "float"), name="correct_num") 93 | -------------------------------------------------------------------------------- /SiameseNetwork/qa_nn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import config 4 | 5 | 6 | class QaNN(object): 7 | def __init__(self, sequence_length, init_embeddings, num_classes, 8 | embeddings_trainable=False, l2_reg_lambda=0.0): 9 | self.L1 = 300 10 | self.L2 = 120 11 | self.batch_size = config.batch_size 12 | self.embedding_size = np.shape(init_embeddings)[1] 13 | # print self.embedding_size 14 | # Keeping track of l2 regularization loss (optional) 15 | l2_loss = tf.constant(0.0) 16 | with tf.name_scope('input'): 17 | u_shape = [None, sequence_length] # self.batch_size 18 | v_shape = [None, num_classes, sequence_length] # train:100 test:300 19 | self.input_x_u = tf.placeholder(tf.int32, u_shape, name="input_x_u") 20 | self.input_x_r = tf.placeholder(tf.int32, v_shape, name="input_x_r") 21 | self.input_y = tf.placeholder(tf.int64, [None], name="input_y") # input_y: batch_size, 22 | 23 | # Embedding layer 24 | with tf.name_scope("embedding"): 25 | W = tf.Variable(init_embeddings, trainable=embeddings_trainable, dtype=tf.float32, name='W') 26 | self.embedded_u = tf.nn.embedding_lookup(W, self.input_x_u) # batch_size x sent_len x embedding_size 27 | self.embedded_u = tf.reduce_sum(self.embedded_u, 1) # batch_size x embedding_size 28 | print ("DEBUG: embedded_u -> %s" % self.embedded_u) 29 | self.embedded_r = tf.nn.embedding_lookup(W, self.input_x_r) 30 | self.embedded_r = tf.reduce_sum(self.embedded_r, 2) # batch_size x neg_size x embedding_size 31 | self.embedded_r = tf.reshape(self.embedded_r, [-1, self.embedding_size]) 32 | print ("DEBUG: embedded_r -> %s" % self.embedded_r) 33 | 34 | with tf.name_scope('L1'): 35 | l1_par_range = np.sqrt(6.0 / (self.embedding_size + self.L1)) 36 | weight1 = tf.Variable(tf.random_uniform([self.embedding_size, self.L1], -l1_par_range, l1_par_range)) 37 | bias1 = tf.Variable(tf.random_uniform([self.L1], -l1_par_range, l1_par_range)) 38 | query_l1 = tf.matmul(self.embedded_u, weight1) + bias1 39 | doc_l1 = tf.matmul(self.embedded_r, weight1) + bias1 40 | self.query_l1 = tf.nn.relu(query_l1) 41 | print("DEBUG: query_l1 -> %s" % self.query_l1) 42 | self.doc_l1 = tf.nn.relu(doc_l1) 43 | print ("DEBUG: doc_l1 -> %s" % self.doc_l1) 44 | 45 | with tf.name_scope('L2'): 46 | l2_par_range = np.sqrt(6.0 / (self.L1 + self.L2)) 47 | weight2 = tf.Variable(tf.random_uniform([self.L1, self.L2], -l2_par_range, l2_par_range)) 48 | bias2 = tf.Variable(tf.random_uniform([self.L2], -l2_par_range, l2_par_range)) 49 | query_l2 = tf.matmul(self.query_l1, weight2) + bias2 50 | doc_l2 = tf.matmul(self.doc_l1, weight2) + bias2 51 | self.query_l2 = tf.nn.relu(query_l2) 52 | self.query_l2 = tf.expand_dims(self.query_l2, 1) 53 | print ("DEBUG: query_l2 -> %s" % self.query_l2) 54 | self.doc_l2 = tf.nn.relu(doc_l2) 55 | self.doc_l2 = tf.reshape(self.doc_l2, [-1, num_classes, self.L2]) 56 | print ("DEBUG: doc_l2 -> %s" % self.doc_l2) 57 | # doc_y = tf.nn.dropout(doc_y, 0.5) 58 | 59 | # cosine layer - final scores and predictions 60 | with tf.name_scope("cosine_layer"): 61 | self.dot = tf.reduce_sum(tf.multiply(self.query_l2, self.doc_l2), 2) 62 | print ("DEBUG: dot -> %s" % self.dot) 63 | self.sqrt_u = tf.sqrt(tf.reduce_sum(self.query_l2 ** 2, 2)) 64 | print ("DEBUG: sqrt_u -> %s" % self.sqrt_u) 65 | self.sqrt_r = tf.sqrt(tf.reduce_sum(self.doc_l2 ** 2, 2)) 66 | print ("DEBUG: sqrt_r -> %s" % self.sqrt_r) 67 | epsilon = 1e-5 68 | self.cosine = tf.maximum(self.dot / (tf.maximum(self.sqrt_u * self.sqrt_r, epsilon)), epsilon) 69 | print("DEBUG: cosine -> %s" % self.cosine) 70 | self.score = tf.nn.softmax(self.cosine) # TODO: score:[0, 2] 71 | print("DEBUG: score -> %s" % self.score) 72 | self.predictions = tf.argmax(self.cosine, 1, name="predictions") 73 | print("DEBUG: predictions -> %s" % self.predictions) 74 | 75 | # softmax regression - loss and prediction 76 | with tf.name_scope("loss"): 77 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_y, logits=100 * self.cosine) 78 | self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss 79 | 80 | # Calculate Accuracy 81 | with tf.name_scope("accuracy"): 82 | correct_predictions = tf.equal(self.predictions, self.input_y) 83 | self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy") 84 | self.correct_num = tf.reduce_sum(tf.cast(correct_predictions, "float"), name="correct_num") -------------------------------------------------------------------------------- /SiameseNetwork/train_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import config 4 | import math 5 | import numpy as np 6 | import sys 7 | import tensorflow as tf 8 | import time 9 | from qa_cnn import QaCNN 10 | from tqdm import tqdm 11 | from util import util 12 | from util.dataset import Dataset 13 | 14 | 15 | def pull_batch(question_list, relation_list, label_list, batch_idx): 16 | batch_size = config.batch_size 17 | if (batch_idx + 1) * batch_size < len(question_list): 18 | question_list = question_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 19 | relation_list = relation_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 20 | label_list = label_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 21 | else: # last batch 22 | question_list = question_list[batch_idx * batch_size:] 23 | relation_list = relation_list[batch_idx * batch_size:] 24 | label_list = label_list[batch_idx * batch_size:] 25 | return question_list, relation_list, label_list 26 | 27 | 28 | def train(train_set, dev_set, max_len, U): 29 | net = QaCNN(sequence_length=max_len, 30 | num_classes=config.neg_sample, 31 | filter_sizes=config.filter_sizes, 32 | num_filters=config.num_filters, 33 | init_embeddings=U, 34 | embeddings_trainable=config.embeddings_trainable, 35 | l2_reg_lambda=config.l2_reg_lambda) 36 | global_step = tf.Variable(0, name="global_step", trainable=True) 37 | optimizer = tf.train.AdamOptimizer() 38 | grads_and_vars = optimizer.compute_gradients(net.loss) 39 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 40 | 41 | with tf.Session() as sess: 42 | sess.run(tf.global_variables_initializer()) 43 | 44 | # Summaries for loss and accuracy 45 | loss_summary = tf.summary.scalar("loss", net.loss) 46 | acc_summary = tf.summary.scalar("accuracy", net.accuracy) 47 | 48 | # Train Summaries 49 | train_summary_op = tf.summary.merge([loss_summary, acc_summary]) 50 | # train_summary_dir = os.path.join(out_dir, "summaries", "train") 51 | # train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 52 | 53 | # Dev summaries 54 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 55 | # dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 56 | # dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 57 | 58 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 59 | checkpoint_dir = os.path.abspath(os.path.join(os.path.curdir, "checkpoints")) 60 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 61 | if not os.path.exists(checkpoint_dir): 62 | os.makedirs(checkpoint_dir) 63 | saver = tf.train.Saver(tf.global_variables()) 64 | 65 | best_dev_loss, stop_step = sys.maxint, 0 66 | for epoch_idx in range(config.epoch_num): 67 | start = time.time() 68 | train_loss, dev_loss, train_correct_num, dev_correct_num = 0, 0, 0, 0 69 | 70 | for train_step in tqdm(range(int(math.ceil(train_set.size * 1.0 / config.batch_size))), 71 | desc='Training epoch ' + str(epoch_idx + 1) + ''): 72 | ques_batch, rela_batch, label_batch = pull_batch(train_set.ques_idx, train_set.rela_idx, 73 | train_set.label, train_step) 74 | feed_dict = { 75 | net.input_x_u: ques_batch, 76 | net.input_x_r: rela_batch, 77 | net.input_y: label_batch, 78 | } 79 | _, summaries, score, loss, accuracy, correct_num, prediction, real_label = sess.run( 80 | [train_op, train_summary_op, net.score, net.loss, net.accuracy, net.correct_num, net.predictions, 81 | net.input_y], feed_dict) 82 | train_loss += loss 83 | train_correct_num += correct_num 84 | if train_step == 0: 85 | train_score = score # score:[0, 2] 86 | else: 87 | train_score = np.concatenate((train_score, score), axis=0) 88 | 89 | for dev_step in tqdm(range(int(math.ceil(dev_set.size * 1.0 / config.batch_size))), 90 | desc='Deving epoch ' + str(epoch_idx + 1) + ''): 91 | ques_batch, rela_batch, label_batch = pull_batch(dev_set.ques_idx, dev_set.rela_idx, 92 | dev_set.label, dev_step) 93 | feed_dict = { 94 | net.input_x_u: ques_batch, 95 | net.input_x_r: rela_batch, 96 | net.input_y: label_batch, 97 | } 98 | loss, accuracy, correct_num, score, summaries = sess.run( 99 | [net.loss, net.accuracy, net.correct_num, net.score, dev_summary_op], feed_dict) 100 | dev_loss += loss 101 | dev_correct_num += correct_num 102 | if dev_step == 0: 103 | dev_score = score 104 | else: 105 | dev_score = np.concatenate((dev_score, score), axis=0) 106 | end = time.time() 107 | print( 108 | "epoch {}, time {}, train loss {:g}, train acc {:g}, dev loss {:g}, dev acc {:g}".format( 109 | epoch_idx, end - start, train_loss / train_set.size, train_correct_num / train_set.size, 110 | dev_loss / dev_set.size, dev_correct_num / dev_set.size)) 111 | 112 | if dev_loss < best_dev_loss: 113 | stop_step = 0 114 | best_dev_loss = dev_loss 115 | print('saving new best result...') 116 | # print np.array(dev_set.rela).shape, dev_score.shape 117 | 118 | saver_path = saver.save(sess, "%s.ckpt" % checkpoint_prefix) 119 | print(saver_path) 120 | 121 | util.save_data(config.train_file, config.train_output_file, train_score.tolist(), train_set.rela) 122 | util.save_data(config.dev_file, config.dev_output_file, dev_score.tolist(), dev_set.rela) 123 | else: 124 | stop_step += 1 125 | if stop_step >= config.early_step: 126 | print('early stopping') 127 | break 128 | 129 | 130 | def test(test_set, max_len, U): 131 | net = QaCNN(sequence_length=max_len, 132 | num_classes=config.num_classes, 133 | filter_sizes=config.filter_sizes, 134 | num_filters=config.num_filters, 135 | init_embeddings=U, 136 | embeddings_trainable=config.embeddings_trainable, 137 | l2_reg_lambda=config.l2_reg_lambda) 138 | # saver = tf.train.import_meta_graph("save/model.ckpt.meta") 139 | saver = tf.train.Saver() 140 | with tf.Session() as sess: 141 | sess.run(tf.global_variables_initializer()) 142 | 143 | saver.restore(sess, tf.train.latest_checkpoint("checkpoints/")) 144 | test_loss, test_correct_num = 0, 0 145 | start = time.time() 146 | for test_step in tqdm(range(int(math.ceil(test_set.size * 1.0 / config.batch_size))), 147 | desc='Testing epoch ' + ''): 148 | ques_batch, rela_batch, label_batch = pull_batch(test_set.ques_idx, test_set.rela_idx, 149 | test_set.label, test_step) 150 | feed_dict = { 151 | net.input_x_u: ques_batch, 152 | net.input_x_r: rela_batch, 153 | net.input_y: label_batch, 154 | } 155 | loss, accuracy, correct_num, score = sess.run( 156 | [net.loss, net.accuracy, net.correct_num, net.score], feed_dict) 157 | test_loss += loss 158 | test_correct_num += correct_num 159 | if test_step == 0: 160 | test_score = score 161 | else: 162 | test_score = np.concatenate((test_score, score), axis=0) 163 | util.save_data(config.test_file, config.test_output_file, test_score.tolist(), test_set.rela) 164 | end = time.time() 165 | print("time {}, test loss {:g}, train acc {:g}".format(end - start, test_loss / test_set.size, 166 | test_correct_num / test_set.size)) 167 | 168 | 169 | if __name__ == "__main__": 170 | start = time.time() 171 | 172 | print("loading word embedding...") 173 | word_dict, embedding = util.get_pretrained_word_vector(config.word2vec_file, (config.voc_size, config.emb_size)) 174 | print("vocabulary size: %d" % len(word_dict)) 175 | 176 | print("loading train data...") 177 | x_u, x_r, y, _ = util.load_data(config.train_file, True, config.neg_sample) 178 | train_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 179 | print(np.array(train_dataset.ques_idx).shape, np.array(train_dataset.rela_idx).shape, np.array( 180 | train_dataset.label).shape) 181 | 182 | print("loading dev data...") 183 | x_u, x_r, y, _ = util.load_data(config.dev_file, True, config.neg_sample) 184 | dev_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 185 | print(np.array(dev_dataset.ques_idx).shape, np.array(dev_dataset.rela_idx).shape, np.array(dev_dataset.label).shape) 186 | 187 | print("loading test data...") 188 | x_u, x_r, y, _ = util.load_data(config.test_file, False, config.num_classes) 189 | test_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 190 | print(np.array(test_dataset.ques_idx).shape, np.array(test_dataset.rela_idx).shape, np.array( 191 | test_dataset.label).shape) 192 | 193 | print("training...") 194 | train(train_dataset, dev_dataset, config.max_sent_len, embedding) 195 | 196 | print("testing...") 197 | test(test_dataset, config.max_sent_len, embedding) 198 | 199 | end = time.time() 200 | print('total time: %s' % str(end - start)) 201 | -------------------------------------------------------------------------------- /SiameseNetwork/train_nn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import config 4 | import math 5 | import numpy as np 6 | import sys 7 | import tensorflow as tf 8 | import time 9 | from qa_nn import QaNN 10 | from tqdm import tqdm 11 | from util import util 12 | from util.dataset import Dataset 13 | 14 | 15 | def pull_batch(question_list, relation_list, label_list, batch_idx): 16 | batch_size = config.batch_size 17 | if (batch_idx + 1) * batch_size < len(question_list): 18 | question_list = question_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 19 | relation_list = relation_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 20 | label_list = label_list[batch_idx * batch_size:(batch_idx + 1) * batch_size] 21 | else: # last batch 22 | question_list = question_list[batch_idx * batch_size:] 23 | relation_list = relation_list[batch_idx * batch_size:] 24 | label_list = label_list[batch_idx * batch_size:] 25 | return question_list, relation_list, label_list 26 | 27 | 28 | def train_nn(train_set, dev_set, max_len, U): 29 | nn = QaNN(sequence_length=max_len, 30 | init_embeddings=U, 31 | num_classes=config.neg_sample, 32 | embeddings_trainable=config.embeddings_trainable, 33 | l2_reg_lambda=config.l2_reg_lambda) 34 | global_step = tf.Variable(0, name="global_step", trainable=True) 35 | optimizer = tf.train.AdamOptimizer() 36 | grads_and_vars = optimizer.compute_gradients(nn.loss) 37 | train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) 38 | 39 | with tf.Session() as sess: 40 | sess.run(tf.global_variables_initializer()) 41 | 42 | # Output directory for models and summaries 43 | # timestamp = str(int(time.time())) 44 | # out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) 45 | # print("Writing to {}\n".format(out_dir)) 46 | 47 | # Summaries for loss and accuracy 48 | loss_summary = tf.summary.scalar("loss", nn.loss) 49 | acc_summary = tf.summary.scalar("accuracy", nn.accuracy) 50 | 51 | # Train Summaries 52 | train_summary_op = tf.summary.merge([loss_summary, acc_summary]) 53 | # train_summary_dir = os.path.join(out_dir, "summaries", "train") 54 | # train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 55 | 56 | # Dev summaries 57 | dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) 58 | # dev_summary_dir = os.path.join(out_dir, "summaries", "dev") 59 | # dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) 60 | 61 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 62 | checkpoint_dir = os.path.abspath(os.path.join(os.path.curdir, "checkpoints")) 63 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 64 | if not os.path.exists(checkpoint_dir): 65 | os.makedirs(checkpoint_dir) 66 | saver = tf.train.Saver(tf.global_variables()) 67 | 68 | best_dev_loss, stop_step = sys.maxint, 0 69 | for epoch_idx in range(config.epoch_num): 70 | start = time.time() 71 | train_loss, dev_loss, train_accuracy, dev_accuracy, train_correct_num, dev_correct_num = 0, 0, 0, 0, 0, 0 72 | 73 | for train_step in tqdm(range(int(math.ceil(train_set.size * 1.0 / config.batch_size))), 74 | desc='Training epoch ' + str(epoch_idx + 1) + ''): 75 | ques_batch, rela_batch, label_batch = pull_batch(train_set.ques_idx, train_set.rela_idx, 76 | train_set.label, train_step) 77 | feed_dict = { 78 | nn.input_x_u: ques_batch, 79 | nn.input_x_r: rela_batch, 80 | nn.input_y: label_batch, 81 | } 82 | _, summaries, score, loss, accuracy, correct_num = sess.run( 83 | [train_op, train_summary_op, nn.score, nn.loss, nn.accuracy, nn.correct_num], feed_dict) 84 | train_loss += loss 85 | train_accuracy += accuracy 86 | train_correct_num += correct_num 87 | if train_step == 0: 88 | train_score = score # score:[0, 2] 89 | else: 90 | train_score = np.concatenate((train_score, score), axis=0) 91 | 92 | for dev_step in tqdm(range(int(math.ceil(dev_set.size * 1.0 / config.batch_size))), 93 | desc='Deving epoch ' + str(epoch_idx + 1) + ''): 94 | ques_batch, rela_batch, label_batch = pull_batch(dev_set.ques_idx, dev_set.rela_idx, 95 | dev_set.label, dev_step) 96 | feed_dict = { 97 | nn.input_x_u: ques_batch, 98 | nn.input_x_r: rela_batch, 99 | nn.input_y: label_batch, 100 | } 101 | loss, accuracy, correct_num, score, summaries = sess.run( 102 | [nn.loss, nn.accuracy, nn.correct_num, nn.score, dev_summary_op], feed_dict) 103 | dev_loss += loss 104 | dev_accuracy += accuracy 105 | dev_correct_num += correct_num 106 | if dev_step == 0: 107 | dev_score = score 108 | else: 109 | dev_score = np.concatenate((dev_score, score), axis=0) 110 | end = time.time() 111 | print( 112 | "epoch {}, time {}, train loss {:g}, train acc {:g}, dev loss {:g}, dev acc {:g}".format( 113 | epoch_idx, end - start, train_loss / train_set.size, train_correct_num / train_set.size, 114 | dev_loss / dev_set.size, dev_correct_num / dev_set.size)) 115 | 116 | if dev_loss < best_dev_loss: 117 | stop_step = 0 118 | best_dev_loss = dev_loss 119 | print('saving new best result...') 120 | # print np.array(dev_set.rela).shape, dev_score.shape 121 | 122 | saver_path = saver.save(sess, "%s.ckpt" % checkpoint_prefix) 123 | print(saver_path) 124 | 125 | util.save_data(config.train_file, config.train_output_file, train_score.tolist(), train_set.rela) 126 | util.save_data(config.dev_file, config.dev_output_file, dev_score.tolist(), dev_set.rela) 127 | else: 128 | stop_step += 1 129 | if stop_step >= config.early_step: 130 | print('early stopping') 131 | break 132 | 133 | 134 | def test_nn(test_set, max_len, U): 135 | nn = QaNN(sequence_length=max_len, 136 | init_embeddings=U, 137 | num_classes=config.num_classes, 138 | embeddings_trainable=config.embeddings_trainable, 139 | l2_reg_lambda=config.l2_reg_lambda) 140 | # saver = tf.train.import_meta_graph("save/model.ckpt.meta") 141 | saver = tf.train.Saver() 142 | with tf.Session() as sess: 143 | sess.run(tf.global_variables_initializer()) 144 | saver.restore(sess, tf.train.latest_checkpoint("checkpoints/")) 145 | test_loss, test_correct_num = 0, 0 146 | start = time.time() 147 | for test_step in tqdm(range(int(math.ceil(test_set.size * 1.0 / config.batch_size))), 148 | desc='Testing epoch ' + ''): 149 | ques_batch, rela_batch, label_batch = pull_batch(test_set.ques_idx, test_set.rela_idx, 150 | test_set.label, test_step) 151 | feed_dict = { 152 | nn.input_x_u: ques_batch, 153 | nn.input_x_r: rela_batch, 154 | nn.input_y: label_batch, 155 | } 156 | loss, accuracy, correct_num, score = sess.run( 157 | [nn.loss, nn.accuracy, nn.correct_num, nn.score], feed_dict) 158 | test_loss += loss 159 | test_correct_num += correct_num 160 | if test_step == 0: 161 | test_score = score 162 | else: 163 | test_score = np.concatenate((test_score, score), axis=0) 164 | end = time.time() 165 | util.save_data(config.test_file, config.test_output_file, test_score.tolist(), test_set.rela) 166 | print("time {}, test loss {:g}, train acc {:g}".format(end - start, test_loss / test_set.size, 167 | test_correct_num / test_set.size)) 168 | 169 | 170 | if __name__ == "__main__": 171 | start = time.time() 172 | 173 | print("loading word embedding...") 174 | word_dict, embedding = util.get_pretrained_word_vector(config.word2vec_file, (config.voc_size, config.emb_size)) 175 | print("vocabulary size: %d" % len(word_dict)) 176 | 177 | # print "loading train data..." 178 | # x_u, x_r, y, _ = util.load_data(config.train_file, True, config.neg_sample) 179 | # train_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 180 | # print np.array(train_dataset.ques_idx).shape, np.array(train_dataset.rela_idx).shape, np.array( 181 | # train_dataset.label).shape 182 | # 183 | # print "loading dev data..." 184 | # x_u, x_r, y, _ = util.load_data(config.dev_file, True, config.neg_sample) 185 | # dev_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 186 | # print np.array(dev_dataset.ques_idx).shape, np.array(dev_dataset.rela_idx).shape, np.array(dev_dataset.label).shape 187 | 188 | print("loading test data...") 189 | x_u, x_r, y, _ = util.load_data(config.test_file, False, config.num_classes) 190 | print(np.array(x_u).shape, np.array(x_r).shape, np.array(y).shape) 191 | print(x_u[0], x_r[0], y[0]) 192 | test_dataset = Dataset(x_u, x_r, y, config.max_sent_len, word_dict) 193 | print(np.array(test_dataset.ques_idx).shape, np.array(test_dataset.rela_idx).shape, np.array( 194 | test_dataset.label).shape) 195 | 196 | # print "training..." 197 | # train_nn(train_dataset, dev_dataset, config.max_sent_len, embedding) 198 | 199 | print("testing...") 200 | test_nn(test_dataset, config.max_sent_len, embedding) 201 | 202 | end = time.time() 203 | print('total time: %s' % str(end - start)) 204 | -------------------------------------------------------------------------------- /SiameseNetwork/util/dataset.py: -------------------------------------------------------------------------------- 1 | import util 2 | 3 | 4 | class Dataset(object): 5 | def __init__(self, ques_list, rela_list, label, max_sent_len, word_dict): 6 | self.ques = ques_list # [[]] ? 7 | self.rela = rela_list 8 | self.ques_lens = [len(sent) for sent in self.ques] 9 | self.rela_lens = [len(sent) for sent in self.rela] 10 | self.word_dict = word_dict 11 | self.label = label 12 | self.size = len(label) 13 | self.max_sent_len = max_sent_len 14 | self.ques_idx, self.rela_idx = self.get_voc_idx(self.ques, self.rela) 15 | 16 | def get_voc_idx(self, ques, rela): 17 | # pad sentence 18 | pad = lambda x: util.pad_sentences(x, self.max_sent_len) 19 | pad_lst = lambda x: map(pad, x) 20 | self.ques_pad = map(pad, ques) 21 | self.rela_pad = map(pad_lst, rela) 22 | # Represent sentences as list(nparray) of ints 23 | idx_func = lambda word: self.word_dict[word] if self.word_dict.has_key(word) else self.word_dict["unk"] 24 | u_idx_func = lambda words: map(idx_func, words) 25 | v_idx_func = lambda words_list: map(u_idx_func, words_list) 26 | return map(u_idx_func, self.ques_pad), map(v_idx_func, self.rela_pad) 27 | -------------------------------------------------------------------------------- /WebQA/src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from argparse import ArgumentParser 4 | 5 | 6 | def get_kb_args(): 7 | parser = ArgumentParser(description='Simple QA model: knowledge base') 8 | parser.add_argument('--idx_entity', type=str, default='data/indexes/entity_2M.pkl') 9 | parser.add_argument('--idx_reachability', type=str, default='data/indexes/reachability_2M.pkl') 10 | parser.add_argument('--idx_name', type=str, default='data/indexes/names_2M.pkl') 11 | parser.add_argument('--idx_freebase', type=str, default='data/indexes/fb_graph.pkl') 12 | parser.add_argument('--idx_degree', type=str, default='data/indexes/degrees_2M.pkl') 13 | parser.add_argument('--wiki_path', type=str, default='data/fb2w.nt') 14 | parser.add_argument('--_path', type=str, default='data/fb2w.nt') 15 | parser.add_argument('--name_path', type=str, default='data/indexes/names_only_2M.pkl') 16 | parser.add_argument('--alias_path', type=str, default='data/indexes/alias_only_2M.pkl') 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def get_entity_args(): 23 | parser = ArgumentParser(description='Simple QA model: mention detection') 24 | parser.add_argument('--results_path', type=str, default='results') 25 | parser.add_argument('--epochs', type=int, default=30) 26 | parser.add_argument('--trained_model', type=str, default='model/mention_detection/lstm_id1_model_cpu.pt') 27 | parser.add_argument('--batch_size', type=int, default=32) 28 | parser.add_argument('--dim_hidden', type=int, default=200) 29 | parser.add_argument('--n_layers', type=int, default=2) 30 | parser.add_argument('--lr', type=float, default=1e-4) 31 | parser.add_argument('--test', action='store_true', dest='test', help='get the testing set result') 32 | parser.add_argument('--rnn_type', type=str, default='lstm') # or use 'gru' 33 | parser.add_argument('--dim_embed', type=int, default=300) 34 | parser.add_argument('--dev', action='store_true', dest='dev', help='get the development set result') 35 | parser.add_argument('--not_bidirectional', action='store_false', dest='birnn') 36 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 37 | parser.add_argument('--log_every', type=int, default=50) 38 | parser.add_argument('--dev_every', type=int, default=100) 39 | parser.add_argument('--save_every', type=int, default=4500) 40 | parser.add_argument('--dropout_prob', type=float, default=0.3) 41 | parser.add_argument('--patience', type=int, default=5, help="number of epochs to wait before early stopping") 42 | parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda') 43 | parser.add_argument('--gpu', type=int, default=-1, help='GPU device to use') # use -1 for CPU 44 | parser.add_argument('--seed', type=int, default=1111, help='random seed for reproducing results') 45 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 46 | parser.add_argument('--data_dir', type=str, default='data/processed_dataset') 47 | parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), 'data/cache')) # ? 48 | parser.add_argument('--vector_cache', type=str, 49 | default=os.path.join(os.getcwd(), 'data/cache/sq_glove300d.pt')) 50 | parser.add_argument('--word_vectors', type=str, default='glove.42B') 51 | parser.add_argument('--train_embed', action='store_false', dest='fix_emb') # fine-tune the word embeddings 52 | parser.add_argument('--resume_snapshot', type=str, default=None) 53 | 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | def get_relation_args(): 59 | parser = ArgumentParser(description='Simple QA model: relation detection') 60 | parser.add_argument('--batch_size', type=int, default=32) 61 | parser.add_argument('--rnn_type', type=str, default='gru', help="use 'gru' or 'lstm'") 62 | parser.add_argument('--dim_embed', type=int, default=300) 63 | parser.add_argument('--dim_hidden', type=int, default=200) 64 | parser.add_argument('--n_layers', type=int, default=2) 65 | parser.add_argument('--lr', type=float, default=1e-4) 66 | parser.add_argument('--lr_weight_decay', type=float, default=0.0) 67 | # parser.add_argument('--patience', type=int, default=3, help="number of epochs to wait before early stopping") 68 | parser.add_argument('--seed', type=int, default=1111, help='random seed for reproducing results') 69 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 70 | parser.add_argument('--not_bidirectional', action='store_false', dest='birnn') 71 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 72 | parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda') 73 | parser.add_argument('--gpu', type=int, default=-1, help='GPU device to use') # use -1 for CPU 74 | parser.add_argument('--dropout_prob', type=float, default=0.3) 75 | parser.add_argument('--data_dir', type=str, default='data/processed_dataset') 76 | parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), 'data/cache')) 77 | parser.add_argument('--vector_cache', type=str, 78 | default=os.path.join(os.getcwd(), 'data/cache/sq_glove300d.pt')) 79 | parser.add_argument('--word_vectors', type=str, default='glove.42B') 80 | parser.add_argument('--train_embed', action='store_false', dest='fix_emb') # fine-tune the word embeddings 81 | parser.add_argument('--trained_model', type=str, default='model/relation_detection/cnn_id1_model_cpu.pt') 82 | parser.add_argument('--hits', type=int, default=1000, help="number of top results to output") 83 | 84 | args = parser.parse_args() 85 | return args 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /WebQA/src/main.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import datetime 3 | 4 | import flask 5 | from flask import render_template, jsonify 6 | from simpleQA import SimpleQA 7 | import json 8 | import random 9 | import logging 10 | 11 | app = flask.Flask(__name__) 12 | # 日志系统配置 13 | handler = logging.FileHandler('log/app.log', encoding='UTF-8') 14 | logging_format = logging.Formatter( 15 | '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)s - %(message)s') 16 | handler.setFormatter(logging_format) 17 | app.logger.addHandler(handler) 18 | 19 | start = datetime.datetime.now() 20 | qa = SimpleQA() 21 | qa.setup() 22 | end = datetime.datetime.now() 23 | print('Time of loading knowledge base and trained models: %f' % (end - start).seconds) 24 | 25 | 26 | @app.route('/', methods=['GET', 'POST']) 27 | def homepage(): 28 | if flask.request.method == 'GET': 29 | result = {} 30 | return render_template("homepage.html", result=result) 31 | elif flask.request.method == 'POST': # and flask.request.form.get('query', None) == "SEARCH" 32 | question = flask.request.form['input'] 33 | if str(question): 34 | # app.logger.info(question) # 记录用户输入 35 | result = qa.get_answer(question) 36 | if len(result['answer']) > 10: # top 10 37 | result['answer'] = result['answer'][:10] 38 | return render_template("homepage.html", result=result) 39 | else: 40 | return render_template("homepage.html", warning="Are you kidding me ?") 41 | 42 | 43 | @app.route('/entity', methods=['GET', 'POST']) 44 | def entity(): 45 | if flask.request.method == 'GET': 46 | return render_template('entity.html') 47 | else: 48 | question = flask.request.form['input'] 49 | if str(question): 50 | # app.logger.info(question) # 记录用户输入 51 | result = qa.entity_linking(question) 52 | if len(result['answer']) > 10: # top 10 53 | result['answer'] = result['answer'][:10] 54 | return render_template("entity.html", result=result) 55 | else: 56 | return render_template("entity.html", warning="Are you kidding me ?") 57 | 58 | 59 | @app.route('/relation', methods=['GET', 'POST']) 60 | def relation(): 61 | if flask.request.method == 'GET': 62 | return render_template('relation.html') 63 | else: 64 | question = flask.request.form['input'] 65 | if str(question): 66 | # app.logger.info(question) # 记录用户输入 67 | result = qa.relation_detection(question) 68 | if len(result['answer']) > 10: # top 10 69 | result['answer'] = result['answer'][:10] 70 | return render_template("relation.html", result=result) 71 | else: 72 | return render_template("relation.html", warning="Are you kidding me ?") 73 | 74 | 75 | @app.route('/mention', methods=['GET', 'POST']) 76 | def mention(): 77 | if flask.request.method == 'GET': 78 | return render_template('mention.html') 79 | else: 80 | question = flask.request.form['InputTextBox'] 81 | if str(question): 82 | if str(flask.request.form['input']) != "recall": 83 | result = qa.mention_detection(question) 84 | return render_template("mention.html", result=result) 85 | else: 86 | result = qa.mention_detection2(question) 87 | return render_template("mention.html", result=result) 88 | else: 89 | return render_template("mention.html", warning="Are you kidding me ?") 90 | 91 | 92 | @app.route('/coming_soon', methods=['GET', 'POST']) 93 | def coming_soon(): 94 | if flask.request.method == 'GET': 95 | return render_template('coming_soon.html') 96 | else: 97 | return render_template('coming_soon.html') 98 | 99 | 100 | @app.route('/mydict', methods=['GET', 'POST']) 101 | def mydict(): 102 | with open("data/data.json") as f: 103 | json_dict = json.load(f) 104 | rand = random.randint(0, len(json_dict)) 105 | print(json_dict[rand]) 106 | # print(rand) 107 | # print(json_dict[rand]) 108 | return jsonify(json_dict[rand]) 109 | 110 | 111 | @app.route('/report', methods=['GET', 'POST']) 112 | def report(): 113 | question = flask.request.form['input'] 114 | with open("data/report.log", "a") as f: 115 | f.write(question + "\n") 116 | return 117 | 118 | if __name__ == '__main__': 119 | # homepage() 120 | app.run(host='0.0.0.0', port=4001, debug=False, processes=1) # 114.212.190.231 121 | -------------------------------------------------------------------------------- /WebQA/static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geofftong/NJU_KBQA/0b6725ff28212c51298bf3fa5b3b59059defcefe/WebQA/static/images/favicon.ico -------------------------------------------------------------------------------- /WebQA/static/images/help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geofftong/NJU_KBQA/0b6725ff28212c51298bf3fa5b3b59059defcefe/WebQA/static/images/help.png -------------------------------------------------------------------------------- /WebQA/templates/entity.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 28 | KBQA 29 | 55 | 56 | 57 | 58 | 59 | 121 | 122 |
123 |
124 |

125 | Entity Linking 126 |

127 |
128 |
129 |
130 |
131 |
132 |
133 | 135 | 136 | 137 | 138 |
139 | 140 | 149 | 150 | 153 | 154 | {% if result %} 155 | 158 | 163 |
164 |
165 | 172 |
173 |
174 | {% if result['flag'] %} 175 |
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | {% for item in result['answer']%} 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | {% endfor %} 196 | 197 |
id midentity name entity typeentity socre
{{loop.index}}{{item[0]}}{{item[1]}}{{item[2]}}{{item[3]}}
198 | {% else %} 199 |
200 |
201 | No entity found! 202 |
203 |
204 | {% endif %} 205 | 206 | {% endif %} 207 | {% if warning %} 208 | 209 | {{warning}} 210 | 211 | {% endif %} 212 |
213 | 214 |
215 | 216 |
217 | 218 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /WebQA/templates/mention.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 29 | KBQA 30 | 56 | 57 | 58 | 59 | 60 | 122 | 123 |
124 |
125 |

126 | Mention Detection 127 |

128 |
129 |
130 |
131 |
132 |
133 |
134 | 143 |

144 | 145 | 147 | 149 |

150 |
151 | 152 |
153 | 154 |
155 | 156 | 157 | 160 | 161 | {% if result %} 162 | 163 | 168 | {% if result['flag'] %} 169 |
170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | {% for item in result['answer']%} 179 | 180 | 181 | 198 | 199 | {% endfor %} 200 | 201 |
id mentions
{{loop.index}} 182 | 197 |
202 | {% else %} 203 |
204 |
205 | No answer found! 206 |
207 |
208 | {% endif %} 209 | 210 | {% endif %} 211 | {% if warning %} 212 | 213 | {{warning}} 214 | 215 | {% endif %} 216 |
217 |
218 |
219 | 220 | 257 | 258 | 259 | 260 | -------------------------------------------------------------------------------- /WebQA/templates/relation.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 27 | KBQA 28 | 54 | 55 | 56 | 57 | 58 | 120 | 121 |
122 |
123 |

124 | Relation Detection 125 |

126 |
127 |
128 |
129 |
130 |
131 |
132 | 134 | 135 | 136 | 137 |
138 | 139 | 148 | 149 | 152 | 153 | {% if result %} 154 | 157 | 162 |
163 |
164 |
165 | Question: {{result['answer'][0][2]}} 166 |
167 |
168 |
169 | {% if result['flag'] %} 170 |
171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | {% for item in result['answer']%} 181 | 182 | 183 | 184 | 185 | 186 | {% endfor %} 187 | 188 |
id relationrelation score
{{loop.index}}{{item[0]}}{{item[1]}}
189 | {% else %} 190 |
191 |
192 | No relation found! 193 |
194 |
195 | {% endif %} 196 | 197 | {% endif %} 198 | {% if warning %} 199 | 200 | {{warning}} 201 | 202 | {% endif %} 203 |
204 | 205 |
206 | 207 |
208 | 209 | 246 | 247 | 248 | 249 | -------------------------------------------------------------------------------- /WebQA/templates/test.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 搜索商品记录存储在本地 8 | 68 | 69 | 70 | 74 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /WebQA/util/datasets.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | 3 | 4 | class SimpleQADataset(data.TabularDataset): 5 | @classmethod 6 | def splits(cls, text_field, label_field, path, 7 | train='train.txt', validation='valid.txt', test='test.txt'): 8 | return super(SimpleQADataset, cls).splits( 9 | path=path, train=train, validation=validation, test=test, 10 | format='TSV', fields=[('id', None), ('sub', None), ('entity', None), ('relation', None), 11 | ('obj', None), ('text', text_field), ('ed', label_field)] 12 | ) 13 | 14 | 15 | class SimpleQaRelationDataset(data.TabularDataset): # train.txt train_relation 16 | @classmethod 17 | def splits(cls, text_field, label_field, path, 18 | train='train.txt', validation='valid.txt', test='test.txt'): 19 | return super(SimpleQaRelationDataset, cls).splits( 20 | path, '', train, validation, test, 21 | format='TSV', fields=[('id', None), ('sub', None), ('entity', None), ('relation', label_field), 22 | ('obj', None), ('text', text_field), ('ed', None)] 23 | ) 24 | -------------------------------------------------------------------------------- /WebQA/util/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import unicodedata 6 | from fuzzywuzzy import fuzz 7 | from nltk.corpus import stopwords 8 | from nltk.tokenize.treebank import TreebankWordTokenizer 9 | 10 | stopwords = set(stopwords.words('english')) 11 | tokenizer = TreebankWordTokenizer() 12 | 13 | 14 | def get_mid2wiki(filename): 15 | # print("Loading Wiki") 16 | mid2wiki = defaultdict(bool) 17 | mid2url = defaultdict() 18 | fin = open(filename) 19 | for line in fin.readlines(): 20 | items = line.strip().split('\t') 21 | if len(items) != 3: 22 | continue 23 | else: 24 | sub = rdf2fb(clean_uri(items[0])) 25 | mid2wiki[sub] = True 26 | url = items[2][1:-3] 27 | mid2url[sub] = url 28 | return mid2wiki, mid2url 29 | 30 | 31 | def convert_json(filepath): 32 | result = list() 33 | with open(filepath) as f: 34 | for line in f: 35 | test_data = dict() 36 | tokens = line.split("\t") 37 | test_data["id"] = tokens[0] 38 | test_data["relation"] = tokens[3] 39 | test_data["question"] = tokens[5] 40 | result.append(test_data) 41 | return result 42 | 43 | 44 | def get_mid2name_alias(name_path, alias_path): 45 | print("loading data from: {}".format(name_path)) 46 | with open(name_path, 'rb') as f: 47 | names = pickle.load(f) 48 | with open(alias_path, 'rb') as f: 49 | alias = pickle.load(f) 50 | return names, alias 51 | 52 | 53 | def tokenize_text(text): 54 | tokens = tokenizer.tokenize(text) 55 | return tokens 56 | 57 | 58 | def get_index(index_path): 59 | print("loading data from: {}".format(index_path)) 60 | with open(index_path, 'rb') as f: 61 | index = pickle.load(f) 62 | return index 63 | 64 | 65 | def strip_accents(text): 66 | return ''.join(c for c in unicodedata.normalize('NFKD', text) if unicodedata.category(c) != 'Mn') 67 | 68 | 69 | def get_ngram(text): 70 | # ngram = set() 71 | ngram = list() 72 | tokens = text.split() 73 | for i in range(len(tokens) + 1): 74 | for j in range(i): 75 | if i - j <= 3: # todo: 3? 76 | # ngram.add(" ".join(tokens[j:i])) 77 | temp = " ".join(tokens[j:i]) 78 | if temp not in ngram: 79 | ngram.append(temp) 80 | # ngram = list(ngram) 81 | ngram = sorted(ngram, key=lambda x: len(x.split()), reverse=True) 82 | return ngram 83 | 84 | 85 | def pick_name(question, names_list): 86 | max_score = None 87 | predict_name = None 88 | for name in names_list: 89 | score = fuzz.ratio(name, question) 90 | if score > max_score: 91 | max_score = score 92 | predict_name = name 93 | return predict_name 94 | 95 | 96 | def www2fb(in_str): 97 | out_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 98 | return out_str 99 | 100 | 101 | def rdf2fb(in_str): 102 | out_str = 'fb:%s' % (in_str.split('http://rdf.freebase.com/ns/')[-1]) 103 | return out_str 104 | 105 | 106 | class ins(object): 107 | def __init__(self, question): 108 | self.question = question 109 | 110 | 111 | def get_span(label): 112 | start, end = 0, 0 113 | flag = False 114 | span = [] 115 | for k, l in enumerate(label): 116 | if l == 'I' and not flag: 117 | start = k 118 | flag = True 119 | if l != 'I' and flag: 120 | flag = False 121 | en = k 122 | span.append((start, en)) 123 | start, end = 0, 0 124 | if start != 0 and end == 0: 125 | end = len(label) + 1 # bug fixed: geoff 126 | span.append((start, end)) 127 | return span 128 | 129 | 130 | def get_names(fb_names, cand_mids): 131 | names = list() 132 | for mid in cand_mids: 133 | if mid in fb_names: 134 | names.append(fb_names[mid][0]) # todo 135 | return names 136 | 137 | 138 | def clean_uri(uri): 139 | if uri.startswith("<") and uri.endswith(">"): 140 | return clean_uri(uri[1:-1]) 141 | elif uri.startswith("\"") and uri.endswith("\""): 142 | return clean_uri(uri[1:-1]) 143 | return uri 144 | 145 | 146 | if __name__ == '__main__': 147 | result = convert_json("../data/processed_dataset/test.txt") 148 | # Writing JSON data 149 | with open('data.json', 'w') as f: 150 | json.dump(result, f) 151 | 152 | # # Reading data back 153 | # with open('data.json', 'r') as f: 154 | # data = json.load(f) 155 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geofftong/NJU_KBQA/0b6725ff28212c51298bf3fa5b3b59059defcefe/demo.gif --------------------------------------------------------------------------------