├── common ├── __init__.py ├── container │ ├── __init__.py │ ├── question.py │ ├── uris.py │ ├── answer.py │ ├── answerrow.py │ ├── qapair.py │ ├── linkeditem.py │ ├── answerset.py │ ├── uri.py │ └── sparql.py ├── graph │ ├── __init__.py │ ├── edge.py │ ├── node.py │ ├── paths.py │ ├── path.py │ └── graph.py ├── query │ ├── __init__.py │ └── querybuilder.py ├── utility │ ├── __init__.py │ ├── stats.py │ ├── mylist.py │ └── utility.py └── preprocessing │ ├── __init__.py │ ├── wordhashing.py │ └── preprocessor.py ├── kb ├── __init__.py ├── kb.py └── dbpedia.py ├── linker ├── __init__.py ├── goldLinker.py └── earl.py ├── learning ├── __init__.py ├── classifier │ ├── __init__.py │ ├── classifier.py │ └── svmclassifier.py └── treelstm │ ├── __init__.py │ ├── Constants.py │ ├── download.sh │ ├── tree.py │ ├── metrics.py │ ├── trainer.py │ ├── utils.py │ ├── vocab.py │ ├── config.py │ ├── dataset.py │ ├── model.py │ ├── main.py │ └── preprocess_lcquad.py ├── parser ├── __init__.py ├── lc_quad_linked.py ├── answerparser.py ├── lc_quad.py └── qald.py ├── data ├── .DS_Store ├── QALD │ └── .DS_Store └── LC-QUAD │ ├── .DS_Store │ └── templates.json ├── output ├── .DS_Store ├── qald │ ├── sim.txt │ ├── id.txt │ ├── b.parents │ ├── a.parents │ ├── b.toks │ ├── a.toks │ ├── a.txt │ ├── b.txt │ ├── a.len │ ├── a.tag │ ├── a.pos │ └── a.rels ├── tmp │ ├── a.len │ ├── sim.txt │ ├── id.txt │ ├── a.tag │ ├── b.parents │ ├── a.pos │ ├── a.rels │ ├── a.parents │ ├── b.toks │ ├── b.txt │ ├── a.toks │ └── a.txt ├── double_relation_classifier │ ├── .DS_Store │ └── svm.model └── question_type_classifier │ ├── .DS_Store │ └── svm.model ├── confusion_matrix_qald.png ├── confusion_matrix_lcquad.png ├── confusion_matrix_lcquad_all.png ├── treelstm ├── Constants.py ├── __init__.py ├── metrics.py ├── tree.py ├── utils.py ├── vocab.py └── model.py ├── lcquad_dataset.py ├── config.py ├── requirements.txt ├── README.md ├── result_analysis.py ├── lcquad_answer.py └── question_type_anlaysis.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /linker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/container/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/query/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/utility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /learning/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /learning/treelstm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- 1 | from . import lc_quad_linked -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /output/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/output/.DS_Store -------------------------------------------------------------------------------- /data/QALD/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/data/QALD/.DS_Store -------------------------------------------------------------------------------- /output/qald/sim.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 1 10 | 1 11 | 1 12 | -------------------------------------------------------------------------------- /output/tmp/a.len: -------------------------------------------------------------------------------- 1 | 46 2 | 46 3 | 46 4 | 46 5 | 46 6 | 46 7 | 46 8 | 46 9 | 46 10 | 46 11 | -------------------------------------------------------------------------------- /data/LC-QUAD/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/data/LC-QUAD/.DS_Store -------------------------------------------------------------------------------- /confusion_matrix_qald.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/confusion_matrix_qald.png -------------------------------------------------------------------------------- /confusion_matrix_lcquad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/confusion_matrix_lcquad.png -------------------------------------------------------------------------------- /confusion_matrix_lcquad_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/confusion_matrix_lcquad_all.png -------------------------------------------------------------------------------- /output/qald/id.txt: -------------------------------------------------------------------------------- 1 | test 2 | test 3 | test 4 | test 5 | test 6 | test 7 | test 8 | test 9 | test 10 | test 11 | test 12 | -------------------------------------------------------------------------------- /output/double_relation_classifier/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/output/double_relation_classifier/.DS_Store -------------------------------------------------------------------------------- /output/double_relation_classifier/svm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/output/double_relation_classifier/svm.model -------------------------------------------------------------------------------- /output/question_type_classifier/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/output/question_type_classifier/.DS_Store -------------------------------------------------------------------------------- /output/question_type_classifier/svm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sylvia-Liang/QAsparql/HEAD/output/question_type_classifier/svm.model -------------------------------------------------------------------------------- /output/qald/b.parents: -------------------------------------------------------------------------------- 1 | 0 1 1 2 | 0 1 1 3 | 0 1 1 4 | 0 1 1 5 | 0 1 1 6 | 0 1 2 2 1 7 | 0 1 1 8 | 0 1 1 9 | 0 1 2 2 1 10 | 0 1 1 11 | 0 1 1 12 | -------------------------------------------------------------------------------- /learning/treelstm/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' -------------------------------------------------------------------------------- /treelstm/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' 10 | -------------------------------------------------------------------------------- /learning/treelstm/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | 5 | echo "Downloading Glove" 6 | cd glove/ 7 | wget -q -c http://www-nlp.stanford.edu/data/glove.840B.300d.zip 8 | unzip -q glove.840B.300d.zip 9 | -------------------------------------------------------------------------------- /output/tmp/sim.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 1 10 | 1 11 | 1 12 | 1 13 | 1 14 | 1 15 | 1 16 | 1 17 | 1 18 | 1 19 | 1 20 | 1 21 | 1 22 | 1 23 | 1 24 | 1 25 | 1 26 | 1 27 | 1 28 | 1 29 | 1 30 | 1 31 | 1 32 | -------------------------------------------------------------------------------- /output/qald/a.parents: -------------------------------------------------------------------------------- 1 | 6 6 2 5 6 0 6 2 | 6 6 2 5 6 0 6 3 | 6 6 2 5 6 0 6 4 | 6 6 2 5 6 0 6 5 | 6 6 2 5 6 0 6 6 | 6 6 2 5 6 0 6 7 | 6 6 2 5 6 0 6 8 | 6 6 2 5 6 0 6 9 | 6 6 2 5 6 0 6 10 | 6 6 2 5 6 0 6 11 | 6 6 2 5 6 0 6 12 | -------------------------------------------------------------------------------- /common/container/question.py: -------------------------------------------------------------------------------- 1 | class Question: 2 | def __init__(self, raw_question, parser): 3 | self.text = "" 4 | self.raw_question = raw_question 5 | self.text = parser(raw_question) 6 | 7 | def __str__(self): 8 | return self.text.encode("ascii", "ignore") 9 | # return self.text 10 | 11 | -------------------------------------------------------------------------------- /output/tmp/id.txt: -------------------------------------------------------------------------------- 1 | test 2 | test 3 | test 4 | test 5 | test 6 | test 7 | test 8 | test 9 | test 10 | test 11 | test 12 | test 13 | test 14 | test 15 | test 16 | test 17 | test 18 | test 19 | test 20 | test 21 | test 22 | test 23 | test 24 | test 25 | test 26 | test 27 | test 28 | test 29 | test 30 | test 31 | test 32 | -------------------------------------------------------------------------------- /treelstm/__init__.py: -------------------------------------------------------------------------------- 1 | from . import Constants 2 | from .dataset import LC_QUAD_Dataset 3 | from .metrics import Metrics 4 | from .model import TreeLSTM 5 | from .trainer import Trainer 6 | from .tree import Tree 7 | from . import utils 8 | from .vocab import Vocab 9 | 10 | __all__ = [Constants, LC_QUAD_Dataset, Metrics, TreeLSTM, Trainer, Tree, Vocab, utils] 11 | -------------------------------------------------------------------------------- /output/tmp/a.tag: -------------------------------------------------------------------------------- 1 | JJ NNS VBP NNP NNP NNP CC $ NN . 2 | JJ NNS VBP NNP NNP NNP CC $ NN . 3 | JJ NNS VBP NNP NNP NNP CC $ NN . 4 | JJ NNS VBP NNP NNP NNP CC $ NN . 5 | JJ NNS VBP NNP NNP NNP CC $ NN . 6 | JJ NNS VBP NNP NNP NNP CC $ NN . 7 | JJ NNS VBP NNP NNP NNP CC $ NN . 8 | JJ NNS VBP NNP NNP NNP CC $ NN . 9 | JJ NNS VBP NNP NNP NNP CC $ NN . 10 | JJ NNS VBP NNP NNP NNP CC $ NN . 11 | -------------------------------------------------------------------------------- /output/tmp/b.parents: -------------------------------------------------------------------------------- 1 | 0 1 1 2 | 0 1 1 3 | 0 1 1 4 | 0 1 2 2 1 5 | 0 1 1 6 | 0 1 1 7 | 0 1 1 8 | 0 1 2 2 1 9 | 0 1 2 2 1 10 | 0 1 2 2 1 11 | 0 1 2 3 3 2 1 12 | 0 1 1 13 | 0 1 1 14 | 0 1 1 15 | 0 1 1 16 | 0 1 2 2 1 17 | 0 1 1 18 | 0 1 1 19 | 0 1 1 20 | 0 1 1 21 | 0 1 2 2 1 22 | 0 1 1 23 | 0 1 1 24 | 0 1 1 25 | 0 1 1 26 | 0 1 2 2 1 27 | 0 1 1 28 | 0 1 1 29 | 0 1 1 30 | 0 1 1 31 | 0 1 2 2 1 32 | -------------------------------------------------------------------------------- /output/qald/b.toks: -------------------------------------------------------------------------------- 1 | wineProduced ?u_0 #ent 2 | wineProduced ?u_0 Sparkling_wine 3 | wineProduced ?u_1 ?u_0 4 | wineProduced ?u_0 Sparkling_wine 5 | wineProduced ?u_0 ?u_1 6 | wineProduced wineProduced ?u_0 ?u_1 Sparkling_wine 7 | wineProduced ?u_1 Sparkling_wine 8 | wineProduced ?u_1 ?u_0 9 | wineProduced wineProduced ?u_1 ?u_0 Sparkling_wine 10 | wineProduced ?u_1 Sparkling_wine 11 | wineProduced ?u_0 ?u_1 12 | -------------------------------------------------------------------------------- /output/qald/a.toks: -------------------------------------------------------------------------------- 1 | Where in France # ent produced ? 2 | Where in France # ent produced ? 3 | Where in France # ent produced ? 4 | Where in France # ent produced ? 5 | Where in France # ent produced ? 6 | Where in France # ent produced ? 7 | Where in France # ent produced ? 8 | Where in France # ent produced ? 9 | Where in France # ent produced ? 10 | Where in France # ent produced ? 11 | Where in France # ent produced ? 12 | -------------------------------------------------------------------------------- /output/qald/a.txt: -------------------------------------------------------------------------------- 1 | Where in France #ent produced? 2 | Where in France #ent produced? 3 | Where in France #ent produced? 4 | Where in France #ent produced? 5 | Where in France #ent produced? 6 | Where in France #ent produced? 7 | Where in France #ent produced? 8 | Where in France #ent produced? 9 | Where in France #ent produced? 10 | Where in France #ent produced? 11 | Where in France #ent produced? 12 | -------------------------------------------------------------------------------- /output/qald/b.txt: -------------------------------------------------------------------------------- 1 | ?u_0 wineProduced #ent 2 | ?u_0 wineProduced Sparkling_wine 3 | ?u_1 wineProduced ?u_0 4 | ?u_0 wineProduced Sparkling_wine 5 | ?u_0 wineProduced ?u_1 6 | ?u_0 wineProduced Sparkling_wine .?u_0 wineProduced ?u_1 7 | ?u_1 wineProduced Sparkling_wine 8 | ?u_1 wineProduced ?u_0 9 | ?u_1 wineProduced Sparkling_wine .?u_1 wineProduced ?u_0 10 | ?u_1 wineProduced Sparkling_wine 11 | ?u_0 wineProduced ?u_1 12 | -------------------------------------------------------------------------------- /treelstm/metrics.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | 5 | 6 | class Metrics(): 7 | def __init__(self): 8 | pass 9 | 10 | def accuracy_score(self, predictions, labels, vocab_output): 11 | labels = torch.tensor([vocab_output.getIndex(str(int(label))) for label in labels], dtype=torch.float) 12 | correct = (predictions == labels).sum() 13 | total = labels.size(0) 14 | acc = float(correct) / total 15 | return acc -------------------------------------------------------------------------------- /output/qald/a.len: -------------------------------------------------------------------------------- 1 | 66 2 | 66 3 | 66 4 | 54 5 | 54 6 | 54 7 | 54 8 | 54 9 | 54 10 | 54 11 | 54 12 | 54 13 | 54 14 | 54 15 | 54 16 | 54 17 | 54 18 | 54 19 | 54 20 | 54 21 | 54 22 | 54 23 | 54 24 | 54 25 | 54 26 | 54 27 | 54 28 | 54 29 | 54 30 | 54 31 | 54 32 | 54 33 | 54 34 | 54 35 | 54 36 | 54 37 | 54 38 | 54 39 | 54 40 | 54 41 | 54 42 | 54 43 | 54 44 | 54 45 | 54 46 | 54 47 | 54 48 | 54 49 | 54 50 | 54 51 | 54 52 | 54 53 | 54 54 | 54 55 | 54 56 | 54 57 | 54 58 | 54 59 | 54 60 | 54 61 | 54 62 | 54 63 | 54 64 | 54 65 | 54 66 | 54 67 | 54 68 | 54 69 | 54 70 | 54 71 | 54 72 | 54 73 | -------------------------------------------------------------------------------- /output/tmp/a.pos: -------------------------------------------------------------------------------- 1 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 2 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 3 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 4 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 5 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 6 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 7 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 8 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 9 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 10 | ADJ NOUN VERB PROPN PROPN PROPN CCONJ SYM NOUN PUNCT 11 | -------------------------------------------------------------------------------- /common/utility/stats.py: -------------------------------------------------------------------------------- 1 | class Stats(dict): 2 | def __init__(self, *args): 3 | dict.__init__(self, args) 4 | 5 | def __getitem__(self, key): 6 | if key not in self: 7 | return 0 8 | return dict.__getitem__(self, key) 9 | 10 | def inc(self, key, value=1): 11 | if key not in self: 12 | self[key] = 0 13 | self[key] += value 14 | 15 | def __str__(self): 16 | keys = self.keys() 17 | # keys.sort() 18 | keys = sorted(keys) 19 | return " ".join([key + ":" + str(self[key]) for key in keys]) 20 | -------------------------------------------------------------------------------- /output/tmp/a.rels: -------------------------------------------------------------------------------- 1 | poss nsubj ROOT compound compound attr cc nmod conj punct 2 | poss nsubj ROOT compound compound attr cc nmod conj punct 3 | poss nsubj ROOT compound compound attr cc nmod conj punct 4 | poss nsubj ROOT compound compound attr cc nmod conj punct 5 | poss nsubj ROOT compound compound attr cc nmod conj punct 6 | poss nsubj ROOT compound compound attr cc nmod conj punct 7 | poss nsubj ROOT compound compound attr cc nmod conj punct 8 | poss nsubj ROOT compound compound attr cc nmod conj punct 9 | poss nsubj ROOT compound compound attr cc nmod conj punct 10 | poss nsubj ROOT compound compound attr cc nmod conj punct 11 | -------------------------------------------------------------------------------- /common/container/uris.py: -------------------------------------------------------------------------------- 1 | class URIs(list): 2 | def __init__(self, *args): 3 | super(URIs, self).__init__(*args) 4 | 5 | def __eq__(self, other): 6 | if isinstance(other, URIs): 7 | if len(self) != len(other): 8 | return False 9 | for uri in self: 10 | found = False 11 | for other_uri in other: 12 | if uri.generic_equal(other_uri): 13 | found = True 14 | break 15 | if not found: 16 | return False 17 | return True 18 | return NotImplemented 19 | -------------------------------------------------------------------------------- /common/container/answer.py: -------------------------------------------------------------------------------- 1 | class Answer: 2 | def __init__(self, answer_type, raw_answer, parser): 3 | self.raw_answer = raw_answer 4 | self.answer_type, self.answer = parser(answer_type, raw_answer) 5 | 6 | def __eq__(self, other): 7 | if isinstance(other, Answer): 8 | return self.answer == other.answer 9 | return NotImplemented 10 | 11 | def __str__(self): 12 | if self.answer_type == "bool": 13 | return str(self.answer) 14 | elif self.answer_type == "uri": 15 | # return self.answer.__str__().encode("ascii", "ignore") 16 | return self.answer.__str__() 17 | 18 | # return self.answer.encode("ascii", "ignore") 19 | return self.answer 20 | -------------------------------------------------------------------------------- /output/tmp/a.parents: -------------------------------------------------------------------------------- 1 | 3 3 0 3 4 5 8 5 3 2 | 3 3 0 3 4 5 8 5 3 3 | 3 3 0 3 4 5 8 5 3 4 | 3 3 0 3 4 5 8 5 3 5 | 3 3 0 3 4 5 8 5 3 6 | 3 3 0 3 4 5 8 5 3 7 | 3 3 0 3 4 5 8 5 3 8 | 3 3 0 3 4 5 8 5 3 9 | 3 3 0 3 4 5 8 5 3 10 | 3 3 0 3 4 5 8 5 3 11 | 3 3 0 3 4 5 8 5 3 12 | 3 3 0 3 4 5 8 5 3 13 | 3 3 0 3 4 5 8 5 3 14 | 3 3 0 3 4 5 8 5 3 15 | 3 3 0 3 4 5 8 5 3 16 | 3 3 0 3 4 5 8 5 3 17 | 3 3 0 3 4 5 8 5 3 18 | 3 3 0 3 4 5 8 5 3 19 | 3 3 0 3 4 5 8 5 3 20 | 3 3 0 3 4 5 8 5 3 21 | 3 3 0 3 4 5 8 5 3 22 | 3 3 0 3 4 5 8 5 3 23 | 3 3 0 3 4 5 8 5 3 24 | 3 3 0 3 4 5 8 5 3 25 | 3 3 0 3 4 5 8 5 3 26 | 3 3 0 3 4 5 8 5 3 27 | 3 3 0 3 4 5 8 5 3 28 | 3 3 0 3 4 5 8 5 3 29 | 3 3 0 3 4 5 8 5 3 30 | 3 3 0 3 4 5 8 5 3 31 | 3 3 0 3 4 5 8 5 3 32 | -------------------------------------------------------------------------------- /common/container/answerrow.py: -------------------------------------------------------------------------------- 1 | class AnswerRow: 2 | def __init__(self, raw_answers, parser): 3 | self.raw_answers = raw_answers 4 | self.answers = parser(raw_answers) 5 | 6 | def number_of_answer(self): 7 | return len(self.answers) 8 | 9 | def __eq__(self, other): 10 | if isinstance(other, AnswerRow): 11 | if len(self.answers) != len(other.answers): 12 | return False 13 | for answer in self.answers: 14 | found = False 15 | for other_answer in other.answers: 16 | if answer == other_answer: 17 | found = True 18 | if not found: 19 | return False 20 | return True 21 | return NotImplemented 22 | -------------------------------------------------------------------------------- /common/utility/mylist.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | 4 | class MyList(list): 5 | def __init__(self, *args): 6 | super(MyList, self).__init__(*args) 7 | 8 | def __le__(self, other): 9 | l1 = Counter(self) 10 | l2 = Counter(other) 11 | counts = l1 - l2 12 | return len(counts) == 0 13 | 14 | def __sub__(self, other): 15 | try: 16 | other_ = other 17 | if isinstance(other, set): 18 | other_ = list(other) 19 | l1 = Counter(self) 20 | l2 = Counter(other_) 21 | output = [] 22 | counts = l1 - l2 23 | for item in counts: 24 | output.extend([item] * counts[item]) 25 | 26 | return self.__class__(output) 27 | except: 28 | return [] 29 | -------------------------------------------------------------------------------- /linker/goldLinker.py: -------------------------------------------------------------------------------- 1 | from common.container.linkeditem import LinkedItem 2 | from common.utility.utility import find_mentions 3 | 4 | 5 | class GoldLinker: 6 | def __init__(self): 7 | pass 8 | 9 | def do(self, qapair, force_gold=False, top=5): 10 | entities = [] 11 | relations = [] 12 | for u in qapair.sparql.uris: 13 | question = qapair.question.text 14 | mentions = find_mentions(question, [u]) 15 | surface = "" 16 | if len(mentions) > 0: 17 | surface = question[mentions[0]["start"]:mentions[0]["end"]] 18 | 19 | linked_item = LinkedItem(surface, [u]) 20 | if u.is_entity(): 21 | entities.append(linked_item) 22 | if u.is_ontology(): 23 | relations.append(linked_item) 24 | 25 | return entities, relations 26 | -------------------------------------------------------------------------------- /common/container/qapair.py: -------------------------------------------------------------------------------- 1 | from common.container.question import Question 2 | from common.container.answerset import AnswerSet 3 | from common.container.sparql import SPARQL 4 | 5 | class QApair: 6 | def __init__(self, raw_question, raw_answerset, raw_query, raw_row, id, parser): 7 | self.raw_row = raw_row 8 | self.question = [] 9 | self.sparql = [] 10 | self.id = id 11 | 12 | self.question = Question(raw_question, parser.parse_question) 13 | self.answerset = AnswerSet(raw_answerset, parser.parse_answerset) 14 | self.sparql = SPARQL(raw_query, parser.parse_sparql) 15 | 16 | def question_template(self, entity_relation_list): 17 | question = self.question.text.lower() 18 | for item in entity_relation_list: 19 | question = question.replace(item.label.lower(), item.uri.type) 20 | return question 21 | 22 | def __str__(self): 23 | return "{}\n{}\n{}".format(self.question, self.answerset, self.sparql) 24 | -------------------------------------------------------------------------------- /output/tmp/b.toks: -------------------------------------------------------------------------------- 1 | veneratedIn ?u_0 #ent 2 | veneratedIn ?u_0 #ent 3 | veneratedIn ?u_1 ?u_0 4 | veneratedIn veneratedIn ?u_0 Islam Judaism 5 | veneratedIn ?u_0 Judaism 6 | veneratedIn ?u_0 Islam 7 | veneratedIn ?u_0 ?u_1 8 | veneratedIn veneratedIn ?u_0 Islam Judaism 9 | veneratedIn veneratedIn ?u_0 ?u_1 Judaism 10 | veneratedIn veneratedIn ?u_0 ?u_1 Islam 11 | veneratedIn veneratedIn veneratedIn ?u_0 ?u_1 Islam Judaism 12 | veneratedIn ?u_1 Judaism 13 | veneratedIn ?u_0 Islam 14 | veneratedIn ?u_1 ?u_0 15 | veneratedIn ?u_1 Judaism 16 | veneratedIn veneratedIn ?u_1 ?u_0 Judaism 17 | veneratedIn ?u_0 Judaism 18 | veneratedIn ?u_1 Islam 19 | veneratedIn ?u_1 ?u_0 20 | veneratedIn ?u_0 Judaism 21 | veneratedIn veneratedIn ?u_1 ?u_0 Islam 22 | veneratedIn ?u_1 Judaism 23 | veneratedIn ?u_0 Islam 24 | veneratedIn ?u_0 ?u_1 25 | veneratedIn ?u_1 Judaism 26 | veneratedIn veneratedIn ?u_0 ?u_1 Islam 27 | veneratedIn ?u_0 Judaism 28 | veneratedIn ?u_1 Islam 29 | veneratedIn ?u_0 ?u_1 30 | veneratedIn ?u_0 Judaism 31 | veneratedIn veneratedIn ?u_0 ?u_1 Judaism 32 | -------------------------------------------------------------------------------- /treelstm/tree.py: -------------------------------------------------------------------------------- 1 | # tree object from stanfordnlp/treelstm 2 | class Tree(object): 3 | def __init__(self): 4 | self.parent = None 5 | self.num_children = 0 6 | self.children = list() 7 | 8 | def add_child(self, child): 9 | child.parent = self 10 | self.num_children += 1 11 | self.children.append(child) 12 | 13 | def size(self): 14 | if hasattr(self, '_size'): 15 | return self._size 16 | count = 1 17 | for i in range(self.num_children): 18 | count += self.children[i].size() 19 | self._size = count 20 | return self._size 21 | 22 | def depth(self): 23 | if getattr(self, '_depth'): 24 | return self._depth 25 | count = 0 26 | if self.num_children > 0: 27 | for i in range(self.num_children): 28 | child_depth = self.children[i].depth() 29 | if child_depth > count: 30 | count = child_depth 31 | count += 1 32 | self._depth = count 33 | return self._depth 34 | -------------------------------------------------------------------------------- /learning/treelstm/tree.py: -------------------------------------------------------------------------------- 1 | # tree object from stanfordnlp/treelstm 2 | class Tree(object): 3 | def __init__(self): 4 | self.parent = None 5 | self.num_children = 0 6 | self.children = list() 7 | 8 | def add_child(self,child): 9 | child.parent = self 10 | self.num_children += 1 11 | self.children.append(child) 12 | 13 | def size(self): 14 | if getattr(self,'_size'): 15 | return self._size 16 | count = 1 17 | for i in range(self.num_children): 18 | count += self.children[i].size() 19 | self._size = count 20 | return self._size 21 | 22 | def depth(self): 23 | if getattr(self,'_depth'): 24 | return self._depth 25 | count = 0 26 | if self.num_children>0: 27 | for i in range(self.num_children): 28 | child_depth = self.children[i].depth() 29 | if child_depth>count: 30 | count = child_depth 31 | count += 1 32 | self._depth = count 33 | return self._depth 34 | -------------------------------------------------------------------------------- /common/container/linkeditem.py: -------------------------------------------------------------------------------- 1 | class LinkedItem: 2 | def __init__(self, surface_form, uris): 3 | self.surface_form = surface_form 4 | self.uris = uris 5 | 6 | def top_uris(self, top=1): 7 | return self.uris[:int(top * len(self.uris))] 8 | 9 | def contains_uri(self, uri): 10 | """ 11 | Whether the uri exists in the list of uris 12 | :param uri: 13 | :return: Bool 14 | """ 15 | return uri in self.uris 16 | 17 | @staticmethod 18 | def list_contains_uris(linkeditem_list, uris): 19 | """ 20 | Returns the linkedItems that contain any of the uris, 21 | but only one linkedItem per uri 22 | :param linkeditem_list: List of LinkedItem 23 | :param uris: 24 | :return: 25 | """ 26 | output = [] 27 | for uri in sorted(uris, key=lambda x: len(str(x)), reverse=True): 28 | for item in linkeditem_list: 29 | if item not in output and item.contains_uri(uri): 30 | output.append(item) 31 | break 32 | return output -------------------------------------------------------------------------------- /common/container/answerset.py: -------------------------------------------------------------------------------- 1 | class AnswerSet: 2 | def __init__(self, raw_answerset, parser): 3 | self.raw_answerset = raw_answerset 4 | self.answer_rows = [] 5 | self.answer_rows = parser(raw_answerset) 6 | 7 | def number_of_answer(self): 8 | return self.answer_rows[0].number_of_answer() 9 | 10 | def __eq__(self, other): 11 | if isinstance(other, AnswerSet): 12 | if len(self.answer_rows) != len(other.answer_rows): 13 | return False 14 | for answers in self.answer_rows: 15 | found = False 16 | for other_answers in other.answer_rows: 17 | if answers == other_answers: 18 | found = True 19 | break 20 | if not found: 21 | return False 22 | return True 23 | return NotImplemented 24 | 25 | def intersect(self, other): 26 | count = 0 27 | if isinstance(other, AnswerSet): 28 | for answers in self.answer_rows: 29 | for other_answers in other.answer_rows: 30 | if answers == other_answers: 31 | count += 1 32 | break 33 | return count 34 | return NotImplemented 35 | 36 | def __len__(self): 37 | return len(self.answer_rows) 38 | 39 | def __str__(self): 40 | return "\n".join(str(a) for a in self.answer_rows) 41 | -------------------------------------------------------------------------------- /output/tmp/b.txt: -------------------------------------------------------------------------------- 1 | ?u_0 veneratedIn #ent 2 | ?u_0 veneratedIn #ent 3 | ?u_1 veneratedIn ?u_0 4 | ?u_0 veneratedIn Judaism .?u_0 veneratedIn Islam 5 | ?u_0 veneratedIn Judaism 6 | ?u_0 veneratedIn Islam 7 | ?u_0 veneratedIn ?u_1 8 | ?u_0 veneratedIn Judaism .?u_0 veneratedIn Islam 9 | ?u_0 veneratedIn Judaism .?u_0 veneratedIn ?u_1 10 | ?u_0 veneratedIn Islam .?u_0 veneratedIn ?u_1 11 | ?u_0 veneratedIn Judaism .?u_0 veneratedIn Islam .?u_0 veneratedIn ?u_1 12 | ?u_1 veneratedIn Judaism 13 | ?u_0 veneratedIn Islam 14 | ?u_1 veneratedIn ?u_0 15 | ?u_1 veneratedIn Judaism .?u_0 veneratedIn Islam 16 | ?u_1 veneratedIn Judaism .?u_1 veneratedIn ?u_0 17 | ?u_0 veneratedIn Judaism 18 | ?u_1 veneratedIn Islam 19 | ?u_1 veneratedIn ?u_0 20 | ?u_0 veneratedIn Judaism .?u_1 veneratedIn Islam 21 | ?u_1 veneratedIn Islam .?u_1 veneratedIn ?u_0 22 | ?u_1 veneratedIn Judaism 23 | ?u_0 veneratedIn Islam 24 | ?u_0 veneratedIn ?u_1 25 | ?u_1 veneratedIn Judaism .?u_0 veneratedIn Islam 26 | ?u_0 veneratedIn Islam .?u_0 veneratedIn ?u_1 27 | ?u_0 veneratedIn Judaism 28 | ?u_1 veneratedIn Islam 29 | ?u_0 veneratedIn ?u_1 30 | ?u_0 veneratedIn Judaism .?u_1 veneratedIn Islam 31 | ?u_0 veneratedIn Judaism .?u_0 veneratedIn ?u_1 32 | -------------------------------------------------------------------------------- /learning/treelstm/metrics.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from sklearn.metrics import precision_recall_fscore_support 3 | import torch 4 | 5 | 6 | class Metrics(): 7 | def __init__(self, num_classes): 8 | self.num_classes = num_classes 9 | 10 | def all(self, predictions, labels): 11 | return "\tPearson: {}\tMSE: {}, \tF1: {}".format(self.pearson(predictions, labels), 12 | self.mse(predictions, labels), 13 | self.f1(predictions, labels)) 14 | 15 | def f1(self, predictions, labels): 16 | try: 17 | y_true = list(labels) 18 | y_pred = map(round, predictions) 19 | precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro') 20 | return precision, recall, f1 21 | except: 22 | return 0, 0, 0 23 | 24 | def pearson(self, predictions, labels): 25 | x = deepcopy(predictions) 26 | y = deepcopy(labels) 27 | x = (x - x.mean()) / x.std() 28 | y = (y - y.mean()) / y.std() 29 | return torch.mean(torch.mul(x, y)) 30 | 31 | def mse(self, predictions, labels): 32 | x = deepcopy(predictions) 33 | y = deepcopy(labels) 34 | return torch.mean((x - y) ** 2) 35 | -------------------------------------------------------------------------------- /output/tmp/a.toks: -------------------------------------------------------------------------------- 1 | Who is venerated in Judaism and # ent ? 2 | Who is venerated in Judaism and # ent ? 3 | Who is venerated in Judaism and # ent ? 4 | Who is venerated in Judaism and # ent ? 5 | Who is venerated in Judaism and # ent ? 6 | Who is venerated in Judaism and # ent ? 7 | Who is venerated in Judaism and # ent ? 8 | Who is venerated in Judaism and # ent ? 9 | Who is venerated in Judaism and # ent ? 10 | Who is venerated in Judaism and # ent ? 11 | Who is venerated in Judaism and # ent ? 12 | Who is venerated in Judaism and # ent ? 13 | Who is venerated in Judaism and # ent ? 14 | Who is venerated in Judaism and # ent ? 15 | Who is venerated in Judaism and # ent ? 16 | Who is venerated in Judaism and # ent ? 17 | Who is venerated in Judaism and # ent ? 18 | Who is venerated in Judaism and # ent ? 19 | Who is venerated in Judaism and # ent ? 20 | Who is venerated in Judaism and # ent ? 21 | Who is venerated in Judaism and # ent ? 22 | Who is venerated in Judaism and # ent ? 23 | Who is venerated in Judaism and # ent ? 24 | Who is venerated in Judaism and # ent ? 25 | Who is venerated in Judaism and # ent ? 26 | Who is venerated in Judaism and # ent ? 27 | Who is venerated in Judaism and # ent ? 28 | Who is venerated in Judaism and # ent ? 29 | Who is venerated in Judaism and # ent ? 30 | Who is venerated in Judaism and # ent ? 31 | Who is venerated in Judaism and # ent ? 32 | -------------------------------------------------------------------------------- /output/tmp/a.txt: -------------------------------------------------------------------------------- 1 | Who is venerated in #ent and Islam? 2 | Who is venerated in #ent and #ent ? 3 | Who is venerated in #ent and #ent ? 4 | Who is venerated in #ent and #ent ? 5 | Who is venerated in #ent and #ent ? 6 | Who is venerated in #ent and #ent ? 7 | Who is venerated in #ent and #ent ? 8 | Who is venerated in #ent and #ent ? 9 | Who is venerated in #ent and #ent ? 10 | Who is venerated in #ent and #ent ? 11 | Who is venerated in #ent and #ent ? 12 | Who is venerated in #ent and #ent ? 13 | Who is venerated in #ent and #ent ? 14 | Who is venerated in #ent and #ent ? 15 | Who is venerated in #ent and #ent ? 16 | Who is venerated in #ent and #ent ? 17 | Who is venerated in #ent and #ent ? 18 | Who is venerated in #ent and #ent ? 19 | Who is venerated in #ent and #ent ? 20 | Who is venerated in #ent and #ent ? 21 | Who is venerated in #ent and #ent ? 22 | Who is venerated in #ent and #ent ? 23 | Who is venerated in #ent and #ent ? 24 | Who is venerated in #ent and #ent ? 25 | Who is venerated in #ent and #ent ? 26 | Who is venerated in #ent and #ent ? 27 | Who is venerated in #ent and #ent ? 28 | Who is venerated in #ent and #ent ? 29 | Who is venerated in #ent and #ent ? 30 | Who is venerated in #ent and #ent ? 31 | Who is venerated in #ent and #ent ? 32 | -------------------------------------------------------------------------------- /common/container/uri.py: -------------------------------------------------------------------------------- 1 | class Uri: 2 | def __init__(self, raw_uri, parser, confidence=1.0): 3 | self.raw_uri = raw_uri 4 | self.uri_type, self.uri = parser(raw_uri) 5 | self.__str = u"{}:{}".format(self.uri_type, self.uri[self.uri.rfind("/") + 1:].encode("ascii", "ignore").decode()) 6 | self.__hash = hash(self.__str) 7 | self.confidence = confidence 8 | 9 | def is_generic(self): 10 | return self.uri_type == "g" 11 | 12 | def is_entity(self): 13 | return self.uri_type == "?s" 14 | 15 | def is_ontology(self): 16 | return self.uri_type == "?p" or self.uri_type == "?o" 17 | 18 | def is_type(self): 19 | return self.uri_type == "?t" 20 | 21 | def sparql_format(self, kb): 22 | return kb.uri_to_sparql(self) 23 | 24 | def generic_id(self): 25 | if self.is_generic(): 26 | return int(self.uri[3:]) 27 | return None 28 | 29 | def generic_equal(self, other): 30 | return (self.is_generic() and other.is_generic()) or self == other 31 | 32 | def __eq__(self, other): 33 | if isinstance(other, Uri): 34 | return self.uri == other.uri 35 | return NotImplemented 36 | 37 | def __hash__(self): 38 | return self.__hash 39 | 40 | def __str__(self): 41 | return self.__str 42 | 43 | @staticmethod 44 | def generic_uri(var_num): 45 | return Uri("g", lambda r: ("g", "?u_{}".format(var_num))) 46 | -------------------------------------------------------------------------------- /common/preprocessing/wordhashing.py: -------------------------------------------------------------------------------- 1 | class WordHashing: 2 | def __init__(self): 3 | self.ids = {} 4 | 5 | def to_n_gams(self, text, n=3): 6 | text = text.lower() 7 | words = "".join([x if (('a' <= x <= 'z') or x == ' ') else ' ' for x in text]) 8 | words = words.split() 9 | res = [] 10 | for word in words: 11 | padded_word = "#{}#".format(word) 12 | seq = [] 13 | for i in range(n): 14 | seq.append(padded_word[i:]) 15 | n_tuples = zip(*seq) 16 | seq = ["".join(x) for x in n_tuples] 17 | res.extend(seq) 18 | return res 19 | 20 | def __encode_n_grams(self, input): 21 | ids = [] 22 | for term in input: 23 | if term in self.ids: 24 | ids.append(self.ids[term]) 25 | else: 26 | term_id = len(self.ids) 27 | self.ids[term] = term_id 28 | ids.append(term_id) 29 | return ids 30 | 31 | def hash(self, text, n=3): 32 | return self.__encode_n_grams(self.to_n_gams(text, n)) 33 | 34 | def save(self, path): 35 | with open(path, 'w') as dict_file: 36 | for kv in self.ids.iteritems(): 37 | dict_file.write('{} {}\n'.format(kv[0], kv[1])) 38 | dict_file.close() 39 | 40 | def load(self, path): 41 | self.ids = {} 42 | with open(path) as dict_file: 43 | for line in dict_file: 44 | kv = line.strip("\n").split(" ") 45 | self.ids[kv[0]] = int(kv[1]) 46 | -------------------------------------------------------------------------------- /parser/lc_quad_linked.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from common.container.qapair import QApair 4 | from common.container.uri import Uri 5 | from common.container.uris import URIs 6 | from kb.dbpedia import DBpedia 7 | from parser.answerparser import AnswerParser 8 | 9 | 10 | class LC_Qaud_Linked: 11 | def __init__(self, path="./data/LC-QUAD/linked.json"): 12 | self.raw_data = [] 13 | self.qapairs = [] 14 | self.path = path 15 | self.parser = LC_Qaud_LinkedParser() 16 | 17 | def load(self): 18 | with open(self.path) as data_file: 19 | self.raw_data = json.load(data_file) 20 | 21 | def parse(self): 22 | for raw_row in self.raw_data: 23 | self.qapairs.append( 24 | QApair(raw_row["question"], raw_row.get("answers"), raw_row["sparql_query"], raw_row, raw_row["id"], 25 | self.parser)) 26 | 27 | def print_pairs(self, n=-1): 28 | for item in self.qapairs[0:n]: 29 | print(item) 30 | print("") 31 | 32 | 33 | class LC_Qaud_LinkedParser(AnswerParser): 34 | def __init__(self): 35 | super(LC_Qaud_LinkedParser, self).__init__(DBpedia(one_hop_bloom_file="./data/blooms/spo1.bloom")) 36 | 37 | def parse_question(self, raw_question): 38 | return raw_question 39 | 40 | def parse_answerset(self, raw_answers): 41 | return self.parse_queryresult(raw_answers) 42 | 43 | def parse_sparql(self, raw_query): 44 | raw_query = raw_query.replace("https://", "http://") 45 | uris = URIs([Uri(raw_uri, self.kb.parse_uri) for raw_uri in re.findall('(<[^>]*>|\?[^ ]*)', raw_query)]) 46 | return raw_query, True, uris 47 | -------------------------------------------------------------------------------- /parser/answerparser.py: -------------------------------------------------------------------------------- 1 | from common.container.answer import Answer 2 | from common.container.answerrow import AnswerRow 3 | from common.container.uri import Uri 4 | 5 | 6 | class AnswerParser(object): 7 | def __init__(self, kb): 8 | self.kb = kb 9 | 10 | def parse_queryresult(self, raw_answerset): 11 | answer_rows = [] 12 | if raw_answerset is None: 13 | return answer_rows 14 | if "boolean" in raw_answerset: 15 | return [ 16 | AnswerRow(raw_answerset, lambda x: [Answer("bool", raw_answerset["boolean"], lambda at, ra: (at, ra))])] 17 | if "results" in raw_answerset and "bindings" in raw_answerset["results"] \ 18 | and len(raw_answerset["results"]["bindings"]) > 0: 19 | for raw_answerrow in raw_answerset["results"]["bindings"]: 20 | answer_rows.append(AnswerRow(raw_answerrow, self.__parse_answerrow)) 21 | elif "string" in raw_answerset: 22 | return [ 23 | AnswerRow(raw_answerset, lambda x: [Answer("uri", raw_answerset["string"], self.__parse_answer)])] 24 | 25 | 26 | return answer_rows 27 | 28 | def __parse_answerrow(self, raw_answerrow): 29 | answers = [] 30 | for var_id in raw_answerrow: 31 | answers.append(Answer(raw_answerrow[var_id]["type"], raw_answerrow[var_id]["value"], self.__parse_answer)) 32 | return answers 33 | 34 | def __parse_answer(self, answer_type, raw_answer): 35 | prefix = self.kb.prefix() 36 | if len(prefix) > 0 and raw_answer.startswith(prefix): 37 | raw_answer = self.kb.shorten_prefix() + raw_answer[len(prefix):] 38 | return answer_type, Uri(raw_answer, self.kb.parse_uri) 39 | -------------------------------------------------------------------------------- /common/preprocessing/preprocessor.py: -------------------------------------------------------------------------------- 1 | from linker.jerrl import Jerrl 2 | from common.preprocessing.wordhashing import WordHashing 3 | import numpy as np 4 | 5 | 6 | class Preprocessor: 7 | @staticmethod 8 | def qapair_to_hash(question_answer_uris): 9 | jerrl = Jerrl() 10 | hashing = WordHashing() 11 | hashed_qapairs = {} 12 | for data_item in question_answer_uris: 13 | question = data_item["question"] 14 | query = data_item["query"] 15 | counter = 0 16 | for item in jerrl.find_mentions(question, data_item["uris"]): 17 | question = u"{0} {1} {2}".format(question[:item["start"]], str(counter), question[item["end"]:]) 18 | query = query.replace(item["uri"].raw_uri, str(counter)) 19 | counter += 1 20 | 21 | hashed_question = hashing.hash(question) 22 | hashed_sparql = hashing.hash(query) 23 | hashed_qapairs[data_item["id"]] = (hashed_question, hashed_sparql) 24 | 25 | VOCAB_SIZE = 50000 # len(hashing.ids) 26 | DATA_SIZE = len(question_answer_uris) 27 | i = 0 28 | questions = np.zeros([DATA_SIZE, VOCAB_SIZE]) 29 | sparqls = np.zeros([DATA_SIZE, VOCAB_SIZE]) 30 | ids = [] 31 | for kv in hashed_qapairs.iteritems(): 32 | question_unique_counts = np.unique(kv[1][0], return_counts=True) 33 | questions[i, question_unique_counts[0]] = question_unique_counts[1] 34 | 35 | sparql_unique_counts = np.unique(kv[1][1], return_counts=True) 36 | sparqls[i, sparql_unique_counts[0]] = sparql_unique_counts[1] 37 | ids.append(kv[0]) 38 | i += 1 39 | 40 | return questions, sparqls, ids 41 | -------------------------------------------------------------------------------- /learning/classifier/classifier.py: -------------------------------------------------------------------------------- 1 | #from sklearn.externals import joblib 2 | import joblib 3 | from sklearn.model_selection import GridSearchCV 4 | import os 5 | import pickle 6 | 7 | 8 | class Classifier(object): 9 | def __init__(self, model_file_path): 10 | self.model_file_path = model_file_path 11 | if self.model_file_path is not None and os.path.exists(self.model_file_path): 12 | self.load(model_file_path) 13 | else: 14 | self.model = None 15 | 16 | # def __pipeline(self): 17 | # pass 18 | 19 | @property 20 | def is_trained(self): 21 | return self.model is not None 22 | 23 | def save(self, file_path): 24 | joblib.dump(self.model, file_path) 25 | # pickle.dump(self.model, file_path) 26 | 27 | def load(self, file_path): 28 | self.model = joblib.load(file_path) 29 | # with open(file_path, 'rb') as pickle_file: 30 | # self.model = pickle.load(pickle_file) 31 | 32 | def train(self, X_train, y_train): 33 | optimized_classifier = GridSearchCV(self.pipeline, self.parameters, n_jobs=-1, cv=10) 34 | # optimized_classifier = self.pipeline 35 | self.model = optimized_classifier.fit(X_train, y_train) 36 | # print('cv_results_: ', optimized_classifier.cv_results_) 37 | print('best_score_: ', optimized_classifier.best_score_) 38 | print('best_params_: ', optimized_classifier.best_params_) 39 | print('cv_results_: ', optimized_classifier.cv_results_['mean_test_score']) 40 | if self.model_file_path is not None: 41 | self.save(self.model_file_path) 42 | return self.model.best_score_ 43 | 44 | def predict(self, X_test): 45 | if self.is_trained: 46 | return self.model.predict(X_test) 47 | else: 48 | return None 49 | 50 | def predict_proba(self, X_test): 51 | if self.is_trained: 52 | return self.model.predict_proba(X_test) 53 | else: 54 | return None 55 | 56 | 57 | -------------------------------------------------------------------------------- /lcquad_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests, json, re, operator 3 | import sys 4 | from parser.lc_quad import LC_Qaud 5 | 6 | 7 | def prepare_dataset(ds): 8 | ds.load() 9 | ds.parse() 10 | return ds 11 | 12 | 13 | def ask_query(uri): 14 | if uri == "": 15 | return 200, json.loads("{\"boolean\": \"True\"}") 16 | uri = uri.replace("https://", "http://") 17 | return query(u'ASK WHERE {{ {} ?u ?x }}'.format(uri)) 18 | 19 | 20 | def query(q): 21 | q = q.replace("https://", "http://") 22 | payload = ( 23 | ('query', q), 24 | ('format', 'application/json')) 25 | 26 | r = requests.get('http://dbpedia.org/sparql', params=payload) 27 | return r.status_code, r.json() 28 | 29 | 30 | def has_answer(t): 31 | if "results" in t and len(t["results"]["bindings"]) > 0: 32 | return True 33 | if "boolean" in t: 34 | return True 35 | return False 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | with open('data/LC-QUAD/train-data.json', 'r', encoding='utf-8') as f: 41 | train = json.load(f) 42 | 43 | with open('data/LC-QUAD/test-data.json', 'r', encoding='utf-8') as f: 44 | test = json.load(f) 45 | 46 | data = train + test 47 | print('data len: ', len(data)) 48 | 49 | with open("data/LC-QUAD/data.json", "w") as write_file: 50 | json.dump(data, write_file) 51 | 52 | ds = LC_Qaud(path="./data/LC-QUAD/data.json") 53 | tmp = [] 54 | for qapair in prepare_dataset(ds).qapairs: 55 | raw_row = dict() 56 | raw_row["id"] = qapair.id.__str__() 57 | raw_row["question"] = qapair.question.text 58 | raw_row["sparql_query"] = qapair.sparql.query 59 | try: 60 | r = query(qapair.sparql.query) 61 | raw_row["answers"] = r[1] 62 | except Exception as e: 63 | raw_row["answers"] = [] 64 | 65 | tmp.append(raw_row) 66 | 67 | with open('data/LC-QUAD/linked_answer.json', 'w') as jsonFile: 68 | json.dump(tmp, jsonFile) 69 | 70 | print('data len: ', len(tmp)) 71 | -------------------------------------------------------------------------------- /parser/lc_quad.py: -------------------------------------------------------------------------------- 1 | import json, re 2 | from common.container.qapair import QApair 3 | from common.container.uri import Uri 4 | from kb.dbpedia import DBpedia 5 | from parser.answerparser import AnswerParser 6 | # ./data/LC-QUAD/data_v8.json 7 | # {"verbalized_question": "Who are the whose is ?", 8 | # "_id": "f0a9f1ca14764095ae089b152e0e7f12", 9 | # "sparql_template_id": 301, 10 | # "sparql_query": "SELECT DISTINCT ?uri WHERE {?uri . ?uri }", 11 | # "corrected_question": "Which comic characters are painted by Bill Finger?"} 12 | class LC_Qaud: 13 | # def __init__(self, path="./data/LC-QUAD/data_v8.json"): 14 | def __init__(self, path="./data/LC-QUAD/data.json"): 15 | self.raw_data = [] 16 | self.qapairs = [] 17 | self.path = path 18 | self.parser = LC_QaudParser() 19 | 20 | def load(self): 21 | with open(self.path) as data_file: 22 | self.raw_data = json.load(data_file) 23 | 24 | def parse(self): 25 | parser = LC_QaudParser() 26 | for raw_row in self.raw_data: 27 | sparql_query = raw_row["sparql_query"].replace("DISTINCT COUNT(", "COUNT(DISTINCT ") 28 | self.qapairs.append( 29 | QApair(raw_row["corrected_question"], [], sparql_query, raw_row, raw_row["_id"], self.parser)) 30 | 31 | def print_pairs(self, n=-1): 32 | for item in self.qapairs[0:n]: 33 | print(item) 34 | print("") 35 | 36 | 37 | class LC_QaudParser(AnswerParser): 38 | def __init__(self): 39 | super(LC_QaudParser, self).__init__(DBpedia(one_hop_bloom_file="./data/blooms/spo1.bloom")) 40 | 41 | def parse_question(self, raw_question): 42 | return raw_question 43 | 44 | def parse_sparql(self, raw_query): 45 | uris = [Uri(raw_uri, DBpedia.parse_uri) for raw_uri in re.findall('<[^>]*>', raw_query)] 46 | 47 | return raw_query, True, uris 48 | 49 | def parse_answerset(self, raw_answerset): 50 | return [] 51 | 52 | def parse_answerrow(self, raw_answerrow): 53 | return [] 54 | 55 | def parse_answer(self, answer_type, raw_answer): 56 | return "", None 57 | -------------------------------------------------------------------------------- /learning/classifier/svmclassifier.py: -------------------------------------------------------------------------------- 1 | from learning.classifier.classifier import Classifier 2 | from sklearn.pipeline import Pipeline 3 | from sklearn.ensemble import RandomForestClassifier 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.base import BaseEstimator, TransformerMixin 6 | import nltk 7 | from nltk.stem import WordNetLemmatizer 8 | from nltk.tokenize import sent_tokenize, word_tokenize 9 | import spacy 10 | from sklearn.decomposition import TruncatedSVD 11 | from sklearn.feature_selection import SelectKBest, SelectPercentile, f_classif, mutual_info_classif, chi2 12 | 13 | wordnet_lemmatizer = WordNetLemmatizer() 14 | 15 | 16 | class Lemmarize(BaseEstimator, TransformerMixin): 17 | def fit(self, X, y=None): 18 | return self 19 | 20 | def transform(self, X): 21 | new = [] 22 | for sentence in X: 23 | token_words = word_tokenize(sentence) 24 | stem_sentence = [] 25 | for word in token_words: 26 | stem_sentence.append(wordnet_lemmatizer.lemmatize(word, pos="v")) 27 | new.append(" ".join(stem_sentence)) 28 | return new 29 | 30 | 31 | class POS(BaseEstimator, TransformerMixin): 32 | def fit(self, X, y=None): 33 | return self 34 | 35 | def transform(self, X): 36 | spacy.prefer_gpu() 37 | nlp = spacy.load("en_core_web_lg") 38 | new = [] 39 | for sentence in X: 40 | doc = nlp(sentence) 41 | json_doc = doc.to_json() 42 | token = json_doc['tokens'] 43 | tag = [] 44 | for t in token: 45 | tag.append(t['tag']) 46 | new.append(" ".join(tag)) 47 | return new 48 | 49 | 50 | class SVMClassifier(Classifier): 51 | def __init__(self, model_file_path=None): 52 | super(SVMClassifier, self).__init__(model_file_path) 53 | self.pipeline = Pipeline([ 54 | ('lemma', Lemmarize()), 55 | ('tf-idf', TfidfVectorizer(max_df=0.9, min_df=3, max_features=2000, ngram_range=(1,4))), 56 | ('svm', RandomForestClassifier(n_estimators=150, max_depth=150, criterion='gini', random_state=42))]) 57 | self.parameters = { 58 | 'svm__max_features': ('sqrt', 'log2') 59 | } 60 | 61 | -------------------------------------------------------------------------------- /learning/treelstm/trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import torch 4 | from torch.autograd import Variable as Var 5 | 6 | from learning.treelstm.utils import map_label_to_target 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, args, model, criterion, optimizer): 11 | super(Trainer, self).__init__() 12 | self.args = args 13 | self.model = model 14 | self.criterion = criterion 15 | self.optimizer = optimizer 16 | self.epoch = 0 17 | 18 | # helper function for training 19 | def train(self, dataset): 20 | self.model.train() 21 | self.optimizer.zero_grad() 22 | loss, k = 0.0, 0 23 | indices = torch.randperm(len(dataset)) 24 | for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''): 25 | ltree, lsent, rtree, rsent, label = dataset[indices[idx]] 26 | linput, rinput = Var(lsent), Var(rsent) 27 | target = Var(map_label_to_target(label, dataset.num_classes)) 28 | if self.args.cuda: 29 | linput, rinput = linput.cuda(), rinput.cuda() 30 | target = target.cuda() 31 | output = self.model(ltree, linput, rtree, rinput) 32 | err = self.criterion(output, target) 33 | loss += err.data[0] 34 | err.backward() 35 | k += 1 36 | if k % self.args.batchsize == 0: 37 | self.optimizer.step() 38 | self.optimizer.zero_grad() 39 | self.epoch += 1 40 | return loss / len(dataset) 41 | 42 | # helper function for testing 43 | def test(self, dataset): 44 | self.model.eval() 45 | loss = 0 46 | predictions = torch.zeros(len(dataset)) 47 | indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float) 48 | for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''): 49 | ltree, lsent, rtree, rsent, label = dataset[idx] 50 | linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True) 51 | target = Var(map_label_to_target(label, dataset.num_classes), volatile=True) 52 | if self.args.cuda: 53 | linput, rinput = linput.cuda(), rinput.cuda() 54 | target = target.cuda() 55 | output = self.model(ltree, linput, rtree, rinput) 56 | err = self.criterion(output, target) 57 | loss += err.data 58 | output = output.data.squeeze().cpu() 59 | predictions[idx] = torch.dot(indices, torch.exp(output)) 60 | return loss / len(dataset), predictions 61 | -------------------------------------------------------------------------------- /treelstm/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import math 6 | 7 | import torch 8 | 9 | from .vocab import Vocab 10 | 11 | 12 | # loading GLOVE word vectors 13 | # if .pth file is found, will load that 14 | # else will load from .txt file & save 15 | def load_word_vectors(path): 16 | if os.path.isfile(path + '.pth') and os.path.isfile(path + '.vocab'): 17 | print('==> File found, loading to memory') 18 | vectors = torch.load(path + '.pth') 19 | vocab = Vocab(filename=path + '.vocab') 20 | return vocab, vectors 21 | # saved file not found, read from txt file 22 | # and create tensors for word vectors 23 | print('==> File not found, preparing, be patient') 24 | count = sum(1 for line in open(path + '.txt', 'r', encoding='utf8', errors='ignore')) 25 | with open(path + '.txt', 'r') as f: 26 | contents = f.readline().rstrip('\n').split(' ') 27 | dim = len(contents[1:]) 28 | words = [None] * (count) 29 | vectors = torch.zeros(count, dim, dtype=torch.float, device='cpu') 30 | with open(path + '.txt', 'r', encoding='utf8', errors='ignore') as f: 31 | idx = 0 32 | for line in f: 33 | contents = line.rstrip('\n').split(' ') 34 | words[idx] = contents[0] 35 | values = list(map(float, contents[1:])) 36 | vectors[idx] = torch.tensor(values, dtype=torch.float, device='cpu') 37 | idx += 1 38 | with open(path + '.vocab', 'w', encoding='utf8', errors='ignore') as f: 39 | for word in words: 40 | f.write(word + '\n') 41 | vocab = Vocab(filename=path + '.vocab') 42 | torch.save(vectors, path + '.pth') 43 | return vocab, vectors 44 | 45 | # write unique words from a set of files to a new file 46 | def build_vocab(filenames, vocabfile): 47 | vocab = set() 48 | for filename in filenames: 49 | with open(filename, 'r') as f: 50 | for line in f: 51 | tokens = line.rstrip('\n').split(' ') 52 | vocab |= set(tokens) 53 | with open(vocabfile, 'w') as f: 54 | for token in sorted(vocab): 55 | f.write(token + '\n') 56 | 57 | 58 | # mapping from scalar to vector 59 | def map_label_to_target(label, num_classes, vocab_output): 60 | target = torch.zeros(1, dtype=torch.long) 61 | target[0] = vocab_output.getIndex(str(int(label))) 62 | return target 63 | # return torch.tensor(vocab_output.getIndex(str(int(label)))) 64 | 65 | # target = torch.zeros(1, num_classes, dtype=torch.long, device='cpu') 66 | # target[0, ] = 1 67 | # return target 68 | -------------------------------------------------------------------------------- /learning/treelstm/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import math 6 | 7 | import torch 8 | 9 | from learning.treelstm.tree import Tree 10 | from learning.treelstm.vocab import Vocab 11 | 12 | 13 | # loading GLOVE word vectors 14 | # if .pth file is found, will load that 15 | # else will load from .txt file & save 16 | def load_word_vectors(path): 17 | if os.path.isfile(path + '.pth') and os.path.isfile(path + '.vocab'): 18 | print('==> File found, loading to memory') 19 | vectors = torch.load(path + '.pth') 20 | vocab = Vocab(filename=path + '.vocab') 21 | return vocab, vectors 22 | # saved file not found, read from txt file 23 | # and create tensors for word vectors 24 | print('==> File not found, preparing, be patient') 25 | count = sum(1 for line in open(path + '.txt', encoding='utf-8', errors='ignore')) 26 | with open(path + '.txt', 'r') as f: 27 | contents = f.readline().rstrip('\n').split(' ') 28 | dim = len(contents[1:]) 29 | words = [None] * (count) 30 | vectors = torch.zeros(count, dim) 31 | with open(path + '.txt', 'r', encoding='utf-8', errors='ignore') as f: 32 | idx = 0 33 | for line in f: 34 | contents = line.rstrip('\n').split(' ') 35 | words[idx] = contents[0] 36 | vectors[idx] = torch.Tensor(list(map(float, contents[1:]))) 37 | idx += 1 38 | with open(path + '.vocab', 'w', encoding='utf-8', errors='ignore') as f: 39 | for word in words: 40 | f.write(word + '\n') 41 | vocab = Vocab(filename=path + '.vocab') 42 | torch.save(vectors, path + '.pth') 43 | return vocab, vectors 44 | 45 | 46 | # write unique words from a set of files to a new file 47 | def build_vocab(filenames, vocabfile): 48 | vocab = set() 49 | for filename in filenames: 50 | if os.path.exists(filename): 51 | with open(filename, 'r') as f: 52 | for line in f: 53 | tokens = line.rstrip('\n').split(' ') 54 | vocab |= set(tokens) 55 | with open(vocabfile, 'w') as f: 56 | for token in sorted(vocab): 57 | f.write(token + '\n') 58 | 59 | 60 | # mapping from scalar to vector 61 | def map_label_to_target(label, num_classes): 62 | target = torch.zeros(1, num_classes) 63 | # if label == -1: 64 | # target[0][0] = 1 65 | # else: 66 | # target[0][1] = 1 67 | ceil = int(math.ceil(label)) 68 | floor = int(math.floor(label)) 69 | if ceil == floor: 70 | target[0][floor - 1] = 1 71 | else: 72 | target[0][floor - 1] = ceil - label 73 | target[0][ceil - 1] = label - floor 74 | return target 75 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser( 6 | description='PyTorch TreeLSTM for Sentence Similarity on Dependency Trees') 7 | # data arguments 8 | parser.add_argument('--data', default='learning/treelstm/data/lc-quad/', 9 | help='path to dataset') 10 | parser.add_argument('--save', default='learning/treelstm/checkpoints/', 11 | help='directory to save checkpoints in') 12 | parser.add_argument('--expname', type=str, default='lc_quad', 13 | help='Name to identify experiment') 14 | # model arguments 15 | parser.add_argument('--mem_dim', default=150, type=int, 16 | help='Size of TreeLSTM cell state') 17 | parser.add_argument('--freeze_embed', action='store_true', 18 | help='Freeze word embeddings') 19 | # training arguments 20 | parser.add_argument('--epochs', default=15, type=int, 21 | help='number of total epochs to run') 22 | parser.add_argument('--batchsize', default=12, type=int, 23 | help='batchsize for optimizer updates') 24 | parser.add_argument('--lr', default=1e-2, type=float, 25 | metavar='LR', help='initial learning rate') 26 | parser.add_argument('--wd', default=2.25e-3, type=float, 27 | help='weight decay (default: 1e-4)') 28 | parser.add_argument('--emblr', default=1e-2, type=float, 29 | metavar='EMLR', help='initial embedding learning rate') 30 | parser.add_argument('--sparse', action='store_true', 31 | help='Enable sparsity for embeddings, \ 32 | incompatible with weight decay') 33 | parser.add_argument('--optim', default='adam', 34 | help='optimizer (default: adam)') 35 | 36 | # miscellaneous options 37 | parser.add_argument('--seed', default=42, type=int, 38 | help='random seed (default: 42)') 39 | cuda_parser = parser.add_mutually_exclusive_group(required=False) 40 | cuda_parser.add_argument('--cuda', dest='cuda', action='store_true') 41 | cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false') 42 | parser.set_defaults(cuda=True) 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | config = { 48 | 'general': { 49 | 'http': { 50 | 'timeout': 120 51 | }, 52 | 'dbpedia': { 53 | 'endpoint': 'http://dbpedia.org/sparql', 54 | 'one_hop_bloom_file': './data/blooms/spo1.bloom', 55 | 'two_hop_bloom_file': './data/blooms/spo2.bloom' 56 | } 57 | } 58 | } 59 | 60 | -------------------------------------------------------------------------------- /treelstm/vocab.py: -------------------------------------------------------------------------------- 1 | # vocab object from harvardnlp/opennmt-py 2 | class Vocab(object): 3 | def __init__(self, filename=None, data=None, lower=False): 4 | self.idxToLabel = {} 5 | self.labelToIdx = {} 6 | self.lower = lower 7 | 8 | # Special entries will not be pruned. 9 | self.special = [] 10 | 11 | if data is not None: 12 | self.addSpecials(data) 13 | if filename is not None: 14 | self.loadFile(filename) 15 | 16 | def size(self): 17 | return len(self.idxToLabel) 18 | 19 | # Load entries from a file. 20 | def loadFile(self, filename): 21 | idx = 0 22 | for line in open(filename, 'r', encoding='utf8', errors='ignore'): 23 | token = line.rstrip('\n') 24 | self.add(token) 25 | idx += 1 26 | 27 | def getIndex(self, key, default=None): 28 | key = key.lower() if self.lower else key 29 | try: 30 | return self.labelToIdx[key] 31 | except KeyError: 32 | return default 33 | 34 | def getLabel(self, idx, default=None): 35 | try: 36 | return self.idxToLabel[idx] 37 | except KeyError: 38 | return default 39 | 40 | # Mark this `label` and `idx` as special 41 | def addSpecial(self, label, idx=None): 42 | idx = self.add(label) 43 | self.special += [idx] 44 | 45 | # Mark all labels in `labels` as specials 46 | def addSpecials(self, labels): 47 | for label in labels: 48 | self.addSpecial(label) 49 | 50 | # Add `label` in the dictionary. Use `idx` as its index if given. 51 | def add(self, label): 52 | label = label.lower() if self.lower else label 53 | if label in self.labelToIdx: 54 | idx = self.labelToIdx[label] 55 | else: 56 | idx = len(self.idxToLabel) 57 | self.idxToLabel[idx] = label 58 | self.labelToIdx[label] = idx 59 | return idx 60 | 61 | # Convert `labels` to indices. Use `unkWord` if not found. 62 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 63 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 64 | vec = [] 65 | 66 | if bosWord is not None: 67 | vec += [self.getIndex(bosWord)] 68 | 69 | unk = self.getIndex(unkWord) 70 | vec += [self.getIndex(label, default=unk) for label in labels] 71 | 72 | if eosWord is not None: 73 | vec += [self.getIndex(eosWord)] 74 | 75 | return vec 76 | 77 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 78 | def convertToLabels(self, idx, stop): 79 | labels = [] 80 | 81 | for i in idx: 82 | labels += [self.getLabel(i)] 83 | if i == stop: 84 | break 85 | 86 | return labels 87 | -------------------------------------------------------------------------------- /learning/treelstm/vocab.py: -------------------------------------------------------------------------------- 1 | # vocab object from harvardnlp/opennmt-py 2 | class Vocab(object): 3 | def __init__(self, filename=None, data=None, lower=False): 4 | self.idxToLabel = {} 5 | self.labelToIdx = {} 6 | self.lower = lower 7 | 8 | # Special entries will not be pruned. 9 | self.special = [] 10 | 11 | if data is not None: 12 | self.addSpecials(data) 13 | if filename is not None: 14 | self.loadFile(filename) 15 | 16 | def size(self): 17 | return len(self.idxToLabel) 18 | 19 | # Load entries from a file. 20 | def loadFile(self, filename): 21 | idx = 0 22 | for line in open(filename, encoding='utf-8', errors='ignore'): 23 | token = line.rstrip('\n') 24 | self.add(token) 25 | idx += 1 26 | 27 | def getIndex(self, key, default=None): 28 | key = key.lower() if self.lower else key 29 | try: 30 | return self.labelToIdx[key] 31 | except KeyError: 32 | return default 33 | 34 | def getLabel(self, idx, default=None): 35 | try: 36 | return self.idxToLabel[idx] 37 | except KeyError: 38 | return default 39 | 40 | # Mark this `label` and `idx` as special 41 | def addSpecial(self, label, idx=None): 42 | idx = self.add(label) 43 | self.special += [idx] 44 | 45 | # Mark all labels in `labels` as specials 46 | def addSpecials(self, labels): 47 | for label in labels: 48 | self.addSpecial(label) 49 | 50 | # Add `label` in the dictionary. Use `idx` as its index if given. 51 | def add(self, label): 52 | label = label.lower() if self.lower else label 53 | if label in self.labelToIdx: 54 | idx = self.labelToIdx[label] 55 | else: 56 | idx = len(self.idxToLabel) 57 | self.idxToLabel[idx] = label 58 | self.labelToIdx[label] = idx 59 | return idx 60 | 61 | # Convert `labels` to indices. Use `unkWord` if not found. 62 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 63 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 64 | vec = [] 65 | 66 | if bosWord is not None: 67 | vec += [self.getIndex(bosWord)] 68 | 69 | unk = self.getIndex(unkWord) 70 | vec += [self.getIndex(label, default=unk) for label in labels] 71 | 72 | if eosWord is not None: 73 | vec += [self.getIndex(eosWord)] 74 | 75 | return vec 76 | 77 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 78 | def convertToLabels(self, idx, stop): 79 | labels = [] 80 | 81 | for i in idx: 82 | labels += [self.getLabel(i)] 83 | if i == stop: 84 | break 85 | 86 | return labels 87 | -------------------------------------------------------------------------------- /common/graph/edge.py: -------------------------------------------------------------------------------- 1 | class Edge: 2 | def __init__(self, source_node, uri, dest_node): 3 | self.source_node = source_node 4 | self.uri = uri 5 | self.dest_node = dest_node 6 | self.source_node.add_outbound(self) 7 | self.dest_node.add_inbound(self) 8 | self.__confidence = ( 9 | self.source_node.confidence if self.source_node is not None else 1) * self.uri.confidence * ( 10 | self.dest_node.confidence if self.dest_node is not None else 1) 11 | self.__hash = ("" if source_node is None else self.source_node.__str__()) + self.uri.__str__() + ( 12 | "" if dest_node is None else self.dest_node.__str__()) 13 | 14 | @property 15 | def confidence(self): 16 | return self.__confidence 17 | 18 | def copy(self, source_node=None, uri=None, dest_node=None): 19 | return Edge(self.source_node if source_node is None else source_node, 20 | self.uri if uri is None else uri, 21 | self.dest_node if dest_node is None else dest_node) 22 | 23 | def has_uri(self, uri): 24 | return self.uri == uri or self.source_node.has_uri(uri) or self.dest_node.has_uri(uri) 25 | 26 | def prepare_remove(self): 27 | self.source_node.remove_outbound(self) 28 | self.dest_node.remove_inbound(self) 29 | 30 | def max_generic_id(self): 31 | s = self.source_node.first_uri_if_only() 32 | if s is not None: 33 | s = s.generic_id() 34 | if s is None: 35 | s = -1 36 | 37 | d = self.dest_node.first_uri_if_only() 38 | if d is not None: 39 | d = d.generic_id() 40 | if d is None: 41 | d = -1 42 | 43 | return max(s, d) 44 | 45 | def sparql_format(self, kb): 46 | return u"{} {} {}".format(self.source_node.sparql_format(kb), self.uri.sparql_format(kb), 47 | self.dest_node.sparql_format(kb)) 48 | 49 | def full_path(self): 50 | return "{} --> {} --> {}".format(self.source_node.__str__(), self.uri.__str__(), self.dest_node.__str__()) 51 | 52 | def generic_equal(self, other): 53 | if isinstance(other, Edge): 54 | return self.source_node.generic_equal(other.source_node) \ 55 | and self.dest_node.generic_equal(other.dest_node) \ 56 | and self.uri.generic_equal(other.uri) 57 | return NotImplemented 58 | 59 | def __hash__(self): 60 | return hash(self.full_path()) 61 | 62 | def __eq__(self, other): 63 | if isinstance(other, Edge): 64 | if hasattr(self, "__hash"): 65 | return self.__hash == other.__hash 66 | else: 67 | return self.source_node == other.source_node and self.uri == other.uri and self.dest_node == other.dest_node 68 | 69 | return NotImplemented 70 | 71 | def __str__(self): 72 | return self.uri.__str__() 73 | -------------------------------------------------------------------------------- /learning/treelstm/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='PyTorch TreeLSTM for Question-Query Similarity on Dependency Trees') 6 | # 7 | parser.add_argument('--mode', default='train', 8 | help='mode: `train` or `test`') 9 | parser.add_argument('--data', default='./learning/treelstm/data/lc_quad/', 10 | help='path to dataset') 11 | parser.add_argument('--glove', default='./learning/treelstm/data/glove/', 12 | help='directory with GLOVE embeddings') 13 | parser.add_argument('--save', default='learning/treelstm/checkpoints/', 14 | help='directory to save checkpoints in') 15 | parser.add_argument('--load', default='checkpoints/', 16 | help='directory to load checkpoints in') 17 | parser.add_argument('--expname', type=str, default='lc_quad', 18 | help='Name to identify experiment') 19 | 20 | # model arguments 21 | parser.add_argument('--input_dim', default=300, type=int, 22 | help='Size of input word vector') 23 | parser.add_argument('--mem_dim', default=150, type=int, 24 | help='Size of TreeLSTM cell state') 25 | parser.add_argument('--hidden_dim', default=50, type=int, 26 | help='Size of classifier MLP') 27 | parser.add_argument('--num_classes', default=2, type=int, 28 | help='Number of classes in dataset') 29 | # training arguments 30 | parser.add_argument('--epochs', default=15, type=int, 31 | help='number of total epochs to run') 32 | parser.add_argument('--batchsize', default=25, type=int, 33 | help='batchsize for optimizer updates') 34 | parser.add_argument('--lr', default=0.01, type=float, 35 | metavar='LR', help='initial learning rate') 36 | parser.add_argument('--wd', default=0.00225, type=float, 37 | help='weight decay (default: 1e-4)') 38 | parser.add_argument('--sparse', action='store_true', 39 | help='Enable sparsity for embeddings, \ 40 | incompatible with weight decay') 41 | parser.add_argument('--optim', default='adagrad', 42 | help='optimizer (default: adagrad)') 43 | parser.add_argument('--sim', default='nn', 44 | help='similarity (default: nn) nn or cos') 45 | # miscellaneous options 46 | parser.add_argument('--seed', default=123, type=int, 47 | help='random seed (default: 123)') 48 | cuda_parser = parser.add_mutually_exclusive_group(required=False) 49 | cuda_parser.add_argument('--cuda', dest='cuda', action='store_true') 50 | cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false') 51 | parser.set_defaults(cuda=True) 52 | 53 | args = parser.parse_args() 54 | return args 55 | -------------------------------------------------------------------------------- /output/qald/a.tag: -------------------------------------------------------------------------------- 1 | WDT NN VBZ VBN IN DT $ NN CC DT NNP NNP NNP . 2 | WDT NN VBZ VBN IN DT $ NN CC DT NNP NNP NNP . 3 | WDT NN VBZ VBN IN DT $ NN CC DT NNP NNP NNP . 4 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 5 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 6 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 7 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 8 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 9 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 10 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 11 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 12 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 13 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 14 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 15 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 16 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 17 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 18 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 19 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 20 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 21 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 22 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 23 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 24 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 25 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 26 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 27 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 28 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 29 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 30 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 31 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 32 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 33 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 34 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 35 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 36 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 37 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 38 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 39 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 40 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 41 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 42 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 43 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 44 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 45 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 46 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 47 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 48 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 49 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 50 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 51 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 52 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 53 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 54 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 55 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 56 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 57 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 58 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 59 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 60 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 61 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 62 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 63 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 64 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 65 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 66 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 67 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 68 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 69 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 70 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 71 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 72 | WDT NN VBZ VBN IN DT $ NN CC DT $ NN . 73 | -------------------------------------------------------------------------------- /common/graph/node.py: -------------------------------------------------------------------------------- 1 | from common.container.uri import Uri 2 | import numpy as np 3 | 4 | 5 | class Node: 6 | def __init__(self, uris, mergable=False): 7 | if isinstance(uris, Uri): 8 | self.__uris = set([uris]) 9 | elif isinstance(uris, list): 10 | self.__uris = set() 11 | for uri in uris: 12 | if isinstance(uri, Uri): 13 | self.__uris.add(uri) 14 | elif isinstance(uris, list): 15 | self.__uris.update(uri) 16 | self.uris_hash = self.__str__() 17 | self.mergable = mergable 18 | self.inbound = [] 19 | self.outbound = [] 20 | self.__confidence = np.prod([uri.confidence for uri in self.__uris]) 21 | 22 | @property 23 | def confidence(self): 24 | return self.__confidence 25 | 26 | @property 27 | def uris(self): 28 | return self.__uris 29 | 30 | def is_disconnected(self): 31 | return len(self.inbound) == 0 and len(self.outbound) == 0 32 | 33 | def add_outbound(self, edge): 34 | if edge not in self.outbound: 35 | self.outbound.append(edge) 36 | 37 | def remove_outbound(self, edge): 38 | self.outbound.remove(edge) 39 | 40 | def add_inbound(self, edge): 41 | if edge not in self.inbound: 42 | self.inbound.append(edge) 43 | 44 | def remove_inbound(self, edge): 45 | self.inbound.remove(edge) 46 | 47 | def first_uri_if_only(self): 48 | if len(self.__uris) == 1: 49 | return next(iter(self.__uris)) 50 | return None 51 | 52 | def __are_all_uris_of_type(self, uri_type): 53 | uris_type = set([u.uri_type for u in self.__uris]) 54 | return len(uris_type) == 1 and uris_type.pop() == uri_type 55 | 56 | def are_all_uris_generic(self): 57 | return self.__are_all_uris_of_type("g") 58 | 59 | def are_all_uris_type(self): 60 | return self.__are_all_uris_of_type("?t") 61 | 62 | def replace_uri(self, uri, new_uri): 63 | if uri in self.__uris: 64 | self.__uris.remove(uri) 65 | self.__uris.add(new_uri) 66 | return True 67 | return False 68 | 69 | def has_uri(self, uri): 70 | return uri in self.__uris 71 | 72 | def sparql_format(self, kb): 73 | if len(self.__uris) == 1: 74 | return self.first_uri_if_only().sparql_format(kb) 75 | raise Exception("...") 76 | 77 | def generic_equal(self, other): 78 | return (self.are_all_uris_generic() and other.are_all_uris_generic()) or self == other 79 | 80 | def __hash__(self): 81 | return hash(self.__str__()) 82 | 83 | def __eq__(self, other): 84 | if isinstance(other, Node): 85 | # return self.__uris == other.__uris 86 | return self.uris_hash == other.uris_hash 87 | return NotImplemented 88 | 89 | def __ne__(self, other): 90 | result = self.__eq__(other) 91 | if result is NotImplemented: 92 | return result 93 | return not result 94 | 95 | def __str__(self): 96 | return "\n".join(sorted([uri.__str__() for uri in self.__uris])) 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anytree==2.6.0 2 | asn1crypto==0.24.0 3 | atomicwrites==1.3.0 4 | attrs==19.1.0 5 | backcall==0.1.0 6 | backports.functools-lru-cache==1.5 7 | bitarray==0.9.3 8 | bleach==1.5.0 9 | blis==0.2.4 10 | boto==2.49.0 11 | boto3==1.9.169 12 | botocore==1.12.169 13 | bz2file==0.98 14 | certifi==2019.6.16 15 | cffi==1.11.5 16 | chardet==3.0.4 17 | Click==7.0 18 | cloudpickle==1.1.1 19 | conda==4.6.14 20 | cryptography==2.7 21 | cycler==0.10.0 22 | cymem==2.0.2 23 | Cython==0.29.10 24 | cytoolz==0.9.0.1 25 | dask==1.2.2 26 | decorator==4.3.0 27 | docutils==0.14 28 | dominate==2.3.5 29 | en-core-web-lg==2.1.0 30 | en-core-web-md==2.1.0 31 | en-core-web-sm==2.1.0 32 | entrypoints==0.3 33 | fasttext==0.8.22 34 | flake8==3.7.7 35 | Flask==1.0.3 36 | Flask-Bootstrap==3.3.7.1 37 | Flask-Classful==0.14.2 38 | Flask-Login==0.4.1 39 | Flask-SQLAlchemy==2.4.0 40 | Flask-WTF==0.14.2 41 | future==0.17.1 42 | gensim==3.7.3 43 | gevent==1.4.0 44 | greenlet==0.4.15 45 | h5py==2.8.0 46 | html5lib==0.9999999 47 | idna==2.8 48 | imageio==2.5.0 49 | importlib-metadata==0.17 50 | inflect==2.1.0 51 | interruptingcow==0.8 52 | ipython==6.4.0 53 | ipython-genutils==0.2.0 54 | isodate==0.6.0 55 | itsdangerous==1.1.0 56 | jedi==0.12.0 57 | Jinja2==2.10.1 58 | jmespath==0.9.4 59 | joblib==0.13.2 60 | jsonschema==3.0.1 61 | kiwisolver==1.1.0 62 | Markdown==3.1.1 63 | MarkupSafe==1.1.1 64 | matplotlib==3.1.0 65 | mccabe==0.6.1 66 | mkl-fft==1.0.12 67 | mkl-random==1.0.2 68 | mkl-service==2.0.2 69 | more-itertools==7.0.0 70 | murmurhash==1.0.2 71 | networkx==2.3 72 | nibabel==2.4.1 73 | nltk==3.4.5 74 | numpy==1.16.4 75 | olefile==0.46 76 | onnx==1.5.0 77 | packaging==19.0 78 | pandas==0.24.2 79 | parso==0.2.0 80 | pexpect==4.5.0 81 | pickleshare==0.7.4 82 | Pillow==6.0.0 83 | plac==0.9.6 84 | pluggy==0.12.0 85 | preshed==2.0.1 86 | prompt-toolkit==1.0.15 87 | protobuf==3.7.1 88 | ptyprocess==0.5.2 89 | py==1.8.0 90 | pybind11==2.3.0 91 | pybloom-live==3.0.0 92 | pybloomfiltermmap3==0.4.15 93 | pycodestyle==2.5.0 94 | pycosat==0.6.3 95 | pycparser==2.18 96 | pydicom==1.2.2 97 | pyflakes==2.1.1 98 | Pygments==2.2.0 99 | pyOpenSSL==17.5.0 100 | pyparsing==2.4.0 101 | pyrsistent==0.15.2 102 | PySocks==1.6.8 103 | pytest==4.6.2 104 | python-dateutil==2.8.0 105 | pytz==2019.1 106 | PyWavelets==1.0.3 107 | PyYAML==4.2b1 108 | rdflib==4.2.2 109 | requests==2.22.0 110 | ruamel-yaml==0.15.35 111 | s3transfer==0.2.1 112 | scikit-image==0.15.0 113 | scikit-learn==0.21.2 114 | scipy==1.2.1 115 | seaborn==0.9.0 116 | simplegeneric==0.8.1 117 | singledispatch==3.4.0.3 118 | six==1.12.0 119 | smart-open==1.8.4 120 | spacy==2.1.4 121 | SPARQLWrapper==1.8.4 122 | spectrum==0.7.1 123 | spotlight==0.1.5 124 | SQLAlchemy==1.3.4 125 | srsly==0.0.7 126 | subprocess32==3.5.4 127 | tagme==0.1.3 128 | tensorboardX==1.7 129 | tensorflow-tensorboard==1.5.1 130 | thinc==7.0.4 131 | timeout==0.1.2 132 | toolz==0.9.0 133 | torch==0.5.0a0+fb59ce3 134 | torchtext==0.3.1 135 | torchvision==0.2.1 136 | tornado==6.0.2 137 | tqdm==4.32.1 138 | traitlets==4.3.2 139 | ujson==1.35 140 | urllib3==1.25.3 141 | visitor==0.1.3 142 | wasabi==0.2.2 143 | wcwidth==0.1.7 144 | Werkzeug==0.15.4 145 | WTForms==2.2.1 146 | zipp==0.5.1 147 | -------------------------------------------------------------------------------- /learning/treelstm/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | from copy import deepcopy 5 | 6 | import torch 7 | import torch.utils.data as data 8 | 9 | import learning.treelstm.Constants as Constants 10 | from learning.treelstm.tree import Tree 11 | from learning.treelstm.vocab import Vocab 12 | 13 | 14 | class QGDataset(data.Dataset): 15 | def __init__(self, path, vocab, num_classes): 16 | super(QGDataset, self).__init__() 17 | self.vocab = vocab 18 | self.num_classes = num_classes 19 | 20 | self.lsentences = self.read_sentences(os.path.join(path, 'a.toks')) 21 | self.rsentences = self.read_sentences(os.path.join(path, 'b.toks')) 22 | 23 | self.ltrees = self.read_trees(os.path.join(path, 'a.parents')) 24 | self.rtrees = self.read_trees(os.path.join(path, 'b.parents')) 25 | 26 | self.labels = self.read_labels(os.path.join(path, 'sim.txt')) 27 | 28 | self.size = len(self.lsentences) 29 | 30 | def __len__(self): 31 | return self.size 32 | 33 | def __getitem__(self, index): 34 | ltree = deepcopy(self.ltrees[index]) 35 | rtree = deepcopy(self.rtrees[index]) 36 | lsent = deepcopy(self.lsentences[index]) 37 | rsent = deepcopy(self.rsentences[index]) 38 | label = deepcopy(self.labels[index]) 39 | return (ltree, lsent, rtree, rsent, label) 40 | 41 | def read_sentences(self, filename): 42 | with open(filename, 'r') as f: 43 | sentences = [self.read_sentence(line) for line in tqdm(f.readlines())] 44 | return sentences 45 | 46 | def read_sentence(self, line): 47 | indices = self.vocab.convertToIdx(line.split(), Constants.UNK_WORD) 48 | return torch.LongTensor(indices) 49 | 50 | def read_trees(self, filename): 51 | with open(filename, 'r') as f: 52 | trees = [self.read_tree(line) for line in tqdm(f.readlines())] 53 | return trees 54 | 55 | def read_tree(self, line): 56 | parents = list(map(int, line.split())) 57 | trees = dict() 58 | root = None 59 | for i in range(1, len(parents) + 1): 60 | if i - 1 not in trees.keys() and parents[i - 1] != -1: 61 | idx = i 62 | prev = None 63 | while True: 64 | parent = parents[idx - 1] 65 | if parent == -1: 66 | break 67 | tree = Tree() 68 | if prev is not None: 69 | tree.add_child(prev) 70 | trees[idx - 1] = tree 71 | tree.idx = idx - 1 72 | if parent - 1 in trees.keys(): 73 | trees[parent - 1].add_child(tree) 74 | break 75 | elif parent == 0: 76 | root = tree 77 | break 78 | else: 79 | prev = tree 80 | idx = parent 81 | return root 82 | 83 | def read_labels(self, filename): 84 | with open(filename, 'r') as f: 85 | labels = list(map(lambda x: float(x), f.readlines())) 86 | labels = torch.Tensor(labels) 87 | return labels 88 | -------------------------------------------------------------------------------- /common/graph/paths.py: -------------------------------------------------------------------------------- 1 | from common.graph.path import Path 2 | import numpy as np 3 | import itertools 4 | 5 | 6 | class Paths(list): 7 | def __init__(self, *args): 8 | super(Paths, self).__init__(*args) 9 | 10 | @property 11 | def confidence(self): 12 | """ 13 | Cumulative product of paths' confidence 14 | :return: 15 | """ 16 | return np.prod([path.confidence for path in self]) 17 | 18 | def to_where(self, kb=None, ask_query=False): 19 | """ 20 | Transform paths into where clauses 21 | :param kb: 22 | :param ask_query: 23 | :return: 24 | """ 25 | output = [] 26 | sparql_len = [] 27 | 28 | for batch_edges in self: 29 | sparql_where = [edge.sparql_format(kb) for edge in batch_edges] 30 | max_generic_id = max([edge.max_generic_id() for edge in batch_edges]) 31 | if kb is None or ask_query: 32 | output.append({"suggested_id": max_generic_id, "where": sparql_where}) 33 | else: 34 | for L in range(1, len(sparql_where) + 1): 35 | for subset in itertools.combinations(sparql_where, L): 36 | result = kb.query_where(subset, count=True) 37 | if result is not None: 38 | result = int(result["results"]["bindings"][0]["callret-0"]["value"]) 39 | if result > 0: 40 | output.append({"suggested_id": max_generic_id, "where": subset}) 41 | sparql_len.append(len(subset)) 42 | 43 | return output 44 | 45 | def add(self, new_paths, validity_fn): 46 | """ 47 | Append new paths if they pass the validity check 48 | :param new_paths: 49 | :param validity_fn: 50 | :return: 51 | """ 52 | for path in new_paths: 53 | if (len(self) == 0 or path not in self) and validity_fn(path): 54 | self.append(path) 55 | 56 | def extend(self, new_edge): 57 | """ 58 | Create a new that contains path of current which the new_edge if possible is appended to each 59 | :param new_edge: 60 | :return: 61 | """ 62 | new_output = [] 63 | if len(self) == 0: 64 | self.append(Path([])) 65 | for item in self: 66 | if item.addable(new_edge): 67 | path = Path() 68 | for edge in item: 69 | if edge.uri == new_edge.uri and \ 70 | edge.source_node.are_all_uris_generic() and \ 71 | edge.dest_node.are_all_uris_generic() and \ 72 | not ( 73 | new_edge.source_node.are_all_uris_generic() and new_edge.dest_node.are_all_uris_generic()): 74 | pass 75 | else: 76 | path.append(edge) 77 | new_output.append(Path(path + [new_edge])) 78 | else: 79 | new_output.append(item) 80 | return Paths(new_output) 81 | 82 | def remove_duplicates(self): 83 | removed_duplicate_paths = [] 84 | paths_str = [str(path) for path in self] 85 | for idx in range(len(self)): 86 | if paths_str[idx] not in paths_str[idx + 1:]: 87 | removed_duplicate_paths.append(self[idx]) 88 | 89 | return Paths(removed_duplicate_paths) 90 | -------------------------------------------------------------------------------- /linker/earl.py: -------------------------------------------------------------------------------- 1 | import ujson as json 2 | from linker.goldLinker import GoldLinker 3 | from common.container.linkeditem import LinkedItem 4 | from common.container.uri import Uri 5 | from kb.dbpedia import DBpedia 6 | from common.utility.utility import closest_string 7 | 8 | 9 | class Earl: 10 | def __init__(self, path="data/LC-QUAD/EARL/output_original.json"): 11 | self.parser = DBpedia.parse_uri 12 | self.gold_linker = GoldLinker() 13 | with open(path, 'r') as data_file: 14 | self.raw_data = json.load(data_file) 15 | self.questions = {} 16 | for item in self.raw_data: 17 | self.questions[item["question"]] = item 18 | 19 | def __force_gold(self, golden_list, surfaces, items): 20 | not_found = [] 21 | intersect = [] 22 | uri_list = [] 23 | for i_item in items: 24 | for i_uri in i_item.uris: 25 | for g_item in golden_list: 26 | if i_uri in g_item.uris: 27 | intersect.append(g_item) 28 | 29 | return intersect 30 | # def __force_gold(self, golden_list, surfaces, items): 31 | # not_found = [] 32 | # for item in golden_list: 33 | # idx = closest_string(item.surface_form, surfaces) 34 | # if idx != -1: 35 | # if item.uris[0] not in items[idx].uris: 36 | # items[idx].uris[len(items[idx].uris) - 1] = item.uris[0] 37 | # surfaces.pop(idx) 38 | # else: 39 | # not_found.append(item) 40 | # 41 | # for item in not_found: 42 | # if len(surfaces) > 0: 43 | # idx = surfaces.keys()[0] 44 | # items[idx].uris[0] = item.uris[0] 45 | # surfaces.pop(idx) 46 | # else: 47 | # items.append(item) 48 | # 49 | # keys = surfaces.keys() 50 | # keys.sort(reverse=True) 51 | # for idx in keys: 52 | # del items[idx] 53 | # 54 | # return items 55 | 56 | def do(self, qapair, force_gold=False, top=50): 57 | if qapair.question.text in self.questions: 58 | item = self.questions[qapair.question.text] 59 | entities = self.__parse(item, "entities", top) 60 | relations = self.__parse(item, "relations", top) 61 | 62 | if force_gold: 63 | gold_entities, gold_relations = self.gold_linker.do(qapair) 64 | entities_surface = {i: item.surface_form for i, item in enumerate(entities)} 65 | relations_surface = {i: item.surface_form for i, item in enumerate(relations)} 66 | 67 | entities = self.__force_gold(gold_entities, entities_surface, entities) 68 | relations = self.__force_gold(gold_relations, relations_surface, relations) 69 | 70 | return entities, relations 71 | else: 72 | return None, None 73 | 74 | def __parse(self, dataset, name, top): 75 | output = [] 76 | for item in dataset[name]: 77 | uris = [] 78 | for uri in item["uris"]: 79 | uris.append(Uri(uri["uri"], self.parser, uri["confidence"])) 80 | if len(item["surface"])>0: 81 | start_index, length = item["surface"] 82 | surface = dataset["question"][start_index: start_index + length] 83 | else: 84 | surface = "" 85 | output.append(LinkedItem(surface, uris[:top])) 86 | return output 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QAsparql 2 | Question-Answering system for Knowledge Graph (DBpedia) 3 | 4 | Abstract: 5 | 6 | The emergence of Linked Data in the form of knowledge graphs in Resource Description Framework (RDF) data models has been among the first developments since Semantic Web appeared in 2001. These knowledge graphs are typically enormous and are not easily accessible to end users because they require specialized knowledge in query languages (such as SPARQL) as well as deep understanding of content structure in the underlying knowledge graph. This led to the development of Question-Answering (QA) systems based on RDF data to allow end users to access the knowledge graphs and benefit from the information stored in them. While QA systems have progressed rapidly in recent years, there is still room for improvement. 7 | To make the knowledge graphs more accessible to end users, we propose a new QA system for translating natural language questions into SPARQL queries. The key idea is to use neural network models to automatically learn and translate a natural language question into a SPARQL query. Our QA system first predicts the types of questions and then constructs the SPARQL query by extracting, ranking and selecting triple patterns from the original question. The final SPARQL query is constructed by combining the selected triple pattern with the predicted question type. The performance of our proposed QA system is empirically evaluated using the two renowned benchmarks - the 7th Question Answering over Linked Data Challenge (QALD-7) and the Large-Scale Complex Question Answering Dataset (LC- QuAD). Experimental results show that our QA system out- performs the state-of-art systems by 15% on the QALD-7 dataset and by 48% on the LC-QuAD dataset, respectively. The advantage of our approach is that it is generically applicable since it does not require any domain-specific knowledge. 8 | 9 | 10 | Preprocess: 11 | 1. bash earning/treelstm/download.sh -- download the pre-trained word embedding models FastText and Glove 12 | 13 | 14 | Whole Process: 15 | 1. python lcquad_dataset.py -- preprocess the LC-QuAD dataset, generate 'linked_answer.json' file as the LC-QuAD dataset with golden standard answers 16 | 2. python lcquad_answer.py -- generate the golden answers for LC-QuAD dataset, generate 'lcquad_gold.json' file as LC-QuAD dataset with generated SPARQL queries based on the entities and properties extracted from the correct standard SPARQL query 17 | 3. python learning/treelstm/preprocess_lcquad.py -- preprocess the LC-QuAD dataset for Tree-LSTM training, split the original Lc-QuAD dataset into 'LCQuad_train.json', 'LCQuad_trial.json', 'LCQuad_test.json' each with 70%\20%\10% of the original dataset. Generate the dependency parsing tree and the corresponding input and output required to train the Tree-LSTM model. 18 | 4. python learning/treelstm/main.py -- train Tree-LSTM. The generated checkpoints files are stored in \checkpoints folder and used in lcquad_test.py and qald_test.py 19 | 5. python entity_lcquad_test.py -- generate phrase mapping for LC-QuAD test dataset 20 | 6. python entity_qald.py -- generate phrase mapping for QALD-7 test dataset 21 | 7. python lcquad_test.py -- test the QA system on LC-QuAD test dataset 22 | 8. python lcquadall_test.py -- test the QA system on LC-QuAD whole dataset 23 | 8. python qald_test.py -- test the QA system on QALD-7 dataset 24 | 9. python question_type_anlaysis.py -- analyze the question type classification accuracy on LC-QuAD and QALD-7 dataset 25 | 10. python result_analysis.py -- analyze the final result for LC-QuAD and QALD-7 dataset 26 | 27 | 28 | -------------------------------------------------------------------------------- /common/container/sparql.py: -------------------------------------------------------------------------------- 1 | class SPARQL: 2 | def __init__(self, raw_query, parser): 3 | self.raw_query = raw_query 4 | self.query, self.supported, self.uris = parser(raw_query) 5 | self.where_clause, self.where_clause_template = self.__extrat_where() 6 | 7 | def __extrat_where(self): 8 | WHERE = "WHERE" 9 | sparql_query = self.query.strip(" {};\t") 10 | idx = sparql_query.find(WHERE) 11 | where_clause_raw = sparql_query[idx + len(WHERE):].strip(" {}") 12 | where_clause_raw = [item.replace(".", "").strip(" .") for item in where_clause_raw.split(" ")] 13 | where_clause_raw = [item for item in where_clause_raw if item != ""] 14 | buffer = [] 15 | where_clause = [] 16 | for item in where_clause_raw: 17 | buffer.append(item) 18 | if len(buffer) == 3: 19 | where_clause.append(buffer) 20 | buffer = [] 21 | if len(buffer) > 0: 22 | where_clause.append(buffer) 23 | 24 | where_clause_template = " ".join([" ".join(item) for item in where_clause]) 25 | for uri in set(self.uris): 26 | where_clause_template = where_clause_template.replace(uri.uri, uri.uri_type) 27 | 28 | return where_clause, where_clause_template 29 | 30 | def query_features(self): 31 | features = {"boolean": ["ask "], 32 | "count": ["count("], 33 | "filter": ["filter("], 34 | "comparison": ["<= ", ">= ", " < ", " > "], 35 | "sort": ["order by"], 36 | "aggregate": ["max(", "min("] 37 | } 38 | 39 | output = set() 40 | if self.where_clause_template.count(" ") > 3: 41 | output.add("compound") 42 | else: 43 | output.add("single") 44 | generic_uris = set() 45 | for uri in self.uris: 46 | if uri.is_generic(): 47 | generic_uris.add(uri) 48 | if len(generic_uris) > 1: 49 | output.add("multivar") 50 | break 51 | if len(generic_uris) <= 1: 52 | output.add("singlevar") 53 | raw_query = self.raw_query.lower() 54 | for feature in features: 55 | for constraint in features[feature]: 56 | if constraint in raw_query: 57 | output.add(feature) 58 | return output 59 | 60 | def __eq__(self, other): 61 | if isinstance(other, SPARQL): 62 | mapping = {} 63 | for line in self.where_clause: 64 | found = False 65 | for other_line in other.where_clause: 66 | match = 0 67 | mapping_buffer = mapping.copy() 68 | for i in range(len(line)): 69 | if line[i] == other_line[i]: 70 | match += 1 71 | elif line[i].startswith("?") and other_line[i].startswith("?"): 72 | if line[i] not in mapping_buffer: 73 | mapping_buffer[line[i]] = other_line[i] 74 | match += 1 75 | else: 76 | match += mapping_buffer[line[i]] == other_line[i] 77 | if match == len(line): 78 | found = True 79 | mapping = mapping_buffer 80 | break 81 | if not found: 82 | return False 83 | return True 84 | 85 | def __ne__(self, other): 86 | return not self == other 87 | 88 | def __str__(self): 89 | return self.query.encode("ascii", "ignore") 90 | # return self.query 91 | -------------------------------------------------------------------------------- /learning/treelstm/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable as Var 6 | import sys 7 | path = os.getcwd() 8 | print('path: ', path) 9 | sys.path.insert(0, path) 10 | import learning.treelstm.Constants as Constants 11 | # sys.path.insert(0,'/cluster/home/xlig/qg') 12 | 13 | 14 | # module for childsumtreelstm 15 | class ChildSumTreeLSTM(nn.Module): 16 | def __init__(self, in_dim, mem_dim): 17 | super(ChildSumTreeLSTM, self).__init__() 18 | self.in_dim = in_dim 19 | self.mem_dim = mem_dim 20 | self.ioux = nn.Linear(self.in_dim, 3 * self.mem_dim) 21 | self.iouh = nn.Linear(self.mem_dim, 3 * self.mem_dim) 22 | self.fx = nn.Linear(self.in_dim, self.mem_dim) 23 | self.fh = nn.Linear(self.mem_dim, self.mem_dim) 24 | 25 | def node_forward(self, inputs, child_c, child_h): 26 | child_h_sum = torch.sum(child_h, dim=0, keepdim=True) 27 | 28 | iou = self.ioux(inputs) + self.iouh(child_h_sum) 29 | i, o, u = torch.split(iou, iou.size(1) // 3, dim=1) 30 | i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u) 31 | 32 | f = F.sigmoid( 33 | self.fh(child_h) + 34 | self.fx(inputs).repeat(len(child_h), 1) 35 | ) 36 | fc = torch.mul(f, child_c) 37 | 38 | c = torch.mul(i, u) + torch.sum(fc, dim=0, keepdim=True) 39 | h = torch.mul(o, F.tanh(c)) 40 | return c, h 41 | 42 | def forward(self, tree, inputs): 43 | _ = [self.forward(tree.children[idx], inputs) for idx in range(tree.num_children)] 44 | 45 | if tree.num_children == 0: 46 | child_c = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.)) 47 | child_h = Var(inputs[0].data.new(1, self.mem_dim).fill_(0.)) 48 | else: 49 | child_c, child_h = zip(*map(lambda x: x.state, tree.children)) 50 | child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0) 51 | 52 | tree.state = self.node_forward(inputs[tree.idx], child_c, child_h) 53 | return tree.state 54 | 55 | 56 | # module for distance-angle similarity 57 | class DASimilarity(nn.Module): 58 | def __init__(self, mem_dim, hidden_dim, num_classes): 59 | super(DASimilarity, self).__init__() 60 | self.mem_dim = mem_dim 61 | self.hidden_dim = hidden_dim 62 | self.num_classes = num_classes 63 | self.wh = nn.Linear(2 * self.mem_dim, self.hidden_dim) 64 | self.wp = nn.Linear(self.hidden_dim, self.num_classes) 65 | 66 | def forward(self, lvec, rvec): 67 | mult_dist = torch.mul(lvec, rvec) 68 | abs_dist = torch.abs(torch.add(lvec, -rvec)) 69 | vec_dist = torch.cat((mult_dist, abs_dist), 1) 70 | 71 | vec_dist = F.dropout(vec_dist, p=0.2, training=self.training) 72 | out = F.sigmoid(self.wh(vec_dist)) 73 | out = F.log_softmax(self.wp(out)) 74 | return out 75 | 76 | 77 | # module for cosine similarity 78 | class CosSimilarity(nn.Module): 79 | def __init__(self, mem_dim): 80 | super(CosSimilarity, self).__init__() 81 | self.cos = nn.CosineSimilarity(dim=mem_dim) 82 | 83 | def forward(self, lvec, rvec): 84 | out = self.cos(lvec, rvec) 85 | out = torch.autograd.Variable(torch.FloatTensor([[1 - out.data[0], out.data[0]]]), requires_grad=True) 86 | if torch.cuda.is_available(): 87 | out = out.cuda() 88 | return F.log_softmax(out) 89 | 90 | 91 | # putting the whole model together 92 | class SimilarityTreeLSTM(nn.Module): 93 | def __init__(self, vocab_size, in_dim, mem_dim, similarity, sparsity): 94 | super(SimilarityTreeLSTM, self).__init__() 95 | self.emb = nn.Embedding(vocab_size, in_dim, padding_idx=Constants.PAD, sparse=sparsity) 96 | self.childsumtreelstm = ChildSumTreeLSTM(in_dim, mem_dim) 97 | self.similarity = similarity 98 | 99 | def forward(self, ltree, linputs, rtree, rinputs): 100 | linputs = self.emb(linputs) 101 | rinputs = self.emb(rinputs) 102 | lstate, lhidden = self.childsumtreelstm(ltree, linputs) 103 | rstate, rhidden = self.childsumtreelstm(rtree, rinputs) 104 | output = self.similarity(lstate, rstate) 105 | return output 106 | -------------------------------------------------------------------------------- /common/graph/path.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Path(list): 5 | def __init__(self, *args): 6 | super(Path, self).__init__(*args) 7 | self.__sorted_str = None 8 | 9 | @property 10 | def confidence(self): 11 | """ 12 | Cumulative product of edges' confidence 13 | :return: 14 | """ 15 | return np.prod([edge.confidence for edge in self]) 16 | 17 | def addable(self, candidate_edge): 18 | """ 19 | Whether candidate edge would be connected to the current path 20 | :param candidate_edge: 21 | :return: 22 | """ 23 | if len(self) == 0: 24 | return True 25 | for edge in self: 26 | if edge.uri == candidate_edge.uri or \ 27 | edge.source_node == candidate_edge.source_node or \ 28 | edge.dest_node == candidate_edge.dest_node or \ 29 | edge.source_node == candidate_edge.dest_node or \ 30 | edge.dest_node == candidate_edge.source_node: 31 | return True 32 | return False 33 | 34 | def replace_edge(self, old_edge, new_edge): 35 | """ 36 | Create a new path in which an existing edge is replaced with the new edge 37 | :param old_edge: 38 | :param new_edge: 39 | :return: None if the old_edge is not in the current path 40 | """ 41 | try: 42 | new_path = Path(self) 43 | new_path[new_path.index(old_edge)] = new_edge 44 | return new_path 45 | except ValueError: 46 | return None 47 | 48 | def __eq__(self, other): 49 | if len(self) == len(other): 50 | same_flag = True 51 | for edge in self: 52 | if edge not in other: 53 | same_flag = False 54 | break 55 | if same_flag: 56 | return True 57 | return False 58 | 59 | def generic_equal(self, other): 60 | return self.__generic_equal(other)[0] 61 | 62 | def generic_equal_with_substitutable_id(self, other): 63 | output = self.__generic_equal(other) 64 | if not output[0]: 65 | return False 66 | return len(set(output[1])) == 1 67 | 68 | def __generic_equal(self, other): 69 | """ 70 | Check whether the other path in the same as the current one, perhaps with different generic node id 71 | :param other: 72 | :return: 73 | """ 74 | output = [] 75 | if len(self) == len(other): 76 | for edge1 in self: 77 | edge_found = False 78 | for edge2 in other: 79 | if edge1.generic_equal(edge2): 80 | edge_found = True 81 | if edge1.source_node.generic_equal( 82 | edge2.source_node) and edge1.source_node.are_all_uris_generic() \ 83 | and edge2.source_node.are_all_uris_generic(): 84 | output.append( 85 | frozenset( 86 | [edge1.source_node.first_uri_if_only(), edge2.source_node.first_uri_if_only()])) 87 | if edge1.dest_node.generic_equal( 88 | edge2.dest_node) and edge1.dest_node.are_all_uris_generic() \ 89 | and edge2.dest_node.are_all_uris_generic(): 90 | output.append( 91 | frozenset([edge1.dest_node.first_uri_if_only(), edge2.dest_node.first_uri_if_only()])) 92 | if edge1.uri.is_generic() and edge1.uri.generic_equal( 93 | edge2.uri): 94 | output.append( 95 | frozenset([edge1.uri, edge2.uri])) 96 | break 97 | if not edge_found: 98 | return False, [] 99 | return True, output 100 | else: 101 | return False, [] 102 | 103 | def __str__(self): 104 | if self.__sorted_str is None: 105 | self.__sorted_str = '-'.join(sorted([item.full_path() for item in self])) 106 | return self.__sorted_str 107 | -------------------------------------------------------------------------------- /output/qald/a.pos: -------------------------------------------------------------------------------- 1 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET PROPN PROPN PROPN PUNCT 2 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET PROPN PROPN PROPN PUNCT 3 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET PROPN PROPN PROPN PUNCT 4 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 5 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 6 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 7 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 8 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 9 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 10 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 11 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 12 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 13 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 14 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 15 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 16 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 17 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 18 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 19 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 20 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 21 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 22 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 23 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 24 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 25 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 26 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 27 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 28 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 29 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 30 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 31 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 32 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 33 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 34 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 35 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 36 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 37 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 38 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 39 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 40 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 41 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 42 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 43 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 44 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 45 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 46 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 47 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 48 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 49 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 50 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 51 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 52 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 53 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 54 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 55 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 56 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 57 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 58 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 59 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 60 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 61 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 62 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 63 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 64 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 65 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 66 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 67 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 68 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 69 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 70 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 71 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 72 | DET NOUN VERB VERB ADP DET SYM NOUN CCONJ DET SYM NOUN PUNCT 73 | -------------------------------------------------------------------------------- /treelstm/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as Var 5 | from . import utils 6 | from . import Constants 7 | 8 | # module for childsumtreelstm 9 | class ChildSumTreeLSTM(nn.Module): 10 | def __init__(self, in_dim, mem_dim, num_classes, criterion, vocab_output): 11 | super(ChildSumTreeLSTM, self).__init__() 12 | self.in_dim = in_dim 13 | self.mem_dim = mem_dim 14 | self.num_classes = num_classes 15 | 16 | self.ix = nn.Linear(self.in_dim, self.mem_dim) 17 | self.ih = nn.Linear(self.mem_dim, self.mem_dim) 18 | 19 | self.fh = nn.Linear(self.mem_dim, self.mem_dim) 20 | self.fx = nn.Linear(self.in_dim, self.mem_dim) 21 | 22 | self.ux = nn.Linear(self.in_dim, self.mem_dim) 23 | self.uh = nn.Linear(self.mem_dim, self.mem_dim) 24 | 25 | self.ox = nn.Linear(self.in_dim, self.mem_dim) 26 | self.oh = nn.Linear(self.mem_dim, self.mem_dim) 27 | 28 | self.criterion = criterion 29 | self.output_module = None 30 | self.vocab_output = vocab_output 31 | 32 | def set_output_module(self, output_module): 33 | self.output_module = output_module 34 | 35 | def node_forward(self, inputs, child_c, child_h): 36 | child_h_sum = F.torch.sum(torch.squeeze(child_h, 1), 0) 37 | 38 | i = F.sigmoid(self.ix(inputs.cuda()) + self.ih(child_h_sum.cuda())) 39 | o = F.sigmoid(self.ox(inputs.cuda()) + self.oh(child_h_sum.cuda())) 40 | u = F.tanh(self.ux(inputs.cuda()) + self.uh(child_h_sum.cuda())) 41 | 42 | # add extra singleton dimension 43 | fx = F.torch.unsqueeze(self.fx(inputs.cuda()), 1) 44 | f = F.torch.cat([self.fh(child_hi.cuda()) + fx for child_hi in child_h], 0) 45 | f = F.sigmoid(f) 46 | 47 | fc = F.torch.squeeze(F.torch.mul(f.cuda(), child_c.cuda()), 1) 48 | 49 | c = F.torch.mul(i.cuda(), u.cuda()) + F.torch.sum(fc.cuda(), 0) 50 | h = F.torch.mul(o.cuda(), F.tanh(c.cuda())) 51 | return c, h 52 | 53 | def forward(self, tree, inputs, training = False): 54 | for idx in range(tree.num_children): 55 | self.forward(tree.children[idx], inputs, training) 56 | 57 | child_c, child_h = self.get_child_states(tree) 58 | tree.state = self.node_forward(inputs[tree.idx], child_c, child_h) 59 | output = self.output_module.forward(tree.state[1], training) 60 | 61 | return output 62 | 63 | def get_child_states(self, tree): 64 | """ 65 | Get c and h of all children 66 | :param tree: 67 | :return: (tuple) 68 | child_c: (num_children, 1, mem_dim) 69 | child_h: (num_children, 1, mem_dim) 70 | """ 71 | if tree.num_children == 0: 72 | child_c = Var(torch.zeros(1, 1, self.mem_dim)) 73 | child_h = Var(torch.zeros(1, 1, self.mem_dim)) 74 | else: 75 | child_c = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) 76 | child_h = Var(torch.Tensor(tree.num_children, 1, self.mem_dim)) 77 | for idx in range(tree.num_children): 78 | child_c[idx] = tree.children[idx].state[0] 79 | child_h[idx] = tree.children[idx].state[1] 80 | return child_c, child_h 81 | 82 | class Classifier(nn.Module): 83 | def __init__(self, mem_dim, num_classes, dropout=False): 84 | super(Classifier, self).__init__() 85 | self.mem_dim = mem_dim 86 | self.num_classes = num_classes 87 | self.dropout = dropout 88 | self.l1 = nn.Linear(self.mem_dim, self.num_classes) 89 | self.logsoftmax = nn.LogSoftmax(dim=1) 90 | 91 | def set_dropout(self, dropout): 92 | self.dropout = dropout 93 | 94 | def forward(self, vec, training=False): 95 | if self.dropout: 96 | out = self.logsoftmax(self.l1(F.dropout(vec, training=training, p=0.2))) 97 | else: 98 | out = self.logsoftmax(self.l1(vec)) 99 | return out 100 | 101 | # putting the whole model together 102 | class TreeLSTM(nn.Module): 103 | def __init__(self, in_dim, mem_dim, num_classes, criterion, vocab_output, dropout=False): 104 | super(TreeLSTM, self).__init__() 105 | self.tree_module = ChildSumTreeLSTM(in_dim, mem_dim, num_classes, criterion, vocab_output) 106 | self.classifier = Classifier(mem_dim, num_classes, dropout) 107 | self.tree_module.set_output_module(self.classifier) 108 | 109 | def set_dropout(self, dropout): 110 | self.classifier.set_dropout(dropout) 111 | 112 | def forward(self, tree, inputs, training = False): 113 | output = self.tree_module.forward(tree, inputs, training) 114 | return output 115 | -------------------------------------------------------------------------------- /common/utility/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging.config 4 | import pickle 5 | 6 | 7 | class PersistanceDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super(PersistanceDict, self).__init__(*args, **kwargs) 10 | 11 | def save(self, file_name): 12 | with open(file_name, 'wb') as f: 13 | pickle.dump(self, f) # , pickle.HIGHEST_PROTOCOL) 14 | 15 | @staticmethod 16 | def load(file_name): 17 | with open(file_name, 'rb') as f: 18 | return pickle.load(f,encoding='latin1') 19 | 20 | 21 | def makedirs(dir): 22 | if not os.path.exists(dir): 23 | os.makedirs(dir) 24 | return None 25 | 26 | 27 | def setup_logging( 28 | default_path='logging.json', 29 | default_level=logging.INFO, 30 | env_key='LOG_CFG' 31 | ): 32 | """Setup logging configuration 33 | 34 | """ 35 | path = default_path 36 | value = os.getenv(env_key, None) 37 | if value: 38 | path = value 39 | if os.path.exists(path): 40 | with open(path, 'rt') as f: 41 | config = json.load(f) 42 | logging.config.dictConfig(config) 43 | else: 44 | # logging.basicConfig(level=default_level) 45 | default_level = logging.WARNING 46 | logging.basicConfig(level=default_level) 47 | 48 | 49 | def closest_string(text, list_of_text): 50 | min = len(text) * 100 51 | idx = -1 52 | for item in list_of_text: 53 | value = __levenshtein(text, list_of_text[item]) 54 | if min > value: 55 | min = value 56 | idx = item 57 | return idx 58 | 59 | 60 | def find_mentions(text, uris): 61 | output = [] 62 | for uri in uris: 63 | s, e, dist = __substring_with_min_levenshtein_distance(str(uri), text) 64 | if dist <= 5: 65 | output.append({"uri": uri, "start": s, "end": e}) 66 | return output 67 | 68 | 69 | def __fuzzy_substring(needle, haystack): 70 | """Calculates the fuzzy match of needle in haystack, 71 | using a modified version of the Levenshtein distance 72 | algorithm. 73 | The function is modified from the levenshtein function 74 | in the bktree module by Adam Hupp""" 75 | m, n = len(needle), len(haystack) 76 | 77 | # base cases 78 | if m == 1: 79 | # return not needle in haystack 80 | row = [len(haystack)] * len(haystack) 81 | row[haystack.find(needle)] = 0 82 | return row 83 | if not n: 84 | return m 85 | 86 | row1 = [0] * (n + 1) 87 | for i in range(0, m): 88 | row2 = [i + 1] 89 | for j in range(0, n): 90 | cost = (needle[i] != haystack[j]) 91 | 92 | row2.append(min(row1[j + 1] + 1, # deletion 93 | row2[j] + 1, # insertion 94 | row1[j] + cost) # substitution 95 | ) 96 | row1 = row2 97 | return row1 98 | 99 | 100 | def __min_farest(values): 101 | return -min((x, -i) for i, x in enumerate(values))[1] 102 | 103 | 104 | def __min_nearest(values): 105 | return min(enumerate(values), key=lambda p: p[1])[0] 106 | 107 | 108 | def __levenshtein(s1, s2): 109 | if len(s1) < len(s2): 110 | return __levenshtein(s2, s1) 111 | 112 | # len(s1) >= len(s2) 113 | if len(s2) == 0: 114 | return len(s1) 115 | 116 | previous_row = range(len(s2) + 1) 117 | for i, c1 in enumerate(s1): 118 | current_row = [i + 1] 119 | for j, c2 in enumerate(s2): 120 | insertions = previous_row[ 121 | j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer 122 | deletions = current_row[j] + 1 # than s2 123 | substitutions = previous_row[j] + (c1 != c2) 124 | current_row.append(min(insertions, deletions, substitutions)) 125 | previous_row = current_row 126 | 127 | return previous_row[-1] 128 | 129 | 130 | def __substring_with_min_levenshtein_distance(n, h): 131 | n = n.lower().replace("_", " ") 132 | h = h.lower() 133 | row = __fuzzy_substring(n, h) 134 | end = min(__min_farest(row), len(h) - 1) 135 | row_rev = __fuzzy_substring(n[::-1], h[::-1]) 136 | start = max(0, len(h) - __min_nearest(row_rev) - 1) 137 | 138 | strip = [" ", "?", ".", ",", "'"] 139 | # stretch the token to be whole word[s] 140 | while h[start] not in strip and start >= 0: 141 | start -= 1 142 | 143 | while h[end - 1] not in strip and end < (len(h) - 1): 144 | end += 1 145 | 146 | # remove invalid chars in head or tail 147 | for i in range(start, end): 148 | if h[start] in strip: 149 | start += 1 150 | else: 151 | break 152 | 153 | for i in range(end, start, -1): 154 | if h[end - 1] in strip: 155 | end -= 1 156 | else: 157 | break 158 | 159 | return start, end, row[end] 160 | -------------------------------------------------------------------------------- /output/qald/a.rels: -------------------------------------------------------------------------------- 1 | det nsubjpass auxpass ROOT prep det nmod pobj cc det compound compound conj punct 2 | det nsubjpass auxpass ROOT prep det nmod pobj cc det compound compound conj punct 3 | det nsubjpass auxpass ROOT prep det nmod pobj cc det compound compound conj punct 4 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 5 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 6 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 7 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 8 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 9 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 10 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 11 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 12 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 13 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 14 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 15 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 16 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 17 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 18 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 19 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 20 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 21 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 22 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 23 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 24 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 25 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 26 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 27 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 28 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 29 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 30 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 31 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 32 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 33 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 34 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 35 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 36 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 37 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 38 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 39 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 40 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 41 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 42 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 43 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 44 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 45 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 46 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 47 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 48 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 49 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 50 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 51 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 52 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 53 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 54 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 55 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 56 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 57 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 58 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 59 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 60 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 61 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 62 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 63 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 64 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 65 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 66 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 67 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 68 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 69 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 70 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 71 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 72 | det nsubjpass auxpass ROOT prep det nmod pobj cc det nmod conj punct 73 | -------------------------------------------------------------------------------- /result_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import re 5 | import numpy as np 6 | 7 | 8 | def analysis(fiename): 9 | with open("output/{}.json".format(fiename), 'r') as data_file: 10 | data = json.load(data_file) 11 | 12 | na = lf = correct = incorrect = no_path = no_answer = 0 13 | p = [] 14 | r=[] 15 | n_list = 0 16 | n_count = 0 17 | n_ask = 0 18 | p_count = [] 19 | p_list = [] 20 | p_ask = [] 21 | r_count = [] 22 | r_list = [] 23 | r_ask = [] 24 | incor = 0 25 | sp = [] 26 | sr = [] 27 | cp=[] 28 | cr=[] 29 | svp=[] 30 | svr=[] 31 | mvp=[] 32 | mvr=[] 33 | 34 | for i in data: 35 | p.append(i['precision']) 36 | r.append(i['recall']) 37 | if i['precision'] == i['recall'] == 0.0: 38 | incor +=1 39 | 40 | if i['answer'] == "correct": 41 | correct +=1 42 | elif i['answer'] == "-Not_Applicable": 43 | na +=1 44 | elif i['answer'] == "-Linker_failed": 45 | lf +=1 46 | elif i['answer'] == "-incorrect": 47 | incorrect +=1 48 | elif i['answer'] == "-without_path": 49 | no_path +=1 50 | elif i['answer'] == "-no_answer": 51 | no_answer +=1 52 | 53 | if 'ASK' in i['query']: 54 | n_ask +=1 55 | p_ask.append(i['precision']) 56 | r_ask.append(i['recall']) 57 | elif 'COUNT(' in i['query']: 58 | n_count+=1 59 | p_count.append(i['precision']) 60 | r_count.append(i['recall']) 61 | else: 62 | n_list +=1 63 | p_list.append(i['precision']) 64 | r_list.append(i['recall']) 65 | 66 | if 'single' in i['features']: 67 | sp.append(i['precision']) 68 | sr.append(i['recall']) 69 | elif 'compound' in i['features']: 70 | cp.append(i['precision']) 71 | cr.append(i['recall']) 72 | 73 | if 'singlevar' in i['features']: 74 | svp.append(i['precision']) 75 | svr.append(i['recall']) 76 | elif 'multivar' in i['features']: 77 | mvp.append(i['precision']) 78 | mvr.append(i['recall']) 79 | 80 | 81 | print("-- Basic Stats --") 82 | print("- Total Questions: %d" % (correct+incorrect+no_path+no_answer+na+lf)) 83 | print("- Correct: %d" % correct) 84 | print("- Incorrect: %d" % incorrect) 85 | print("- No-Path: %d" % no_path) 86 | print("- No-Answer: %d" % no_answer) 87 | print("- Not_Applicable: %d" % na) 88 | print("- Linker_failed: %d" % lf) 89 | print('- Wrong Answer: %d' % incor) 90 | 91 | print('None in precision: ',sum(i is None for i in p)) 92 | print('None in recall: ', sum(i is None for i in r)) 93 | 94 | p = np.array(p, dtype=np.float64) 95 | r = np.array(r, dtype=np.float64) 96 | mp = np.nanmean(p) 97 | mr = np.nanmean(r) 98 | 99 | print("- Precision: %.4f" % mp) 100 | print("- Recall: %.4f" % mr) 101 | print("- F1: %.4f" % ((2*mp*mr)/(mp+mr))) 102 | 103 | p_count = np.array(p_count, dtype=np.float64) 104 | p_list = np.array(p_list, dtype=np.float64) 105 | p_ask = np.array(p_ask, dtype=np.float64) 106 | r_count = np.array(r_count, dtype=np.float64) 107 | r_list = np.array(r_list, dtype=np.float64) 108 | r_ask = np.array(r_ask, dtype=np.float64) 109 | print('List: ', n_list) 110 | a = np.nanmean(p_list) 111 | b = np.nanmean(r_list) 112 | print('precision: %.4f' % a) 113 | print('reacall: %.4f' % b) 114 | print('f1-score: %.4f'% ((2*a*b)/(a+b))) 115 | 116 | print('Count: ', n_count) 117 | a = np.nanmean(p_count) 118 | b = np.nanmean(r_count) 119 | print('precision: %.4f' % a) 120 | print('reacall: %.4f' % b) 121 | print('f1-score: %.4f'% ((2*a*b)/(a+b))) 122 | 123 | print('Ask: ', n_list) 124 | a = np.nanmean(p_ask) 125 | b = np.nanmean(r_ask) 126 | print('precision: %.4f' % a) 127 | print('reacall: %.4f' % b) 128 | print('f1-score: %.4f'% ((2*a*b)/(a+b))) 129 | 130 | sp = np.array(sp, dtype=np.float64) 131 | sr = np.array(sr, dtype=np.float64) 132 | cp=np.array(cp, dtype=np.float64) 133 | cr=np.array(cr, dtype=np.float64) 134 | print('Single: ', len(sp), len(sr)) 135 | a = np.nanmean(sp) 136 | b = np.nanmean(sr) 137 | print('precision: %.4f' % a) 138 | print('reacall: %.4f' % b) 139 | print('f1-score: %.4f'% ((2*a*b)/(a+b))) 140 | 141 | print('Compound: ', len(cp), len(cr)) 142 | a = np.nanmean(cp) 143 | b = np.nanmean(cr) 144 | print('precision: %.4f' % a) 145 | print('reacall: %.4f' % b) 146 | print('f1-score: %.4f'% ((2*a*b)/(a+b))) 147 | 148 | 149 | svp=np.array(svp, dtype=np.float64) 150 | svr=np.array(svr, dtype=np.float64) 151 | mvp=np.array(mvp, dtype=np.float64) 152 | mvr=np.array(mvr, dtype=np.float64) 153 | print('Single Var: ', len(svp), len(svr)) 154 | a = np.nanmean(svp) 155 | b = np.nanmean(svr) 156 | print('precision: %.4f' % a) 157 | print('reacall: %.4f' % b) 158 | print('f1-score: %.4f' % ((2 * a * b) / (a + b))) 159 | 160 | print('Multiple Var: ', len(mvp), len(mvr)) 161 | a = np.nanmean(mvp) 162 | b = np.nanmean(mvr) 163 | print('precision: %.4f' % a) 164 | print('reacall: %.4f' % b) 165 | print('f1-score: %.4f' % ((2 * a * b) / (a + b))) 166 | 167 | 168 | if __name__ == "__main__": 169 | file = "lcquadtestanswer_output" 170 | print('LC-QUAD test: ') 171 | analysis(file) 172 | print('\n'*2) 173 | 174 | file = "qaldanswer_output" 175 | print('QALD-7: ') 176 | analysis(file) 177 | print('\n'*2) 178 | 179 | file = "lcquadanswer_output" 180 | print('LC-QUAD all: ') 181 | analysis(file) 182 | print('\n'*2) 183 | 184 | 185 | -------------------------------------------------------------------------------- /kb/kb.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | from multiprocessing import Pool 3 | from contextlib import closing 4 | import json 5 | import requests 6 | 7 | 8 | def query(args): 9 | endpoint, q, idx = args 10 | payload = {'query': q, 'format': 'application/json'} 11 | try: 12 | query_string = urllib.parse.urlencode(payload) 13 | url = endpoint + '?' + query_string 14 | r = requests.get(url) 15 | except: 16 | return 0, None, idx 17 | return r.status_code, r.json() if r.status_code == 200 else None, idx 18 | 19 | 20 | class KB(object): 21 | def __init__(self, endpoint, default_graph_uri=""): 22 | self.endpoint = endpoint 23 | self.default_graph_uri = default_graph_uri 24 | self.type_uri = "type_uri" 25 | self.server_available = self.check_server() 26 | 27 | def check_server(self): 28 | payload = {'query': 'select distinct ?Concept where {[] a ?Concept} LIMIT 1', 'format': 'application/json'} 29 | try: 30 | query_string = urllib.parse.urlencode(payload) 31 | url = self.endpoint + '?' + query_string 32 | r = requests.get(url) 33 | if r.status_code == 200: 34 | return True 35 | except: 36 | return False 37 | return False 38 | 39 | def query(self, q): 40 | payload = {'query': q, 'format': 'application/json'} 41 | try: 42 | query_string = urllib.parse.urlencode(payload) 43 | url = self.endpoint + '?' + query_string 44 | r = requests.get(url) 45 | except: 46 | return 0, None 47 | 48 | return r.status_code, r.json() if r.status_code == 200 else None 49 | 50 | def sparql_query(self, clauses, return_vars="*", count=False, ask=False): 51 | where = u"WHERE {{ {} }}".format(" .".join(clauses)) 52 | if count: 53 | query = u"{} SELECT COUNT(DISTINCT {}) {}".format(self.query_prefix(), return_vars, where) 54 | elif ask: 55 | query = u"{} ASK {}".format(self.query_prefix(), where) 56 | else: 57 | query = u"{} SELECT DISTINCT {} {}".format(self.query_prefix(), return_vars, where) 58 | 59 | return query 60 | 61 | def query_where(self, clauses, return_vars="*", count=False, ask=False): 62 | query = self.sparql_query(clauses, return_vars, count, ask) 63 | status, response = self.query(query) 64 | if status == 200: 65 | return response 66 | 67 | def parallel_query(self, query_templates): 68 | args = [] 69 | for i in range(len(query_templates)): 70 | args.append( 71 | (self.endpoint, u"{} ASK WHERE {{ {} }}".format(self.query_prefix(), query_templates[i][1]), 72 | query_templates[i][0])) 73 | with closing(Pool(len(query_templates))) as pool: 74 | query_results = pool.map(query, args) 75 | pool.terminate() 76 | results = [] 77 | for i in range(len(query_results)): 78 | if query_results[i][0] == 200: 79 | results.append((query_results[i][2], query_results[i][1]["boolean"])) 80 | return results 81 | 82 | def one_hop_graph(self, entity1_uri, relation_uri, entity2_uri=None): 83 | # print('kb one_hop_graph') 84 | relation_uri = self.uri_to_sparql(relation_uri) 85 | entity1_uri = self.uri_to_sparql(entity1_uri) 86 | if entity2_uri is None: 87 | entity2_uri = "?u1" 88 | else: 89 | entity2_uri = self.uri_to_sparql(entity2_uri) 90 | 91 | query_types = [u"{ent2} {rel} {ent1}", 92 | u"{ent1} {rel} {ent2}", 93 | u"?u1 {type} {rel}"] 94 | where = "" 95 | for i in range(len(query_types)): 96 | where = where + u"UNION {{ values ?m {{ {} }} {{select <1> where {{ {} }} }} }}\n". \ 97 | format(i, 98 | query_types[ 99 | i].format( 100 | rel=relation_uri, 101 | ent1=entity1_uri, 102 | ent2=entity2_uri, 103 | type=self.type_uri, 104 | prefix=self.query_prefix())) 105 | where = where[6:] 106 | query = u"""{prefix} 107 | SELECT DISTINCT ?m WHERE {{ {where} }}""".format(prefix=self.query_prefix(), where=where) 108 | 109 | status, response = self.query(query) 110 | if status == 200 and len(response["results"]["bindings"]) > 0: 111 | output = response["results"]["bindings"] 112 | return output 113 | 114 | def two_hop_graph_template(self, entity1_uri, relation1_uri, entity2_uri, relation2_uri): 115 | # print('kb two_hop_graph_template') 116 | query_types = [[0, u"{ent1} {rel1} {ent2} . ?u1 {rel2} {ent1}"], 117 | [1, u"{ent1} {rel1} {ent2} . {ent1} {rel2} ?u1"], 118 | [2, u"{ent1} {rel1} {ent2} . {ent2} {rel2} ?u1"], 119 | [3, u"{ent1} {rel1} {ent2} . ?u1 {rel2} {ent2}"], 120 | [4, u"{ent1} {rel1} {ent2} . ?u1 {type} {rel2}"]] 121 | output = [[item[0], item[1].format(rel1=relation1_uri, ent1=entity1_uri, 122 | ent2=entity2_uri, rel2=relation2_uri, 123 | type=self.type_uri)] for item in query_types] 124 | # print('entity1_uri: ', entity1_uri) 125 | # print('entity2_uri: ', entity2_uri) 126 | # print('relation1_uri: ', relation1_uri) 127 | # print('relation2_uri: ', relation2_uri) 128 | # print('output: ', output) 129 | return output 130 | 131 | def two_hop_graph(self, entity1_uri, relation1_uri, entity2_uri, relation2_uri): 132 | # print('kb two_hop_graph') 133 | relation1_uri = self.uri_to_sparql(relation1_uri) 134 | relation2_uri = self.uri_to_sparql(relation2_uri) 135 | entity1_uri = self.uri_to_sparql(entity1_uri) 136 | entity2_uri = self.uri_to_sparql(entity2_uri) 137 | 138 | queries = self.two_hop_graph_template(entity1_uri, relation1_uri, entity2_uri, relation2_uri) 139 | output = None 140 | if len(queries) > 0: 141 | output = self.parallel_query(queries) 142 | # print('queries: ', queries) 143 | # print('output: ', output) 144 | return output 145 | 146 | @staticmethod 147 | def shorten_prefix(): 148 | return "" 149 | 150 | @staticmethod 151 | def query_prefix(): 152 | return "" 153 | 154 | @staticmethod 155 | def prefix(): 156 | return "" 157 | 158 | @staticmethod 159 | def parse_uri(input_uri): 160 | pass 161 | 162 | @staticmethod 163 | def uri_to_sparql(input_uri): 164 | return input_uri.uri 165 | -------------------------------------------------------------------------------- /lcquad_answer.py: -------------------------------------------------------------------------------- 1 | from parser.lc_quad_linked import LC_Qaud_Linked 2 | from parser.lc_quad import LC_QaudParser 3 | from common.container.sparql import SPARQL 4 | from common.container.answerset import AnswerSet 5 | from common.graph.graph import Graph 6 | from common.utility.stats import Stats 7 | from common.query.querybuilder import QueryBuilder 8 | import common.utility.utility as utility 9 | from linker.goldLinker import GoldLinker 10 | from linker.earl import Earl 11 | from learning.classifier.svmclassifier import SVMClassifier 12 | import json 13 | import argparse 14 | import logging 15 | import sys 16 | import os 17 | import numpy as np 18 | 19 | 20 | def safe_div(x, y): 21 | if y == 0: 22 | return None 23 | return x / y 24 | 25 | 26 | def qg(linker, kb, parser, qapair, force_gold=True): 27 | logger.info(qapair.sparql) 28 | logger.info(qapair.question.text) 29 | 30 | h1_threshold = 9999999 31 | 32 | # Get Answer from KB online 33 | status, raw_answer_true = kb.query(qapair.sparql.query.replace("https", "http")) 34 | answerset_true = AnswerSet(raw_answer_true, parser.parse_queryresult) 35 | qapair.answerset = answerset_true 36 | 37 | ask_query = "ASK " in qapair.sparql.query 38 | count_query = "COUNT(" in qapair.sparql.query 39 | sort_query = "order by" in qapair.sparql.raw_query.lower() 40 | entities, ontologies = linker.do(qapair, force_gold=force_gold) 41 | 42 | double_relation = False 43 | relation_uris = [u for u in qapair.sparql.uris if u.is_ontology() or u.is_type()] 44 | if len(relation_uris) != len(set(relation_uris)): 45 | double_relation = True 46 | else: 47 | double_relation = False 48 | 49 | print('ask_query: ', ask_query) 50 | print('count_query: ', count_query) 51 | print('double_relation: ', double_relation) 52 | 53 | if entities is None or ontologies is None: 54 | return "-Linker_failed", [] 55 | 56 | graph = Graph(kb) 57 | queryBuilder = QueryBuilder() 58 | 59 | logger.info("start finding the minimal subgraph") 60 | 61 | graph.find_minimal_subgraph(entities, ontologies, double_relation=double_relation, ask_query=ask_query, 62 | sort_query=sort_query, h1_threshold=h1_threshold) 63 | logger.info(graph) 64 | wheres = queryBuilder.to_where_statement(graph, parser.parse_queryresult, ask_query=ask_query, 65 | count_query=count_query, sort_query=sort_query) 66 | 67 | output_where = [{"query": " .".join(item["where"]), "correct": False, "target_var": "?u_0"} for item in wheres] 68 | for item in list(output_where): 69 | logger.info(item["query"]) 70 | if len(wheres) == 0: 71 | return "-without_path", output_where 72 | correct = False 73 | 74 | for idx in range(len(wheres)): 75 | where = wheres[idx] 76 | 77 | if "answer" in where: 78 | answerset = where["answer"] 79 | target_var = where["target_var"] 80 | else: 81 | target_var = "?u_" + str(where["suggested_id"]) 82 | raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query) 83 | answerset = AnswerSet(raw_answer, parser.parse_queryresult) 84 | 85 | output_where[idx]["target_var"] = target_var 86 | sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql) 87 | if (answerset == qapair.answerset) != (sparql == qapair.sparql): 88 | print("error") 89 | 90 | if answerset == qapair.answerset: 91 | correct = True 92 | output_where[idx]["correct"] = True 93 | output_where[idx]["target_var"] = target_var 94 | else: 95 | if target_var == "?u_0": 96 | target_var = "?u_1" 97 | else: 98 | target_var = "?u_0" 99 | raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query) 100 | print("Q_H ",) 101 | # print(raw_answer) 102 | print("Q_") 103 | answerset = AnswerSet(raw_answer, parser.parse_queryresult) 104 | 105 | sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql) 106 | if (answerset == qapair.answerset) != (sparql == qapair.sparql): 107 | print("error") 108 | 109 | if answerset == qapair.answerset: 110 | correct = True 111 | output_where[idx]["correct"] = True 112 | output_where[idx]["target_var"] = target_var 113 | 114 | return "correct" if correct else "-incorrect", output_where 115 | 116 | 117 | if __name__ == "__main__": 118 | logger = logging.getLogger(__name__) 119 | utility.setup_logging() 120 | 121 | ds = LC_Qaud_Linked(path="./data/LC-QUAD/linked_answer.json") 122 | ds.load() 123 | ds.parse() 124 | 125 | if not ds.parser.kb.server_available: 126 | logger.error("Server is not available. Please check the endpoint at: {}".format(ds.parser.kb.endpoint)) 127 | sys.exit(0) 128 | 129 | parser = LC_QaudParser() 130 | kb = parser.kb 131 | 132 | stats = Stats() 133 | linker = GoldLinker() 134 | output_file = 'lcquad_gold' 135 | 136 | tmp = [] 137 | output = [] 138 | na_list = [] 139 | 140 | for qapair in ds.qapairs: 141 | print('='*10) 142 | stats.inc("total") 143 | output_row = {"question": qapair.question.text, 144 | "id": qapair.id, 145 | "query": qapair.sparql.query, 146 | "answer": "", 147 | "features": list(qapair.sparql.query_features()), 148 | "generated_queries": []} 149 | 150 | if qapair.answerset is None or len(qapair.answerset) == 0: 151 | stats.inc("query_no_answer") 152 | output_row["answer"] = "-no_answer" 153 | na_list.append(output_row['id']) 154 | else: 155 | result, where = qg(linker, ds.parser.kb, ds.parser, qapair, False) 156 | stats.inc(result) 157 | output_row["answer"] = result 158 | newwhere = [] 159 | for iwhere in where: 160 | if iwhere not in newwhere: 161 | newwhere.append(iwhere) 162 | output_row["generated_queries"] = newwhere 163 | logger.info(result) 164 | 165 | logger.info(stats) 166 | output.append(output_row) 167 | 168 | if stats["total"] % 100 == 0: 169 | with open("output/{}.json".format(output_file), "w") as data_file: 170 | json.dump(output, data_file, sort_keys=True, indent=4, separators=(',', ': ')) 171 | 172 | with open("output/{}.json".format(output_file), "w") as data_file: 173 | json.dump(output, data_file, sort_keys=True, indent=4, separators=(',', ': ')) 174 | print('stats: ', stats) 175 | 176 | with open('na_list_lcquadgold.txt', 'w') as f: 177 | for i in na_list: 178 | f.write("{}\n".format(i)) 179 | -------------------------------------------------------------------------------- /parser/qald.py: -------------------------------------------------------------------------------- 1 | import json, re 2 | from common.container.qapair import QApair 3 | from common.container.uri import Uri 4 | from common.container.answerrow import AnswerRow 5 | from common.container.answer import Answer 6 | from kb.dbpedia import DBpedia 7 | from parser.answerparser import AnswerParser 8 | from xml.dom import minidom 9 | import sys 10 | 11 | 12 | class Qald: 13 | qald_7 = "./data/QALD/qald-7-train-multilingual.json" 14 | qaldtest_7 = "./data/QALD/qald-7-test-multilingual.json" 15 | 16 | def __init__(self, path): 17 | self.raw_data = [] 18 | self.qapairs = [] 19 | self.path = path 20 | self.parser = QaldParser() 21 | 22 | def load(self, path=None): 23 | if path is None: 24 | path = self.path 25 | if path.endswith("json"): 26 | with open(path, encoding='utf-8') as data_file: 27 | self.raw_data = json.load(data_file) 28 | elif path.endswith("xml"): 29 | with open(path) as data_file: 30 | self.raw_data = minidom.parse(data_file).documentElement 31 | 32 | def extend(self, path): 33 | self.load(path) 34 | self.parse() 35 | 36 | def parse(self): 37 | if self.path.endswith("json"): 38 | self.parse_json() 39 | elif self.path.endswith("xml"): 40 | self.parse_xml() 41 | 42 | def parse_json(self): 43 | parser = QaldParser() 44 | for raw_row in self.raw_data["questions"]: 45 | question = "" 46 | query = "" 47 | if "question" in raw_row: 48 | question = raw_row["question"] 49 | elif "body" in raw_row: 50 | # QALD-5 format 51 | question = raw_row["body"] 52 | if "query" in raw_row: 53 | if isinstance(raw_row["query"], dict): 54 | if "sparql" in raw_row["query"]: 55 | query = raw_row["query"]["sparql"] 56 | else: 57 | query = "" 58 | else: 59 | query = raw_row["query"] 60 | self.qapairs.append(QApair(question, raw_row["answers"], query, raw_row, raw_row["id"], parser)) 61 | 62 | def parse_xml(self): 63 | parser = QaldParser() 64 | data_set = self.raw_data 65 | 66 | raw_rows = data_set.getElementsByTagName("question") 67 | for raw_row in raw_rows: 68 | question = [] 69 | answers = [] 70 | query = "" 71 | question_id = raw_row.getAttribute("id") 72 | 73 | if raw_row.getElementsByTagName("query"): 74 | query = raw_row.getElementsByTagName("query")[0].childNodes[0].data 75 | elif raw_row.getElementsByTagName("pseudoquery"): 76 | query = raw_row.getElementsByTagName("pseudoquery")[0].childNodes[0].data 77 | query = query.replace("\n"," ") 78 | query = re.sub(r" {2,}","",query) 79 | 80 | questions_text = raw_row.getElementsByTagName('string') 81 | questions_keyword = raw_row.getElementsByTagName('keywords') 82 | for i in range(0,len(questions_text)): 83 | lang = questions_text[i].getAttribute("lang") 84 | string = questions_text[i].childNodes 85 | if string: 86 | string = string[0].data 87 | else: 88 | string = "" 89 | if questions_keyword: 90 | keyword = questions_keyword[i].childNodes 91 | if keyword: 92 | keyword = keyword[0].data 93 | else: 94 | keyword = "" 95 | else: 96 | keyword = "" 97 | question.append({u"language":lang, u"string":string, u"keywords": keyword}) 98 | 99 | answer_row = raw_row.getElementsByTagName("answers")[0] 100 | answers_list = answer_row.getElementsByTagName("answer") 101 | for a in range(0,len(answers_list)): 102 | answers.append({u"string": u"{}".format(answers_list[a].childNodes[0].data) }) 103 | self.qapairs.append(QApair(question, answers, query, raw_row, question_id, parser)) 104 | 105 | def print_pairs(self, n=-1): 106 | for item in self.qapairs[0:n]: 107 | print(item) 108 | print("") 109 | 110 | 111 | class QaldParser(AnswerParser): 112 | def __init__(self): 113 | super(QaldParser, self).__init__(DBpedia()) 114 | 115 | def parse_question(self, raw_question): 116 | # print "AA", raw_question 117 | for q in raw_question: 118 | if q["language"] == "en": 119 | return q["string"] 120 | 121 | def parse_sparql(self, raw_query): 122 | if "sparql" in raw_query: 123 | raw_query = raw_query["sparql"] 124 | elif isinstance(raw_query, str) and "where" in raw_query.lower(): 125 | pass 126 | else: 127 | raw_query = "" 128 | if "PREFIX " in raw_query: 129 | # QALD-5 bug! 130 | raw_query = raw_query.replace("htp:/w.", "http://www.") 131 | raw_query = raw_query.replace("htp:/dbpedia.", "http://dbpedia.") 132 | 133 | for item in re.findall("PREFIX [^:]*: <[^>]*>", raw_query): 134 | prefix = item[7:item.find(" ", 9)] 135 | uri = item[item.find("<"):-1] 136 | raw_query = raw_query.replace(prefix, uri) 137 | idx = raw_query.find("WHERE") 138 | idx2 = raw_query[:idx - 1].rfind(">") 139 | raw_query = raw_query[idx2 + 1:] 140 | for uri in set(re.findall('<[^ ]+', raw_query)): 141 | if uri[-1] != '>': 142 | raw_query = raw_query.replace(uri, uri + ">") 143 | 144 | uris = [Uri(raw_uri, self.kb.parse_uri) for raw_uri in re.findall('<[^>]*>', raw_query)] 145 | supported = not any(substring in raw_query for substring in ["UNION", "FILTER", "OFFSET", "HAVING", "LIMIT"]) 146 | return raw_query, supported, uris 147 | 148 | def parse_answerset(self, raw_answers): 149 | if len(raw_answers) == 0: 150 | return [] 151 | elif len(raw_answers) == 1: 152 | return self.parse_queryresult(raw_answers[0]) 153 | else: 154 | result = [] 155 | for item in raw_answers: 156 | result.append( 157 | AnswerRow(item["string"], 158 | lambda x: [Answer("uri", x, lambda t, y: ("uri", Uri(y, self.kb.parse_uri)))])) 159 | 160 | return result 161 | 162 | def parse_answerrow(self, raw_answerrow): 163 | answers = [Answer(raw_answerrow["AnswerType"], raw_answerrow, self.parse_answer)] 164 | return answers 165 | 166 | def parse_answer(self, answer_type, raw_answer): 167 | if answer_type == "boolean": 168 | return answer_type, str(raw_answer) 169 | else: 170 | if not answer_type in raw_answer: 171 | answer_type = "\"{}\"".format(answer_type) 172 | return raw_answer[answer_type]["type"], Uri(raw_answer[answer_type]["value"], self.kb.parse_uri) 173 | -------------------------------------------------------------------------------- /kb/dbpedia.py: -------------------------------------------------------------------------------- 1 | from kb.kb import KB 2 | import config 3 | from pybloom_live import BloomFilter, ScalableBloomFilter 4 | import os 5 | 6 | 7 | class DBpedia(KB): 8 | def __init__(self, endpoint=config.config['general']['dbpedia']['endpoint'], 9 | one_hop_bloom_file=config.config['general']['dbpedia']['one_hop_bloom_file'], 10 | two_hop_bloom_file=config.config['general']['dbpedia']['two_hop_bloom_file']): 11 | super(DBpedia, self).__init__(endpoint) 12 | self.type_uri = "" 13 | if os.path.exists(one_hop_bloom_file): 14 | with open(one_hop_bloom_file, 'rb') as bloom_file: 15 | self.one_hop_bloom = BloomFilter.fromfile(bloom_file) 16 | else: 17 | self.one_hop_bloom = None 18 | self.two_hop_bloom_file = two_hop_bloom_file 19 | 20 | if os.path.exists(two_hop_bloom_file): 21 | self.two_hop_bloom = dict() 22 | for item in [True, False]: 23 | file_path = two_hop_bloom_file.replace('spo2', 'spo2' + str(item)) 24 | if os.path.exists(file_path): 25 | with open(file_path) as bloom_file: 26 | self.two_hop_bloom[item] = ScalableBloomFilter.fromfile(bloom_file) 27 | else: 28 | self.two_hop_bloom[item] = ScalableBloomFilter(mode=ScalableBloomFilter.LARGE_SET_GROWTH) 29 | else: 30 | self.two_hop_bloom = None 31 | self.two_hop_bloom_counter = 0 32 | 33 | def bloom_query(self, filters): 34 | found = True 35 | for item in filters: 36 | bloom_filter = item.replace("<", "").replace(">", "") 37 | if bloom_filter not in self.one_hop_bloom: 38 | found = False 39 | return found 40 | 41 | def one_hop_graph(self, entity1_uri, relation_uri, entity2_uri=None): 42 | if self.one_hop_bloom is not None: 43 | # print('self.one_hop_bloom is not None') 44 | relation_uri = self.uri_to_sparql(relation_uri) 45 | entity1_uri = self.uri_to_sparql(entity1_uri) 46 | if entity2_uri is None: 47 | query_types = [[u"{rel}:{ent1}"], 48 | [u"{ent1}:{rel}"], 49 | [u"{type}:{rel}"]] 50 | else: 51 | entity2_uri = self.uri_to_sparql(entity2_uri) 52 | query_types = [[u"{ent2}:{rel}", u"{rel}:{ent1}"], 53 | [u"{ent1}:{rel}", u"{rel}:{ent2}"], 54 | [u"{type}:{rel}"]] 55 | results = [] 56 | for i in range(len(query_types)): 57 | if self.bloom_query( 58 | [item.format(rel=relation_uri, ent1=entity1_uri, ent2=entity2_uri, type=self.type_uri) for item 59 | in query_types[i]]): 60 | results.append({"m": {"value": i}}) 61 | return results 62 | else: 63 | # print('self.one_hop_bloom is None') 64 | return super(DBpedia, self).one_hop_graph(entity1_uri, relation_uri, entity2_uri) 65 | 66 | def two_hop_graph(self, entity1_uri, relation1_uri, entity2_uri, relation2_uri): 67 | if self.two_hop_bloom is not None: 68 | # print('self.two_hop_bloom is not None') 69 | relation1_uri = self.uri_to_sparql(relation1_uri) 70 | relation2_uri = self.uri_to_sparql(relation2_uri) 71 | entity1_uri = self.uri_to_sparql(entity1_uri) 72 | entity2_uri = self.uri_to_sparql(entity2_uri) 73 | 74 | queries = self.two_hop_graph_template(entity1_uri, relation1_uri, entity2_uri, relation2_uri) 75 | output = [] 76 | 77 | for query in queries: 78 | for item in [True, False]: 79 | if query[1] in self.two_hop_bloom[item]: 80 | output.append([query[0], item]) 81 | break 82 | 83 | if len(queries) != len(output): 84 | output = super(DBpedia, self).parallel_query(queries) 85 | 86 | for idx in range(len(output)): 87 | self.two_hop_bloom[output[idx][1]].add(queries[idx][1]) 88 | self.two_hop_bloom_counter += 1 89 | 90 | if self.two_hop_bloom_counter > 100: 91 | self.two_hop_bloom_counter = 0 92 | for item in [True, False]: 93 | file_path = self.two_hop_bloom_file.replace('spo2', 'spo2' + str(item)) 94 | with open(file_path, 'w') as bloom_file: 95 | self.two_hop_bloom[item].tofile(bloom_file) 96 | return output 97 | else: 98 | # print('self.two_hop_bloom is None') 99 | return super(DBpedia, self).two_hop_graph(entity1_uri, relation1_uri, entity2_uri, relation2_uri) 100 | 101 | def two_hop_graph_template(self, entity1_uri, relation1_uri, entity2_uri, relation2_uri): 102 | query_types = [[0, u"{ent1} {rel1} {ent2} . ?u1 {rel2} {ent1}", u"{rel2}:{ent1}"], 103 | [1, u"{ent1} {rel1} {ent2} . {ent1} {rel2} ?u1", u"{ent1}:{rel2}"], 104 | [2, u"{ent1} {rel1} {ent2} . {ent2} {rel2} ?u1", u"{ent2}:{rel2}"], 105 | [3, u"{ent1} {rel1} {ent2} . ?u1 {rel2} {ent2}", u"{rel2}:{ent2}"], 106 | [4, u"{ent1} {rel1} {ent2} . ?u1 {type} {rel2}", u"{type}:{rel2}"]] 107 | if self.one_hop_bloom is not None: 108 | for item in query_types: 109 | item.append(item[2].format(rel1=relation1_uri, ent1=entity1_uri, 110 | ent2=entity2_uri, rel2=relation2_uri, 111 | type=self.type_uri)) 112 | filtered_query_types = [] 113 | if self.one_hop_bloom is not None: 114 | for item in query_types: 115 | if ("?" in item[3]) or self.bloom_query([item[3]]): 116 | filtered_query_types.append(item) 117 | 118 | output = [[item[0], item[1].format(rel1=relation1_uri, ent1=entity1_uri, 119 | ent2=entity2_uri, rel2=relation2_uri, 120 | type=self.type_uri)] for item in filtered_query_types] 121 | return output 122 | else: 123 | return super(DBpedia, self).two_hop_graph_template(entity1_uri, relation1_uri, entity2_uri, relation2_uri) 124 | 125 | @staticmethod 126 | def parse_uri(input_uri): 127 | if isinstance(input_uri, bool): 128 | return "bool", input_uri 129 | raw_uri = input_uri.strip("<>") 130 | if raw_uri.find("/resource/") >= 0: 131 | return "?s", raw_uri 132 | elif raw_uri.find("/ontology/") >= 0: 133 | return "?o", raw_uri 134 | elif raw_uri.find("/property/") >= 0: 135 | return "?p", raw_uri 136 | elif raw_uri.find("rdf-syntax-ns#type") >= 0: 137 | return "?t", raw_uri 138 | elif raw_uri.startswith("?"): 139 | return "g", raw_uri 140 | else: 141 | return "?u", raw_uri 142 | 143 | @staticmethod 144 | def uri_to_sparql(input_uri): 145 | if input_uri.uri_type == "g": 146 | return input_uri.uri 147 | return u"<{}>".format(input_uri.uri) 148 | -------------------------------------------------------------------------------- /question_type_anlaysis.py: -------------------------------------------------------------------------------- 1 | from lcquad_test import Orchestrator 2 | from parser.lc_quad import LC_QaudParser 3 | from learning.classifier.svmclassifier import SVMClassifier 4 | from parser.qald import Qald 5 | from parser.lc_quad_linked import LC_Qaud_Linked 6 | from sklearn.metrics import classification_report 7 | from sklearn.metrics import accuracy_score 8 | import sys 9 | import numpy as np 10 | import json 11 | from sklearn.metrics import confusion_matrix 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = LC_QaudParser() 17 | classifier1 = SVMClassifier('./output/question_type_classifier/svm.model') 18 | classifier2 = SVMClassifier('./output/double_relation_classifier/svm.model') 19 | query_builder = Orchestrator(None, classifier1, classifier2, parser, None, auto_train=True) 20 | 21 | print("train_question_classifier") 22 | 23 | scores = query_builder.train_question_classifier(file_path="./data/LC-QUAD/data.json", test_size=0.8) 24 | print(scores) 25 | y_pred = query_builder.question_classifier.predict(query_builder.X_test) 26 | print(accuracy_score(query_builder.y_test, y_pred)) 27 | print(classification_report(query_builder.y_test, y_pred, digits=3)) 28 | 29 | ds = LC_Qaud_Linked(path="./data/LC-QUAD/linked_test.json") 30 | ds.load() 31 | ds.parse() 32 | 33 | lcquad = [] 34 | lc_y = [] 35 | for qapair in ds.qapairs: 36 | lcquad.append(qapair.question.text) 37 | if "COUNT(" in qapair.sparql.query: 38 | lc_y.append(2) 39 | elif "ASK" in qapair.sparql.query: 40 | lc_y.append(1) 41 | else: 42 | lc_y.append(0) 43 | 44 | lc_y = np.array(lc_y) 45 | print('LIST: ', sum(lc_y==0)) 46 | print('ASK: ', sum(lc_y == 1)) 47 | print('COUNT: ', sum(lc_y == 2)) 48 | np.savetxt('lcquad_question_type.csv', lc_y, delimiter=',') 49 | 50 | lc_pred = query_builder.question_classifier.predict(lcquad) 51 | print('LC-QUAD question_classifier') 52 | print(accuracy_score(lc_y, lc_pred)) 53 | print(classification_report(lc_y, lc_pred, digits=4)) 54 | 55 | classes = ['List', 'Count', 'Boolean'] 56 | cm = confusion_matrix(lc_y, lc_pred) 57 | print('Before Normalization') 58 | print(cm) 59 | 60 | print('Accuracy by class: ') 61 | c_acc = cm.diagonal() / cm.sum(axis=1) 62 | print(c_acc) 63 | 64 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 65 | 66 | print('After Normalization') 67 | print(cm) 68 | 69 | fig, ax = plt.subplots() 70 | im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) 71 | ax.figure.colorbar(im, ax=ax) 72 | # We want to show all ticks... 73 | ax.set(xticks=np.arange(cm.shape[1]), 74 | yticks=np.arange(cm.shape[0]), 75 | # ... and label them with the respective list entries 76 | xticklabels=classes, yticklabels=classes, 77 | ylabel='True label', 78 | xlabel='Predicted label') 79 | 80 | # Rotate the tick labels and set their alignment. 81 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 82 | rotation_mode="anchor") 83 | 84 | # Loop over data dimensions and create text annotations. 85 | fmt = '.2f' 86 | thresh = cm.max() / 2. 87 | for i in range(cm.shape[0]): 88 | for j in range(cm.shape[1]): 89 | ax.text(j, i, format(cm[i, j], fmt), 90 | ha="center", va="center", 91 | color="white" if cm[i, j] > thresh else "black") 92 | fig.tight_layout() 93 | plt.savefig('confusion_matrix_lcquad.png') 94 | 95 | q_ds = Qald(Qald.qald_7) 96 | q_ds.load() 97 | q_ds.parse() 98 | 99 | qald = [] 100 | q_y = [] 101 | for qapair in q_ds.qapairs: 102 | qald.append(qapair.question.text) 103 | if "COUNT(" in qapair.sparql.query: 104 | q_y.append(2) 105 | elif "ASK" in qapair.sparql.query: 106 | q_y.append(1) 107 | x = ascii(qapair.sparql.query.replace('\n', ' ').replace('\t', ' ')) 108 | print(x) 109 | else: 110 | q_y.append(0) 111 | 112 | q_y = np.array(q_y) 113 | print('LIST: ', sum(q_y==0)) 114 | print('ASK: ', sum(q_y == 1)) 115 | print('COUNT: ', sum(q_y == 2)) 116 | np.savetxt('qald_question_type.csv', q_y, delimiter=',') 117 | 118 | q_pred = query_builder.question_classifier.predict(qald) 119 | print('QALD question_classifier') 120 | print(accuracy_score(q_y, q_pred)) 121 | print(classification_report(q_y, q_pred, digits=4)) 122 | 123 | classes = ['List', 'Count', 'Boolean'] 124 | cm = confusion_matrix(q_y, q_pred) 125 | print('Before Normalization') 126 | print(cm) 127 | 128 | print('Accuracy by class: ') 129 | c_acc = cm.diagonal() / cm.sum(axis=1) 130 | print(c_acc) 131 | 132 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 133 | 134 | print('After Normalization') 135 | print(cm) 136 | 137 | fig, ax = plt.subplots() 138 | im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) 139 | ax.figure.colorbar(im, ax=ax) 140 | # We want to show all ticks... 141 | ax.set(xticks=np.arange(cm.shape[1]), 142 | yticks=np.arange(cm.shape[0]), 143 | # ... and label them with the respective list entries 144 | xticklabels=classes, yticklabels=classes, 145 | ylabel='True label', 146 | xlabel='Predicted label') 147 | 148 | # Rotate the tick labels and set their alignment. 149 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 150 | rotation_mode="anchor") 151 | 152 | # Loop over data dimensions and create text annotations. 153 | fmt = '.2f' 154 | thresh = cm.max() / 2. 155 | for i in range(cm.shape[0]): 156 | for j in range(cm.shape[1]): 157 | ax.text(j, i, format(cm[i, j], fmt), 158 | ha="center", va="center", 159 | color="white" if cm[i, j] > thresh else "black") 160 | fig.tight_layout() 161 | plt.savefig('confusion_matrix_qald.png') 162 | 163 | ds = LC_Qaud_Linked(path="./data/LC-QUAD/linked_answer.json") 164 | ds.load() 165 | ds.parse() 166 | 167 | lcquad = [] 168 | lc_y = [] 169 | for qapair in ds.qapairs: 170 | lcquad.append(qapair.question.text) 171 | if "COUNT(" in qapair.sparql.query: 172 | lc_y.append(2) 173 | elif "ASK" in qapair.sparql.query: 174 | lc_y.append(1) 175 | else: 176 | lc_y.append(0) 177 | 178 | lc_y = np.array(lc_y) 179 | print('LIST: ', sum(lc_y==0)) 180 | print('ASK: ', sum(lc_y == 1)) 181 | print('COUNT: ', sum(lc_y == 2)) 182 | np.savetxt('lcquad_question_type_all.csv', lc_y, delimiter=',') 183 | 184 | lc_pred = query_builder.question_classifier.predict(lcquad) 185 | print('LC-QUAD question_classifier') 186 | print(accuracy_score(lc_y, lc_pred)) 187 | print(classification_report(lc_y, lc_pred, digits=4)) 188 | 189 | classes = ['List', 'Count', 'Boolean'] 190 | cm = confusion_matrix(lc_y, lc_pred) 191 | print('Before Normalization') 192 | print(cm) 193 | 194 | print('Accuracy by class: ') 195 | c_acc = cm.diagonal() / cm.sum(axis=1) 196 | print(c_acc) 197 | 198 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 199 | 200 | print('After Normalization') 201 | print(cm) 202 | 203 | fig, ax = plt.subplots() 204 | im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) 205 | ax.figure.colorbar(im, ax=ax) 206 | # We want to show all ticks... 207 | ax.set(xticks=np.arange(cm.shape[1]), 208 | yticks=np.arange(cm.shape[0]), 209 | # ... and label them with the respective list entries 210 | xticklabels=classes, yticklabels=classes, 211 | ylabel='True label', 212 | xlabel='Predicted label') 213 | 214 | # Rotate the tick labels and set their alignment. 215 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 216 | rotation_mode="anchor") 217 | 218 | # Loop over data dimensions and create text annotations. 219 | fmt = '.2f' 220 | thresh = cm.max() / 2. 221 | for i in range(cm.shape[0]): 222 | for j in range(cm.shape[1]): 223 | ax.text(j, i, format(cm[i, j], fmt), 224 | ha="center", va="center", 225 | color="white" if cm[i, j] > thresh else "black") 226 | fig.tight_layout() 227 | plt.savefig('confusion_matrix_lcquad_all.png') 228 | -------------------------------------------------------------------------------- /learning/treelstm/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import random 6 | import logging 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.autograd import Variable as Var 13 | 14 | import sys 15 | # IMPORT CONSTANTS 16 | import Constants 17 | # NEURAL NETWORK MODULES/LAYERS 18 | from model import * 19 | # DATA HANDLING CLASSES 20 | from tree import Tree 21 | from vocab import Vocab 22 | # DATASET CLASS FOR SICK DATASET 23 | from dataset import QGDataset 24 | # METRICS CLASS FOR EVALUATION 25 | from metrics import Metrics 26 | # UTILITY FUNCTIONS 27 | from utils import load_word_vectors, build_vocab 28 | # CONFIG PARSER 29 | from learning.treelstm.config import parse_args 30 | # TRAIN AND TEST HELPER FUNCTIONS 31 | from trainer import Trainer 32 | import datetime 33 | from fastText import load_model 34 | 35 | 36 | def main(): 37 | global args 38 | args = parse_args() 39 | # global logger 40 | logger = logging.getLogger(__name__) 41 | logger.setLevel(logging.DEBUG) 42 | formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s") 43 | # file logger 44 | fh = logging.FileHandler(os.path.join(args.save, args.expname) + '.log', mode='w') 45 | fh.setLevel(logging.INFO) 46 | fh.setFormatter(formatter) 47 | logger.addHandler(fh) 48 | # console logger 49 | ch = logging.StreamHandler() 50 | ch.setLevel(logging.DEBUG) 51 | ch.setFormatter(formatter) 52 | logger.addHandler(ch) 53 | # argument validation 54 | args.cuda = args.cuda and torch.cuda.is_available() 55 | if args.sparse and args.wd != 0: 56 | logger.error('Sparsity and weight decay are incompatible, pick one!') 57 | exit() 58 | logger.debug(args) 59 | args.data = 'learning/treelstm/data/lc_quad/' 60 | args.save = 'learning/treelstm/checkpoints/' 61 | torch.manual_seed(args.seed) 62 | random.seed(args.seed) 63 | if args.cuda: 64 | torch.cuda.manual_seed(args.seed) 65 | torch.backends.cudnn.benchmark = True 66 | if not os.path.exists(args.save): 67 | os.makedirs(args.save) 68 | 69 | train_dir = os.path.join(args.data, 'train/') 70 | dev_dir = os.path.join(args.data, 'dev/') 71 | test_dir = os.path.join(args.data, 'test/') 72 | 73 | # write unique words from all token files 74 | 75 | dataset_vocab_file = os.path.join(args.data, 'dataset.vocab') 76 | if not os.path.isfile(dataset_vocab_file): 77 | token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]] 78 | token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]] 79 | token_files = token_files_a + token_files_b 80 | dataset_vocab_file = os.path.join(args.data, 'dataset.vocab') 81 | build_vocab(token_files, dataset_vocab_file) 82 | 83 | # get vocab object from vocab file previously written 84 | vocab = Vocab(filename=dataset_vocab_file, 85 | data=[Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]) 86 | logger.debug('==> Dataset vocabulary size : %d ' % vocab.size()) 87 | 88 | # load dataset splits 89 | train_file = os.path.join(args.data, 'dataset_train.pth') 90 | if os.path.isfile(train_file): 91 | train_dataset = torch.load(train_file) 92 | else: 93 | train_dataset = QGDataset(train_dir, vocab, args.num_classes) 94 | torch.save(train_dataset, train_file) 95 | logger.debug('==> Size of train data : %d ' % len(train_dataset)) 96 | dev_file = os.path.join(args.data, 'dataset_dev.pth') 97 | if os.path.isfile(dev_file): 98 | dev_dataset = torch.load(dev_file) 99 | else: 100 | dev_dataset = QGDataset(dev_dir, vocab, args.num_classes) 101 | torch.save(dev_dataset, dev_file) 102 | logger.debug('==> Size of dev data : %d ' % len(dev_dataset)) 103 | test_file = os.path.join(args.data, 'dataset_test.pth') 104 | if os.path.isfile(test_file): 105 | test_dataset = torch.load(test_file) 106 | else: 107 | test_dataset = QGDataset(test_dir, vocab, args.num_classes) 108 | torch.save(test_dataset, test_file) 109 | logger.debug('==> Size of test data : %d ' % len(test_dataset)) 110 | 111 | similarity = DASimilarity(args.mem_dim, args.hidden_dim, args.num_classes) 112 | # if args.sim == "cos": 113 | # similarity = CosSimilarity(1) 114 | # else: 115 | # similarity = DASimilarity(args.mem_dim, args.hidden_dim, args.num_classes, dropout=True) 116 | 117 | # initialize model, criterion/loss_function, optimizer 118 | model = SimilarityTreeLSTM( 119 | vocab.size(), 120 | args.input_dim, 121 | args.mem_dim, 122 | similarity, 123 | args.sparse) 124 | criterion = nn.KLDivLoss() # nn.HingeEmbeddingLoss() 125 | 126 | if args.cuda: 127 | model.cuda(), criterion.cuda() 128 | else: 129 | torch.set_num_threads(4) 130 | logger.info("number of available cores: {}".format(torch.get_num_threads())) 131 | if args.optim == 'adam': 132 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 133 | elif args.optim == 'adagrad': 134 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wd) 135 | elif args.optim == 'sgd': 136 | optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd) 137 | metrics = Metrics(args.num_classes) 138 | 139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.25) 140 | 141 | # for words common to dataset vocab and GLOVE, use GLOVE vectors 142 | # for other words in dataset vocab, use random normal vectors 143 | emb_file = os.path.join(args.data, 'dataset_embed.pth') 144 | if os.path.isfile(emb_file): 145 | emb = torch.load(emb_file) 146 | else: 147 | EMBEDDING_DIM = 300 148 | emb = torch.zeros(vocab.size(), EMBEDDING_DIM, dtype=torch.float) 149 | fasttext_model = load_model("data/fasttext/wiki.en.bin") 150 | print('Use Fasttext Embedding') 151 | for word in vocab.labelToIdx.keys(): 152 | word_vector = fasttext_model.get_word_vector(word) 153 | if word_vector.all() != None and len(word_vector) == EMBEDDING_DIM: 154 | emb[vocab.getIndex(word)] = torch.Tensor(word_vector) 155 | else: 156 | emb[vocab.getIndex(word)] = torch.Tensor(EMBEDDING_DIM).uniform_(-1, 1) 157 | # # load glove embeddings and vocab 158 | # args.glove = 'learning/treelstm/data/glove/' 159 | # print('Use Glove Embedding') 160 | # glove_vocab, glove_emb = load_word_vectors(os.path.join(args.glove, 'glove.840B.300d')) 161 | # logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size()) 162 | # emb = torch.Tensor(vocab.size(), glove_emb.size(1)).normal_(-0.05, 0.05) 163 | # # zero out the embeddings for padding and other special words if they are absent in vocab 164 | # for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]): 165 | # emb[idx].zero_() 166 | # for word in vocab.labelToIdx.keys(): 167 | # if glove_vocab.getIndex(word): 168 | # emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)] 169 | torch.save(emb, emb_file) 170 | # plug these into embedding matrix inside model 171 | if args.cuda: 172 | emb = emb.cuda() 173 | model.emb.weight.data.copy_(emb) 174 | 175 | checkpoint_filename = '%s.pt' % os.path.join(args.save, args.expname) 176 | if args.mode == "test": 177 | checkpoint = torch.load(checkpoint_filename) 178 | model.load_state_dict(checkpoint['model']) 179 | args.epochs = 1 180 | 181 | # create trainer object for training and testing 182 | trainer = Trainer(args, model, criterion, optimizer) 183 | 184 | for epoch in range(args.epochs): 185 | if args.mode == "train": 186 | scheduler.step() 187 | 188 | train_loss = trainer.train(train_dataset) 189 | train_loss, train_pred = trainer.test(train_dataset) 190 | logger.info( 191 | '==> Epoch {}, Train \tLoss: {} {}'.format(epoch, train_loss, 192 | metrics.all(train_pred, train_dataset.labels))) 193 | checkpoint = {'model': trainer.model.state_dict(), 'optim': trainer.optimizer, 194 | 'args': args, 'epoch': epoch, 'scheduler': scheduler} 195 | checkpoint_filename = '%s.pt' % os.path.join(args.save, 196 | args.expname + ',epoch={},train_loss={}'.format(epoch + 1, 197 | train_loss)) 198 | torch.save(checkpoint, checkpoint_filename) 199 | 200 | dev_loss, dev_pred = trainer.test(dev_dataset) 201 | test_loss, test_pred = trainer.test(test_dataset) 202 | logger.info( 203 | '==> Epoch {}, Dev \tLoss: {} {}'.format(epoch, dev_loss, metrics.all(dev_pred, dev_dataset.labels))) 204 | logger.info( 205 | '==> Epoch {}, Test \tLoss: {} {}'.format(epoch, test_loss, metrics.all(test_pred, test_dataset.labels))) 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /data/LC-QUAD/templates.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 1, 4 | "n_entities": 1, 5 | "template": " SELECT DISTINCT ?uri WHERE {?uri <%(e_to_e_out)s> <%(e_out)s> } ", 6 | "type": "vanilla" 7 | }, 8 | { 9 | "id": 301, 10 | "n_entities": 1, 11 | "template": " SELECT DISTINCT ?uri WHERE {?uri <%(e_to_e_out)s> <%(e_out)s> . ?uri rdf:type class } ", 12 | "type": "vanilla" 13 | }, 14 | { 15 | "id": 2, 16 | "n_entities": 1, 17 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in)s> <%(e_in_to_e)s> ?uri } ", 18 | "type": "vanilla" 19 | }, 20 | { 21 | "id": 302, 22 | "n_entities": 1, 23 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in)s> <%(e_in_to_e)s> ?uri } . ?uri rdf:type class ", 24 | "type": "vanilla" 25 | }, 26 | { 27 | "id": 3, 28 | "n_entities": 1, 29 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> ?uri . ?x rdf:type class} ", 30 | "type": "vanilla" 31 | }, 32 | { 33 | "id": 303, 34 | "n_entities": 1, 35 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> ?uri } ", 36 | "type": "vanilla" 37 | }, 38 | { 39 | "id": 5, 40 | "n_entities": 1, 41 | "template": " SELECT DISTINCT ?uri WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri } ", 42 | "type": "vanilla" 43 | }, 44 | { 45 | "id": 305, 46 | "n_entities": 1, 47 | "template": " SELECT DISTINCT ?uri WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri . ?uri rdf:type class} ", 48 | "type": "vanilla" 49 | }, 50 | { 51 | "id": 6, 52 | "n_entities": 1, 53 | "template": "SELECT DISTINCT ?uri WHERE { ?x <%(e_out_to_e_out_out)s> <%(e_out_out)s> . ?uri <%(e_to_e_out)s> ?x } ", 54 | "type": "vanilla" 55 | }, 56 | { 57 | "id": 306, 58 | "n_entities": 1, 59 | "template": "SELECT DISTINCT ?uri WHERE { ?x <%(e_out_to_e_out_out)s> <%(e_out_out)s> . ?uri <%(e_to_e_out)s> ?x } . ?uri rdf:type class", 60 | "type": "vanilla" 61 | }, 62 | { 63 | "id": 7, 64 | "n_entities": 2, 65 | "template": " SELECT DISTINCT ?uri WHERE { ?uri <%(e_to_e_out)s> <%(e_out_1)s> . ?uri <%(e_to_e_out)s> <%(e_out_2)s>} ", 66 | "type": "vanilla" 67 | }, 68 | { 69 | "id": 307, 70 | "n_entities": 2, 71 | "template": " SELECT DISTINCT ?uri WHERE { ?uri <%(e_to_e_out)s> <%(e_out_1)s> . ?uri <%(e_to_e_out)s> <%(e_out_2)s>} ", 72 | "type": "vanilla" 73 | }, 74 | { 75 | "id": 8, 76 | "n_entities": 2, 77 | "template": " SELECT DISTINCT ?uri WHERE {?uri <%(e_to_e_out_1)s> <%(e_out_1)s> . ?uri <%(e_to_e_out_2)s> <%(e_out_2)s> } ", 78 | "type": "vanilla" 79 | }, 80 | { 81 | "id": 308, 82 | "n_entities": 2, 83 | "template": " SELECT DISTINCT ?uri WHERE {?uri <%(e_to_e_out_1)s> <%(e_out_1)s> . ?uri <%(e_to_e_out_2)s> <%(e_out_2)s> . ?uri rdf:type class} ", 84 | "type": "vanilla" 85 | }, 86 | { 87 | "id": 9, 88 | "n_entities": 1, 89 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> ?uri}", 90 | "type": "vanilla" 91 | }, 92 | { 93 | "id": 309, 94 | "n_entities": 1, 95 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> . ?uri}", 96 | "type": "vanilla" 97 | }, 98 | { 99 | "id": 11, 100 | "n_entities": 1, 101 | "template": " SELECT DISTINCT ?uri WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri .?x rdf:type class}", 102 | "type": "vanilla" 103 | }, 104 | { 105 | "id": 311, 106 | "n_entities": 1, 107 | "template": " SELECT DISTINCT ?uri WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri .?x rdf:type class}", 108 | "type": "vanilla" 109 | }, 110 | { 111 | "id": 61, 112 | "n_entities": 1, 113 | "template": " SELECT DISTINCT ?uri WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri }", 114 | "type": "vanilla" 115 | }, 116 | { 117 | "id": 15, 118 | "n_entities": 2, 119 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_1)s> <%(e_in_to_e)s> ?uri. <%(e_in_2)s> <%(e_in_to_e)s> ?uri} ", 120 | "type": "vanilla" 121 | }, 122 | { 123 | "id": 315, 124 | "n_entities": 2, 125 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_1)s> <%(e_in_to_e)s> ?uri. <%(e_in_2)s> <%(e_in_to_e)s> ?uri} . ?uri rdf:type class", 126 | "type": "vanilla" 127 | }, 128 | { 129 | "id": 16, 130 | "n_entities": 2, 131 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_1)s> <%(e_in_to_e_1)s> ?uri. <%(e_in_2)s> <%(e_in_to_e_2)s> ?uri} ", 132 | "type": "vanilla" 133 | }, 134 | { 135 | "id": 316, 136 | "n_entities": 2, 137 | "template": " SELECT DISTINCT ?uri WHERE { <%(e_in_1)s> <%(e_in_to_e_1)s> ?uri. <%(e_in_2)s> <%(e_in_to_e_2)s> ?uri} . ?uri rdf:type class", 138 | "type": "vanilla" 139 | }, 140 | { 141 | "id": 101, 142 | "n_entities": 1, 143 | "template": " SELECT DISTINCT COUNT(?uri) WHERE {?uri <%(e_to_e_out)s> <%(e_out)s> } ", 144 | "type": "count" 145 | }, 146 | { 147 | "id": 401, 148 | "n_entities": 1, 149 | "template": " SELECT DISTINCT COUNT(?uri) WHERE {?uri <%(e_to_e_out)s> <%(e_out)s> . ?uri rdf:type class} ", 150 | "type": "count" 151 | }, 152 | { 153 | "id": 102, 154 | "n_entities": 1, 155 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { <%(e_in)s> <%(e_in_to_e)s> ?uri } ", 156 | "type": "count" 157 | }, 158 | { 159 | "id": 402, 160 | "n_entities": 1, 161 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { <%(e_in)s> <%(e_in_to_e)s> ?uri } . ?uri rdf:type class", 162 | "type": "count" 163 | }, 164 | { 165 | "id": 103, 166 | "n_entities": 1, 167 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> ?uri } ", 168 | "type": "count" 169 | }, 170 | { 171 | "id": 403, 172 | "n_entities": 1, 173 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { <%(e_in_in)s> <%(e_in_in_to_e_in)s> ?x . ?x <%(e_in_to_e)s> ?uri } ", 174 | "type": "count" 175 | }, 176 | { 177 | "id": 105, 178 | "n_entities": 1, 179 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri. ?x rdf:type class}", 180 | "type": "count" 181 | }, 182 | { 183 | "id": 405, 184 | "n_entities": 1, 185 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri } . ?uri rdf:type class", 186 | "type": "count" 187 | }, 188 | { 189 | "id": 106, 190 | "n_entities": 1, 191 | "template": "SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_out_to_e_out_out)s> <%(e_out_out)s> . ?uri <%(e_to_e_out)s> ?x } ", 192 | "type": "count" 193 | }, 194 | { 195 | "id": 406, 196 | "n_entities": 1, 197 | "template": "SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_out_to_e_out_out)s> <%(e_out_out)s> . ?uri <%(e_to_e_out)s> ?x . ?uri rdf:type class} ", 198 | "type": "count" 199 | }, 200 | { 201 | "id": 107, 202 | "n_entities": 2, 203 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?uri <%(e_to_e_out)s> <%(e_out_1)s> . ?uri <%(e_to_e_out)s> <%(e_out_2)s>} ", 204 | "type": "count" 205 | }, 206 | { 207 | "id": 407, 208 | "n_entities": 2, 209 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?uri <%(e_to_e_out)s> <%(e_out_1)s> . ?uri <%(e_to_e_out)s> <%(e_out_2)s>} . ?uri rdf:type class ", 210 | "type": "count" 211 | }, 212 | { 213 | "id": 108, 214 | "n_entities": 2, 215 | "template": " SELECT DISTINCT COUNT(?uri) WHERE {?uri <%(e_to_e_out_1)s> <%(e_out_1)s> . ?uri <%(e_to_e_out_2)s> <%(e_out_2)s> } ", 216 | "type": "count" 217 | }, 218 | { 219 | "id": 408, 220 | "n_entities": 2, 221 | "template": " SELECT DISTINCT COUNT(?uri) WHERE {?uri <%(e_to_e_out_1)s> <%(e_out_1)s> . ?uri <%(e_to_e_out_2)s> <%(e_out_2)s> } . ?uri rdf:type class", 222 | "type": "count" 223 | }, 224 | { 225 | "id": 111, 226 | "n_entities": 1, 227 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s> ?uri }", 228 | "type": "count" 229 | }, 230 | { 231 | "id": 411, 232 | "n_entities": 1, 233 | "template": " SELECT DISTINCT COUNT(?uri) WHERE { ?x <%(e_in_to_e_in_out)s> <%(e_in_out)s> . ?x <%(e_in_to_e)s>. ?uri . ?x rdf:type class}", 234 | "type": "count" 235 | }, 236 | { 237 | "id": 151, 238 | "n_entities": 1, 239 | "template": "ASK WHERE { <%(uri)s> <%(e_to_e_out)s> <%(e_out)s> }", 240 | "type": "ask" 241 | }, 242 | { 243 | "id": 451, 244 | "n_entities": 1, 245 | "template": "ASK WHERE { <%(uri)s> <%(e_to_e_out)s> <%(e_out)s> }", 246 | "type": "ask" 247 | }, 248 | { 249 | "id": 152, 250 | "n_entities": 1, 251 | "template": "ASK WHERE { <%(e_in)s> <%(e_in_to_e)s> <%(uri)s> }", 252 | "type": "ask" 253 | }, 254 | { 255 | "id": 452, 256 | "n_entities": 1, 257 | "template": "ASK WHERE { <%(e_in)s> <%(e_in_to_e)s> <%(uri)s> }", 258 | "type": "ask" 259 | } 260 | ] -------------------------------------------------------------------------------- /common/graph/graph.py: -------------------------------------------------------------------------------- 1 | from common.graph.node import Node 2 | from common.graph.edge import Edge 3 | from common.container.uri import Uri 4 | from common.container.linkeditem import LinkedItem 5 | from common.utility.mylist import MyList 6 | import itertools 7 | import logging 8 | from tqdm import tqdm 9 | 10 | 11 | class Graph: 12 | def __init__(self, kb, logger=None): 13 | self.kb = kb 14 | self.logger = logger or logging.getLogger(__name__) 15 | self.nodes, self.edges = set(), set() 16 | self.entity_items, self.relation_items = [], [] 17 | self.suggest_retrieve_id = 0 18 | 19 | def create_or_get_node(self, uris, mergable=False): 20 | if isinstance(uris, (int)): 21 | uris = self.__get_generic_uri(uris, 0) 22 | mergable = True 23 | new_node = Node(uris, mergable) 24 | for node in self.nodes: 25 | if node == new_node: 26 | return node 27 | return new_node 28 | 29 | def add_node(self, node): 30 | if node not in self.nodes: 31 | self.nodes.add(node) 32 | 33 | def remove_node(self, node): 34 | self.nodes.remove(node) 35 | 36 | def add_edge(self, edge): 37 | if edge not in self.edges: 38 | self.add_node(edge.source_node) 39 | self.add_node(edge.dest_node) 40 | self.edges.add(edge) 41 | 42 | def remove_edge(self, edge): 43 | edge.prepare_remove() 44 | self.edges.remove(edge) 45 | if edge.source_node.is_disconnected(): 46 | self.remove_node(edge.source_node) 47 | if edge.dest_node.is_disconnected(): 48 | self.remove_node(edge.dest_node) 49 | 50 | def count_combinations(self, entity_items, relation_items, number_of_entities, top_uri): 51 | total = 0 52 | for relation_item in relation_items: 53 | rel_uris_len = len(relation_item.top_uris(top_uri)) 54 | for entity_uris in itertools.product(*[items.top_uris(top_uri) for items in entity_items]): 55 | total += rel_uris_len * len(list(itertools.combinations(entity_uris, number_of_entities))) 56 | return total 57 | 58 | def __one_hop_graph(self, entity_items, relation_items, threshold=None, number_of_entities=1): 59 | top_uri = 1 60 | 61 | total = self.count_combinations(entity_items, relation_items, number_of_entities, top_uri) 62 | if threshold is not None: 63 | while total > threshold: 64 | top_uri -= 0.1 65 | total = self.count_combinations(entity_items, relation_items, number_of_entities, top_uri) 66 | 67 | with tqdm(total=total, disable=self.logger.level >= 10) as pbar: 68 | for relation_item in relation_items: 69 | for relation_uri in relation_item.top_uris(top_uri): 70 | for entity_uris in itertools.product(*[items.top_uris(top_uri) for items in entity_items]): 71 | for entity_uri in itertools.combinations(entity_uris, number_of_entities): 72 | pbar.update(1) 73 | result = self.kb.one_hop_graph(entity_uri[0], relation_uri, 74 | entity_uri[1] if len(entity_uri) > 1 else None) 75 | if result is not None: 76 | for item in result: 77 | m = int(item["m"]["value"]) 78 | uri = entity_uri[1] if len(entity_uri) > 1 else 0 79 | if m == 0: 80 | n_s = self.create_or_get_node(uri, True) 81 | n_d = self.create_or_get_node(entity_uri[0]) 82 | e = Edge(n_s, relation_uri, n_d) 83 | self.add_edge(e) 84 | elif m == 1: 85 | n_s = self.create_or_get_node(entity_uri[0]) 86 | n_d = self.create_or_get_node(uri, True) 87 | e = Edge(n_s, relation_uri, n_d) 88 | self.add_edge(e) 89 | elif m == 2: 90 | n_s = self.create_or_get_node(uri) 91 | n_d = self.create_or_get_node(relation_uri) 92 | e = Edge(n_s, Uri(self.kb.type_uri, self.kb.parse_uri), n_d) 93 | self.add_edge(e) 94 | 95 | def find_minimal_subgraph(self, entity_items, relation_items, double_relation=False, ask_query=False, 96 | sort_query=False, h1_threshold=None): 97 | self.entity_items, self.relation_items = MyList(entity_items), MyList(relation_items) 98 | 99 | if double_relation: 100 | self.relation_items.append(self.relation_items[0]) 101 | 102 | # Find subgraphs that are consist of at least one entity and exactly one relation 103 | # self.logger.info("start finding one hop graph") 104 | self.__one_hop_graph(self.entity_items, self.relation_items, number_of_entities=int(ask_query) + 1, 105 | threshold=h1_threshold) 106 | # self.logger.info("finding one hop graph finished") 107 | # print('one hop graph: ') 108 | # for edge in self.edges: 109 | # print(edge.source_node, edge, edge.dest_node) 110 | 111 | if len(self.edges) > 100: 112 | return 113 | 114 | # Extend the existing edges with another hop 115 | # self.logger.info("Extend edges with another hop") 116 | self.__extend_edges(self.edges, relation_items) 117 | 118 | # print('extend_edges: ') 119 | # for edge in self.edges: 120 | # print(edge.source_node, edge, edge.dest_node) 121 | 122 | 123 | def __extend_edges(self, edges, relation_items): 124 | new_edges = set() 125 | total = 0 126 | for relation_item in relation_items: 127 | for relation_uri in relation_item.uris: 128 | total += len(edges) 129 | with tqdm(total=total, disable=self.logger.level >= 10) as pbar: 130 | for relation_item in relation_items: 131 | for relation_uri in relation_item.uris: 132 | for edge in edges: 133 | pbar.update(1) 134 | new_edges.update(self.__extend_edge(edge, relation_uri)) 135 | for e in new_edges: 136 | self.add_edge(e) 137 | 138 | def __extend_edge(self, edge, relation_uri): 139 | output = set() 140 | var_node = None 141 | if edge.source_node.are_all_uris_generic(): 142 | var_node = edge.source_node 143 | if edge.dest_node.are_all_uris_generic(): 144 | var_node = edge.dest_node 145 | ent1 = edge.source_node.first_uri_if_only() 146 | ent2 = edge.dest_node.first_uri_if_only() 147 | if not (var_node is None or ent1 is None or ent2 is None): 148 | result = self.kb.two_hop_graph(ent1, edge.uri, ent2, relation_uri) 149 | if result is not None: 150 | for item in result: 151 | if item[1]: 152 | if item[0] == 0: 153 | n_s = self.create_or_get_node(1, True) 154 | n_d = var_node 155 | e = Edge(n_s, relation_uri, n_d) 156 | output.add(e) 157 | elif item[0] == 1: 158 | n_s = var_node 159 | n_d = self.create_or_get_node(1, True) 160 | e = Edge(n_s, relation_uri, n_d) 161 | output.add(e) 162 | elif item[0] == 2: 163 | n_s = var_node 164 | n_d = self.create_or_get_node(1, True) 165 | e = Edge(n_s, relation_uri, n_d) 166 | output.add(e) 167 | self.suggest_retrieve_id = 1 168 | elif item[0] == 3: 169 | n_s = self.create_or_get_node(1, True) 170 | n_d = var_node 171 | e = Edge(n_s, relation_uri, n_d) 172 | output.add(e) 173 | elif item[0] == 4: 174 | n_d = self.create_or_get_node(relation_uri) 175 | n_s = self.create_or_get_node(1, True) 176 | e = Edge(n_s, Uri(self.kb.type_uri, self.kb.parse_uri), n_d) 177 | output.add(e) 178 | return output 179 | 180 | def __get_generic_uri(self, uri, edges): 181 | return Uri.generic_uri(uri) 182 | 183 | def generalize_nodes(self): 184 | """ 185 | if there are nodes which have none-generic uri that is not in the list of possible entity/relation, 186 | such uris will be replaced by a generic uri 187 | :return: None 188 | """ 189 | uris = sum([items.uris for items in self.entity_items] + [items.uris for items in self.relation_items], []) 190 | for node in self.nodes: 191 | for uri in node.uris: 192 | if uri not in uris and not uri.is_generic(): 193 | generic_uri = self.__get_generic_uri(uri, node.inbound + node.outbound) 194 | node.replace_uri(uri, generic_uri) 195 | 196 | def merge_edges(self): 197 | to_be_removed = set() 198 | for edge_1 in self.edges: 199 | for edge_2 in self.edges: 200 | if edge_1 is edge_2 or edge_2 in to_be_removed: 201 | continue 202 | if edge_1 == edge_2: 203 | to_be_removed.add(edge_2) 204 | for item in to_be_removed: 205 | try: 206 | self.remove_edge(item) 207 | except: 208 | print('not remove edge') 209 | 210 | def __str__(self): 211 | return "\n".join([edge.full_path() for edge in self.edges]) 212 | -------------------------------------------------------------------------------- /common/query/querybuilder.py: -------------------------------------------------------------------------------- 1 | from common.container.answerset import AnswerSet 2 | from common.container.linkeditem import LinkedItem 3 | from common.graph.path import Path 4 | from common.graph.paths import Paths 5 | from common.utility.mylist import MyList 6 | 7 | 8 | class QueryBuilder: 9 | def to_where_statement(self, graph, parse_queryresult, ask_query, count_query, sort_query): 10 | graph.generalize_nodes() 11 | graph.merge_edges() 12 | 13 | paths = self.__find_paths_start_with_entities(graph, graph.entity_items, graph.relation_items, graph.edges) 14 | 15 | paths = paths.remove_duplicates() 16 | 17 | # Expand coverage by changing generic ids 18 | new_paths = [] 19 | for path in paths: 20 | to_be_updated_edges = [] 21 | generic_nodes = set() 22 | for edge in path: 23 | if edge.source_node.are_all_uris_generic(): 24 | generic_nodes.add(edge.source_node) 25 | if edge.dest_node.are_all_uris_generic(): 26 | generic_nodes.add(edge.dest_node) 27 | 28 | if edge.source_node.are_all_uris_generic() and not edge.dest_node.are_all_uris_generic(): 29 | to_be_updated_edges.append( 30 | {"type": "source", "node": edge.source_node, "edge": edge}) 31 | if edge.dest_node.are_all_uris_generic() and not edge.source_node.are_all_uris_generic(): 32 | to_be_updated_edges.append( 33 | {"type": "dest", "node": edge.dest_node, "edge": edge}) 34 | 35 | for new_node in generic_nodes: 36 | for edge_info in to_be_updated_edges: 37 | if edge_info["node"] != new_node: 38 | new_path = None 39 | if edge_info["type"] == "source": 40 | new_path = path.replace_edge(edge_info["edge"], 41 | edge_info["edge"].copy(source_node=new_node)) 42 | if edge_info["type"] == "dest": 43 | new_path = path.replace_edge(edge_info["edge"], edge_info["edge"].copy(dest_node=new_node)) 44 | if new_path is not None: 45 | new_paths.append(new_path) 46 | 47 | new_paths = Paths(new_paths).remove_duplicates() 48 | # for new_path in new_paths: 49 | # generic_equal = False 50 | # if new_path not in paths: 51 | # for path in paths: 52 | # if path.generic_equal_with_substitutable_id(new_path): 53 | # generic_equal = True 54 | # break 55 | # if not generic_equal: 56 | # paths.append(new_path) 57 | 58 | for new_path in new_paths: 59 | paths.append(new_path) 60 | paths = paths.remove_duplicates() 61 | 62 | paths.sort(key=lambda x: x.confidence, reverse=True) 63 | output = paths.to_where(graph.kb, ask_query) 64 | 65 | # Remove queries with no answer 66 | filtered_output = [] 67 | for item in output: 68 | target_var = "?u_" + str(item["suggested_id"]) 69 | raw_answer = graph.kb.query_where(item["where"], return_vars=target_var, 70 | count=count_query, 71 | ask=ask_query) 72 | answerset = AnswerSet(raw_answer, parse_queryresult) 73 | 74 | # Do not include the query if it does not return any answer, except for boolean query 75 | if len(answerset.answer_rows) > 0 or ask_query: 76 | item["target_var"] = target_var 77 | item["answer"] = answerset 78 | filtered_output.append(item) 79 | 80 | # filtered_output_with_no_duplicate_answer = [] 81 | # for n, ii in enumerate(filtered_output): 82 | # duplicate_answer = False 83 | # for item in filtered_output[n + 1:]: 84 | # if item["answer"] == ii["answer"]: 85 | # duplicate_answer = True 86 | # if not duplicate_answer: 87 | # filtered_output_with_no_duplicate_answer.append(ii) 88 | 89 | # return filtered_output_with_no_duplicate_answer 90 | 91 | return filtered_output 92 | 93 | def __find_paths(self, graph, entity_items, relation_items, edges, output_paths=Paths(), used_edges=set()): 94 | new_output_paths = Paths([]) 95 | 96 | if len(relation_items) == 0: 97 | if len(entity_items) > 0: 98 | return Paths() 99 | return output_paths 100 | 101 | used_relations = [] 102 | for relation_item in relation_items: 103 | for relation in relation_item.uris: 104 | used_relations = used_relations + [relation] 105 | for edge in self.find_edges(edges, relation, used_edges): 106 | entities = MyList() 107 | if not (edge.source_node.are_all_uris_generic() or edge.uri.is_type()): 108 | entities.extend(edge.source_node.uris) 109 | if not (edge.dest_node.are_all_uris_generic() or edge.uri.is_type()): 110 | entities.extend(edge.dest_node.uris) 111 | 112 | entity_use = entity_items - LinkedItem.list_contains_uris(entity_items, entities) 113 | relation_use = relation_items - LinkedItem.list_contains_uris(relation_items, used_relations) 114 | edge_use = edges - {edge} 115 | 116 | new_paths = self.__find_paths(graph, 117 | entity_use, 118 | relation_use, 119 | edge_use, 120 | output_paths=output_paths.extend(edge), 121 | used_edges=used_edges | set([edge])) 122 | new_output_paths.add(new_paths, lambda path: len(path) >= len(graph.relation_items)) 123 | 124 | return new_output_paths 125 | 126 | def __find_paths_start_with_entities(self, graph, entity_items, relation_items, edges, output_paths=Paths(), used_edges=set()): 127 | new_output_paths = Paths([]) 128 | for entity_item in entity_items: 129 | for entity in entity_item.uris: 130 | for edge in self.find_edges_by_entity(edges, entity, used_edges): 131 | if not edge.uri.is_type(): 132 | used_relations = [edge.uri] 133 | else: 134 | used_relations = edge.dest_node.uris 135 | entities = MyList() 136 | if not (edge.source_node.are_all_uris_generic() or edge.uri.is_type()): 137 | entities.extend(edge.source_node.uris) 138 | if not (edge.dest_node.are_all_uris_generic() or edge.uri.is_type()): 139 | entities.extend(edge.dest_node.uris) 140 | 141 | entity_use = entity_items - LinkedItem.list_contains_uris(entity_items, entities) 142 | relation_use = relation_items - LinkedItem.list_contains_uris(relation_items, used_relations) 143 | edge_use = edges - {edge} 144 | 145 | new_paths = self.__find_paths(graph, 146 | entity_use, 147 | relation_use, 148 | edge_use, 149 | output_paths=output_paths.extend(edge), 150 | used_edges=used_edges | set([edge])) 151 | # new_paths = self.__find_paths(graph, 152 | # entity_items - LinkedItem.list_contains_uris(entity_items, entities), 153 | # relation_items - LinkedItem.list_contains_uris(relation_items, 154 | # used_relations), 155 | # edges - {edge}, 156 | # output_paths=output_paths.extend(edge), 157 | # used_edges=used_edges | set([edge])) 158 | new_output_paths.add(new_paths, lambda path: len(path) >= len(graph.relation_items)) 159 | return new_output_paths 160 | 161 | def find_edges(self, edges, uri, used_edges): 162 | outputs = [edge for edge in edges if edge.uri == uri or (edge.uri.is_type() and edge.dest_node.has_uri(uri))] 163 | if len(used_edges) == 0: 164 | return outputs 165 | connected_edges = [] 166 | for edge in outputs: 167 | found = False 168 | for used_edge in used_edges: 169 | if edge.source_node == used_edge.source_node or edge.source_node == used_edge.dest_node or \ 170 | edge.dest_node == used_edge.source_node or edge.dest_node == used_edge.dest_node: 171 | found = True 172 | break 173 | if found: 174 | connected_edges.append(edge) 175 | 176 | return connected_edges 177 | 178 | def find_edges_by_entity(self, edges, entity_uri, used_edges): 179 | outputs = [edge for edge in edges if 180 | (edge.source_node.has_uri(entity_uri) or edge.dest_node.has_uri(entity_uri))] 181 | if len(used_edges) == 0: 182 | return outputs 183 | connected_edges = [] 184 | for edge in outputs: 185 | found = False 186 | for used_edge in used_edges: 187 | if edge.source_node == used_edge.source_node or edge.source_node == used_edge.dest_node or \ 188 | edge.dest_node == used_edge.source_node or edge.dest_node == used_edge.dest_node: 189 | found = True 190 | break 191 | if found: 192 | connected_edges.append(edge) 193 | 194 | return connected_edges 195 | -------------------------------------------------------------------------------- /learning/treelstm/preprocess_lcquad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import json 5 | import anytree 6 | from tqdm import tqdm 7 | import sys 8 | import spacy 9 | path = os.getcwd() 10 | sys.path.insert(0, path) 11 | from common.utility.utility import find_mentions 12 | from parser.lc_quad_linked import LC_Qaud_LinkedParser 13 | # sys.path.append('/cluster/home/xlig/kg/') 14 | # sys.path.insert(0, '/cluster/home/xlig/kg/') 15 | 16 | 17 | def make_dirs(dirs): 18 | for d in dirs: 19 | if not os.path.exists(d): 20 | os.makedirs(d) 21 | 22 | 23 | def dependency_parse(filepath): 24 | spacy.prefer_gpu() 25 | nlp = spacy.load("en_core_web_lg") 26 | 27 | dirpath = os.path.dirname(filepath) 28 | filepre = os.path.splitext(os.path.basename(filepath))[0] 29 | tokpath = os.path.join(dirpath, filepre + '.toks') 30 | parentpath = os.path.join(dirpath, filepre + '.parents') 31 | relpath = os.path.join(dirpath, filepre + '.rels') 32 | pospath = os.path.join(dirpath, filepre + '.pos') 33 | tagpath = os.path.join(dirpath, filepre + '.tag') 34 | lenpath = os.path.join(dirpath, filepre + '.len') 35 | 36 | with open(tokpath, 'w', encoding='utf-8') as tokfile, \ 37 | open(relpath, 'w', encoding='utf-8') as relfile, \ 38 | open(parentpath, 'w', encoding='utf-8') as parfile, \ 39 | open(lenpath, 'w', encoding='utf-8') as lenfile, \ 40 | open(tagpath, 'w', encoding='utf-8') as tagfile, \ 41 | open(pospath, 'w', encoding='utf-8') as posfile: 42 | with open(os.path.join(dirpath, 'a.txt'), 'r', encoding='utf-8') as f: 43 | for line in f: 44 | l = line.split(' ') 45 | l = [i for i in l if i != ''] 46 | newline = ' '.join(l) 47 | doc = nlp(newline) 48 | json_doc = doc.to_json() 49 | token = json_doc['tokens'] 50 | pos = [] 51 | tag = [] 52 | dep = [] 53 | tok = [] 54 | parent = [] 55 | length = json_doc['sents'][0]['end'] + 1 56 | for t in token: 57 | if t['pos'] != 'SPACE': 58 | tok.append(doc[t['id']].text) 59 | pos.append(t['pos']) 60 | tag.append(t['tag']) 61 | dep.append(t['dep']) 62 | head = t['head'] 63 | if t['dep'] == 'ROOT': 64 | head = 0 65 | else: 66 | head = head + 1 67 | parent.append(head) 68 | tokfile.write(' '.join(tok) + '\n') 69 | posfile.write(' '.join(pos) + '\n') 70 | tagfile.write(' '.join(tag) + '\n') 71 | relfile.write(' '.join(dep) + '\n') 72 | parfile.writelines(["%s " % str(item) for item in parent]) 73 | parfile.write('\n') 74 | lenfile.write(str(length) + '\n') 75 | 76 | 77 | def query_parse(filepath): 78 | dirpath = os.path.dirname(filepath) 79 | filepre = os.path.splitext(os.path.basename(filepath))[0] 80 | tokpath = os.path.join(dirpath, filepre + '.toks') 81 | parentpath = os.path.join(dirpath, filepre + '.parents') 82 | with open(filepath) as datafile, \ 83 | open(tokpath, 'w') as tokfile, \ 84 | open(parentpath, 'w') as parentfile: 85 | for line in tqdm(datafile): 86 | clauses = line.split(" .") 87 | vars = dict() 88 | root = None 89 | for clause in clauses: 90 | triple = [item.replace("\n", "") for item in clause.split(" ")] 91 | 92 | root_node = anytree.Node(triple[1]) 93 | left_node = anytree.Node(triple[0], root_node) 94 | right_node = anytree.Node(triple[2], root_node) 95 | 96 | leveled = [left_node, root_node, right_node] 97 | for item in triple: 98 | if item.startswith("?u_"): 99 | if item in vars: 100 | children = vars[item].parent.children 101 | if children[0] == vars[item]: 102 | vars[item].parent.children = [root_node, children[1]] 103 | else: 104 | vars[item].parent.children = [children[0], root_node] 105 | vars[item] = [node for node in leveled if node.name == item][0] 106 | break 107 | else: 108 | vars[item] = [node for node in leveled if node.name == item][0] 109 | 110 | if root is None: 111 | root = root_node 112 | 113 | pre_order = [node for node in anytree.iterators.PreOrderIter(root)] 114 | tokens = [node.name for node in pre_order] 115 | for i in range(len(pre_order)): 116 | pre_order[i].index = i + 1 117 | idxs = [node.parent.index if node.parent is not None else 0 for node in pre_order] 118 | 119 | tokfile.write(" ".join(tokens) + "\n") 120 | parentfile.write(" ".join(map(str, idxs)) + "\n") 121 | 122 | 123 | def build_vocab(filepaths, dst_path, lowercase=True): 124 | vocab = set() 125 | for filepath in filepaths: 126 | with open(filepath) as f: 127 | for line in f: 128 | if lowercase: 129 | line = line.lower() 130 | vocab |= set(line.split()) 131 | with open(dst_path, 'w') as f: 132 | for w in sorted(vocab): 133 | f.write(w + '\n') 134 | 135 | 136 | def generalize_question(a, b, parser=None): 137 | # replace entity mention in question with a generic symbol 138 | 139 | if parser is None: 140 | parser = LC_Qaud_LinkedParser() 141 | 142 | _, _, uris = parser.parse_sparql(b) 143 | uris = [uri for uri in uris if uri.is_entity()] 144 | 145 | i = 0 146 | for item in find_mentions(a, uris): 147 | a = "{} #en{} {}".format(a[:item["start"]], "t" * (i + 1), a[item["end"]:]) 148 | b = b.replace(item["uri"].raw_uri, "#en{}".format("t" * (i + 1))) 149 | 150 | # remove extra info from the relation's uri and remaining entities 151 | for item in ["http://dbpedia.org/resource/", "http://dbpedia.org/ontology/", 152 | "http://dbpedia.org/property/", "http://www.w3.org/1999/02/22-rdf-syntax-ns#"]: 153 | b = b.replace(item, "") 154 | b = b.replace("<", "").replace(">", "") 155 | 156 | return a, b 157 | 158 | 159 | def split(data, parser=None): 160 | if isinstance(data, str): 161 | with open(data) as datafile: 162 | dataset = json.load(datafile) 163 | else: 164 | dataset = data 165 | 166 | a_list = [] 167 | b_list = [] 168 | id_list = [] 169 | sim_list = [] 170 | for item in tqdm(dataset): 171 | i = item["id"] 172 | a = item["question"] 173 | for query in item["generated_queries"]: 174 | a, b = generalize_question(a, query["query"], parser) 175 | 176 | # Empty query should be ignored 177 | if len(b) < 5: 178 | continue 179 | sim = str(2 if query["correct"] else 1) 180 | 181 | id_list.append(i + '\n') 182 | a_list.append(a.encode('ascii', 'ignore').decode('ascii') + '\n') 183 | b_list.append(b.encode('ascii', 'ignore').decode('ascii') + '\n') 184 | sim_list.append(sim + '\n') 185 | return a_list, b_list, id_list, sim_list 186 | 187 | 188 | def save_split(dst_dir, a_list, b_list, id_list, sim_list): 189 | with open(os.path.join(dst_dir, 'a.txt'), 'w') as afile, \ 190 | open(os.path.join(dst_dir, 'b.txt'), 'w') as bfile, \ 191 | open(os.path.join(dst_dir, 'id.txt'), 'w') as idfile, \ 192 | open(os.path.join(dst_dir, 'sim.txt'), 'w') as simfile: 193 | for i in range(len(a_list)): 194 | idfile.write(id_list[i]) 195 | afile.write(a_list[i]) 196 | bfile.write(b_list[i]) 197 | simfile.write(sim_list[i]) 198 | 199 | 200 | def parse(dirpath, dep_parse=True): 201 | if dep_parse: 202 | dependency_parse(os.path.join(dirpath, 'a.txt')) 203 | query_parse(os.path.join(dirpath, 'b.txt')) 204 | 205 | 206 | if __name__ == '__main__': 207 | print('=' * 80) 208 | print('Preprocessing LC-Quad dataset') 209 | print('=' * 80) 210 | 211 | base_dir = os.path.dirname(os.path.realpath(__file__)) 212 | print('base_dir: ', base_dir) 213 | data_dir = os.path.join(base_dir, 'data') 214 | lc_quad_dir = os.path.join(data_dir, 'lc_quad') 215 | lib_dir = os.path.join(base_dir, 'lib') 216 | train_dir = os.path.join(lc_quad_dir, 'train') 217 | dev_dir = os.path.join(lc_quad_dir, 'dev') 218 | test_dir = os.path.join(lc_quad_dir, 'test') 219 | make_dirs([train_dir, dev_dir, test_dir]) 220 | 221 | # split into separate files 222 | train_filepath = os.path.join(lc_quad_dir, 'LCQuad_train.json') 223 | trail_filepath = os.path.join(lc_quad_dir, 'LCQuad_trial.json') 224 | test_filepath = os.path.join(lc_quad_dir, 'LCQuad_test.json') 225 | 226 | ds = json.load(open("output/lcquad_gold.json")) 227 | 228 | total = len(ds) 229 | train_size = int(.7 * total) 230 | dev_size = int(.2 * total) 231 | test_size = int(.1 * total) 232 | print('Totle: ', total) 233 | print('train_size: ', train_size) 234 | print('dev_size: ', dev_size) 235 | print('test_size: ', test_size) 236 | 237 | json.dump(ds[:train_size], open(train_filepath, "w")) 238 | json.dump(ds[train_size:train_size + dev_size], open(trail_filepath, "w")) 239 | json.dump(ds[train_size + dev_size:], open(test_filepath, "w")) 240 | 241 | parser = LC_Qaud_LinkedParser() 242 | 243 | print('Split train set') 244 | save_split(train_dir, *split(train_filepath, parser)) 245 | print('Split dev set') 246 | save_split(dev_dir, *split(trail_filepath, parser)) 247 | print('Split test set') 248 | save_split(test_dir, *split(test_filepath, parser)) 249 | 250 | # parse sentences 251 | print("parse train set") 252 | parse(train_dir) 253 | print("parse dev set") 254 | parse(dev_dir) 255 | print("parse test set") 256 | parse(test_dir) 257 | 258 | # get vocabulary 259 | build_vocab( 260 | glob.glob(os.path.join(lc_quad_dir, '*/*.toks')), 261 | os.path.join(lc_quad_dir, 'vocab.txt')) 262 | build_vocab( 263 | glob.glob(os.path.join(lc_quad_dir, '*/*.toks')), 264 | os.path.join(lc_quad_dir, 'vocab-cased.txt'), 265 | lowercase=False) 266 | --------------------------------------------------------------------------------