├── .gitignore ├── LuceneSearch.py ├── README.txt ├── clean_utils.py ├── createLuceneIndex.py ├── data_utils.py ├── evaluate ├── calculate_bleu.py ├── calculate_bleu.sh ├── compute_bool_acc.py ├── compute_count_acc.py ├── compute_count_acc2.py ├── compute_count_accuracy.py ├── compute_precision_active_set.py ├── compute_precision_active_set.sh ├── compute_recall.py ├── compute_recall.sh ├── compute_recall_active_set.py ├── compute_recall_per_state.sh ├── load_wikidata_wfn.py ├── multi_bleu.pl ├── parse_active_set.py ├── postprocess_bool.py ├── run_calculate_bleu.sh ├── run_compute_recall.sh ├── split_count_op.py └── words2number.py ├── hierarchy_model.py ├── load_wikidata2.py ├── params.py ├── params_test.py ├── prepare_data_for_hred.py ├── question_parser_lucene2.py ├── read_data.py ├── relation_linker ├── annoy_index_rel │ ├── glove_embedding_of_vocab.ann │ ├── index2rel.pkl │ └── index2word.pkl ├── annoy_index_rel_noisy │ ├── glove_embedding_of_vocab.ann │ ├── index2rel.pkl │ └── index2word.pkl ├── build_annoy_index_over_relation_words.py ├── create_relation_annoy_index.py ├── perform_relation_identification.py ├── predicates_bw.tsv └── predicates_fw.tsv ├── requirements.txt ├── run.sh ├── run_model.py ├── run_test.py ├── run_test.sh ├── run_test_jobs.sh ├── seq2seq.py ├── stopwords.pkl ├── stopwords_histogram.txt ├── text_util.py ├── type_linker ├── annoy_index_type │ ├── glove_embedding_of_vocab.ann │ ├── index2type.pkl │ └── type_names.json ├── create_type_annoy_index.py └── perform_type_identification.py ├── utils ├── cp_files.sh ├── find_overlap.py ├── get_cont_chunks.py ├── get_nounphrases.py └── search_entities.py ├── vocabs ├── response_vocab.pkl └── vocab.pkl ├── wikidata_entities_with_digitnames.pkl └── words2number.py /.gitignore: -------------------------------------------------------------------------------- 1 | check_missing_memory.py 2 | check_mem_preselection_failures.py 3 | transe_dir/* 4 | lucene_dir/* 5 | browse.py 6 | *.pyc 7 | hierarchy_model_old.py* 8 | prepare_data_old/* 9 | prepare_data_backup/* 10 | prepare_data/* 11 | test_codes 12 | hierarchy_model_old.py* 13 | run_test_twostep_old.* 14 | run_model_old.py* 15 | run_old.sh* 16 | copy/* 17 | new_model_softmax_kvmem/* 18 | model_softmax_kvmem/* 19 | obsolete/* 20 | model_softmax_decoder/* 21 | new_model_softmax_decoder/* 22 | model_softmax_test/* 23 | to_delete.sh 24 | knowledge_graph.py* 25 | -------------------------------------------------------------------------------- /LuceneSearch.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import re 3 | import nltk 4 | from nltk.corpus import stopwords 5 | import json 6 | import string 7 | import lucene 8 | from lucene import * 9 | from java.io import File 10 | from org.apache.lucene.analysis.standard import StandardAnalyzer 11 | from org.apache.lucene.index import DirectoryReader, IndexReader 12 | from org.apache.lucene.index import Term 13 | from org.apache.lucene.search import BooleanClause, BooleanQuery, PhraseQuery, TermQuery 14 | from org.apache.lucene.queryparser.classic import QueryParser 15 | from org.apache.lucene.store import SimpleFSDirectory 16 | from org.apache.lucene.search import IndexSearcher 17 | from org.apache.lucene.util import Version 18 | import unicodedata 19 | import unidecode 20 | stop = set(stopwords.words('english')) 21 | string.punctuation='!"#$&\'()*+,-./:;<=>?@[\]^_`{|}~ ' 22 | regex = re.compile('[%s]' % re.escape(string.punctuation)) 23 | 24 | class LuceneSearch(): 25 | def __init__(self,lucene_index_dir='lucene_index/', num_docs_to_return=100): 26 | lucene.initVM(vmargs=['-Djava.awt.headless=true']) 27 | directory = SimpleFSDirectory(File(lucene_index_dir)) 28 | self.searcher = IndexSearcher(DirectoryReader.open(directory)) 29 | self.num_docs_to_return =num_docs_to_return 30 | self.ireader = IndexReader.open(directory) 31 | 32 | def strict_search(self, value, value_orig=None): 33 | value_words = set(value.split(' ')) 34 | if value_orig is None: 35 | value_orig = re.sub(' +', ' ', regex.sub(' ', value)).strip() 36 | else: 37 | value_orig = re.sub(' +', ' ', regex.sub(' ', value_orig)).strip() 38 | value = re.sub(' +', ' ', regex.sub(' ', value.lower())).strip() 39 | query = BooleanQuery() 40 | query.add(TermQuery(Term("wiki_name",value)), BooleanClause.Occur.SHOULD) 41 | query.add(TermQuery(Term("wiki_name",value_orig)), BooleanClause.Occur.SHOULD) 42 | query.add(TermQuery(Term("wiki_name_orig",value)), BooleanClause.Occur.SHOULD) 43 | query.add(TermQuery(Term("wiki_name_orig",value_orig)), BooleanClause.Occur.SHOULD) 44 | #print "0. query ",query 45 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 46 | return scoreDocs 47 | 48 | def qid_search(self, value): 49 | query = TermQuery(Term("wiki_id",value)) 50 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 51 | return scoreDocs 52 | 53 | def search(self, words, words_orig, stopwords=[], min_length=0, slop=2, remove_digits=False, any_one_word_occur=False): 54 | words_without_digits = re.sub(r'\w*\d\w*', '', " ".join(words)).strip().split(" ") 55 | if remove_digits and len(words_without_digits)>0: 56 | words = words_without_digits 57 | words = [x for x in words if x.lower() not in stopwords and len(x)>min_length] 58 | words_orig = [x for x in words_orig if x.lower() not in stopwords and len(x)>min_length] 59 | 60 | if len(words)==0: 61 | return [] 62 | query = BooleanQuery() 63 | query1 = PhraseQuery() 64 | query1.setSlop(slop) 65 | query2 = PhraseQuery() 66 | query2.setSlop(slop) 67 | query3 = PhraseQuery() 68 | query3.setSlop(slop) 69 | for word in words: 70 | query2.add(Term("wiki_name_analyzed_nopunct", word)) 71 | query3.add(Term("wiki_name_analyzed_nopunct_nostop", word)) 72 | for word in words_orig: 73 | query1.add(Term("wiki_name_analyzed", word)) 74 | query.add(query1, BooleanClause.Occur.SHOULD) 75 | query.add(query2, BooleanClause.Occur.SHOULD) 76 | query.add(query3, BooleanClause.Occur.SHOULD) 77 | #print "1. query ", query 78 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 79 | if len(scoreDocs)>0: 80 | #self.printDocs(scoreDocs) 81 | return scoreDocs 82 | query = BooleanQuery() 83 | for word in words: 84 | query_word = BooleanQuery() 85 | query_word.add(TermQuery(Term("wiki_name_analyzed_nopunct", word)), BooleanClause.Occur.SHOULD) 86 | query_word.add(TermQuery(Term("wiki_name_analyzed_nopunct_nostop", word)), BooleanClause.Occur.SHOULD) 87 | query.add(query_word, BooleanClause.Occur.MUST) 88 | #print "2. query ", query 89 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 90 | if len(scoreDocs)>0: 91 | return scoreDocs 92 | query = BooleanQuery() 93 | for word in words_orig: 94 | query.add(TermQuery(Term("wiki_name_analyzed", word)), BooleanClause.Occur.MUST) 95 | #print "3. query ", query 96 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 97 | if len(stopwords)>0 and any_one_word_occur: 98 | query = BooleanQuery() 99 | for word in words_orig: 100 | query.add(TermQuery(Term("wiki_name_analyzed", word)), BooleanClause.Occur.SHOULD) 101 | return scoreDocs 102 | 103 | def relaxed_search(self, value, text=None): 104 | value_orig = value.strip() 105 | value = re.sub(' +', ' ', regex.sub(' ', value.lower())).strip() 106 | if text is not None: 107 | text = re.sub(' +', ' ', regex.sub(' ', text.lower())).strip() 108 | words = nltk.word_tokenize(value) 109 | words_set = set(words) 110 | words_orig = nltk.word_tokenize(value_orig) 111 | if len(' '.join(words))==0: 112 | return [] 113 | if len(words)==0: 114 | return [] 115 | scoreDocs = self.strict_search(value, value_orig) 116 | if len(scoreDocs)>0: 117 | wiki_entities = self.get_wiki_entities(scoreDocs, words_set, text) 118 | if len(wiki_entities)>0: 119 | return wiki_entities 120 | scoreDocs = self.search(words, words_orig, []) 121 | if len(scoreDocs)>0: 122 | wiki_entities = self.get_wiki_entities(scoreDocs, words_set, text) 123 | if len(wiki_entities)>0: 124 | return wiki_entities 125 | scoreDocs = self.search(words, words_orig, stop) 126 | if len(scoreDocs)>0: 127 | wiki_entities = self.get_wiki_entities(scoreDocs, words_set, text) 128 | if len(wiki_entities)>0: 129 | return wiki_entities 130 | scoreDocs = self.search(words, words_orig, stop, 1) 131 | if len(scoreDocs)>0: 132 | wiki_entities = self.get_wiki_entities(scoreDocs, words_set, text) 133 | if len(wiki_entities)>0: 134 | return wiki_entities 135 | return [] 136 | 137 | def more_relaxed_search(self, value, text): 138 | wiki_entities = self.relaxed_search(value, text) 139 | if len(wiki_entities)==0: 140 | value_orig = value.strip() 141 | value = re.sub(' +', ' ', regex.sub(' ', value.lower())).strip() 142 | if text is not None: 143 | text = re.sub(' +', ' ', regex.sub(' ', text.lower())).strip() 144 | words = nltk.word_tokenize(value) 145 | words_set = set(words) 146 | words_orig = nltk.word_tokenize(value_orig) 147 | scoreDocs = self.search(words, words_orig, stop, 1, 3) 148 | if len(scoreDocs)>0: 149 | return self.get_wiki_entities(scoreDocs, words_set, text) 150 | else: 151 | scoreDocs = self.search(words, words_orig, stop, 1, 3, True) 152 | if len(scoreDocs)>0: 153 | return self.get_wiki_entities(scoreDocs, words_set, text) 154 | else: 155 | scoreDocs = self.search(words, words_orig, stop, 1, 3, True, True) 156 | if len(scoreDocs)>0: 157 | return self.get_wiki_entities(scoreDocs, words_set, text) 158 | else: 159 | return [] 160 | else: 161 | return wiki_entities 162 | 163 | def get_wiki_entities(self, scoreDocs, value_words, text=None): 164 | if len(scoreDocs)>100: 165 | return [] 166 | entities = [] 167 | for scoreDoc in scoreDocs: 168 | doc = self.searcher.doc(scoreDoc.doc) 169 | wiki_id = doc['wiki_id'] 170 | doc = doc['wiki_name_analyzed_nopunct'] 171 | #print doc 172 | doc_words = set(doc.strip().split(' ')) #re.sub(' +', ' ', regex.sub(' ', doc.lower())).strip().split(' ')) 173 | if text is None or doc.strip() in text: 174 | if wiki_id not in entities: 175 | entities.append(wiki_id) 176 | #print 'searching for ', value_words, '::',doc+"("+wiki_id+"), " 177 | ''' 178 | extra_words = doc_words - value_words 179 | extra_words = extra_words - stop 180 | #print 'searching for ', value_words, ':: doc',doc_words ,' extra ', extra_words 181 | if len(extra_words)<2: 182 | entities.append(wiki_id) 183 | try: 184 | print 'searching for ', value_words, '::',doc+"("+wiki_id+"), " 185 | except: 186 | continue 187 | ''' 188 | return entities 189 | 190 | def printDocs(self, scoreDocs): 191 | for scoreDoc in scoreDocs: 192 | doc = self.searcher.doc(scoreDoc.doc) 193 | for f in doc.getFields(): 194 | print f.name(),':', f.stringValue(),', ' 195 | 196 | print '' 197 | print '-------------------------------------\n' 198 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | Step1: Download https://drive.google.com/file/d/1ccZSys8u4F_mqNJ97OOlSLe3fjpFLhdv/view?usp=sharing and extract it (and rename the folder to lucene_dir) 2 | 3 | Step2: Download the files ent_embed.pkl.npy, rel_embed.pkl.npy, id_ent_map.pickle, id_rel_map.pickle from the link https://zenodo.org/record/4052427#.X2_hWXRKhQI and place them in a dir. named transe_dir 4 | 5 | Step3: Download https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing and put it in a folder glove_dir 6 | 7 | Step4: Download the wikidata JSONs from the link https://zenodo.org/record/4052427#.X2_hWXRKhQI and put them in a folder wikidata_dir 8 | 9 | Step5: Put the correct (complete) paths to wikidata_dir, lucene_dir, transe_dir, glove_dir in params.py and params_test.py 10 | 11 | Step6: In both params.py and params_test.py, use param['type_of_loss']="decoder" 12 | 13 | Step7: Create a folder say 'Target_Model_decoder' where you want the decoder model to be dumped, and make two folders inside it ('dump' and 'model') (e.g. 'mkdir Target_Model_decoder/dump' and 'mkdir Target_Model_decoder/model') 14 | 15 | Step8: put the params.py and params_test.py from Step 6 inside Target_Model_decoder folder 16 | 17 | Step9: Create another version of params.py and params_test.py, this time using param['type_of_loss']="kvmem" 18 | 19 | Step 10: Create a folder say 'Target_Model_kvmem' where you want the kvmem model to be dumped, and make two folders inside it ('dump' and 'model') (e.g. 'mkdir Target_Model_kvmem/dump' and 'mkdir Target_Model_kvmem/model') 20 | 21 | Step11: Download train_preprocessed.zip from https://drive.google.com/file/d/1HmLOGTV_v18grW_hXpu_s6MdogEJDM_a/view?usp=sharing and extract and put the contents (preprocessed pickle files of the train data) into Target_Model_decoder/dump and Target_Model_kvmem/dump 22 | 23 | Step12: Download valid_preprocessed.zip from https://drive.google.com/file/d/1uoBUjjidyDks0pEUehxX-ofB5B_trdpP/view?usp=sharing and extract and put the contents (preprocessed pickle files of the valid data) into Target_Model_decoder/dump and Target_Model_kvmem/dump 24 | 25 | Step13: Download test_preprocessed.zip from https://drive.google.com/file/d/1PMOE_jQJM_avY3MItAdEI0s3GJU_Km31/view?usp=sharing and extract and put the contents (preprocessed pickle files of the test data) into Target_Model_decoder/dump and Target_Model_kvmem/dump 26 | 27 | Step14: Run ./run.sh for training (the way it has been shown in the run.sh file) where the dump_dir is 'Target_Model_decoder' which you have created earlier and the data_dir is the directory containing the downloaded data 28 | 29 | Step15: Run ./run_test.sh for testing (the way it has been shown in the run_test.sh file). 30 | 31 | Step16: For evaluating the model separately on each question type, run the following: 32 | ./run_test.sh Target_Model_decoder verify 33 | ./run_test.sh Target_Model_decoder quantitative_count 34 | ./run_test.sh Target_Model_decoder comparative_count 35 | ./run_test.sh Target_Model_kvmem simple 36 | ./run_test.sh Target_Model_kvmem logical 37 | ./run_test.sh Target_Model_kvmem quantitative 38 | ./run_test.sh Target_Model_kvmem comparative 39 | -------------------------------------------------------------------------------- /clean_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import csv 4 | 5 | def read_file_as_set(input_path): 6 | s = set() 7 | with open(input_path) as input_file: 8 | for line in input_file: 9 | line = line.strip('\n') 10 | s.add(line) 11 | return s 12 | 13 | def read_file_as_dict(input_path): 14 | d = {} 15 | with open(input_path) as input_file: 16 | reader = csv.DictReader(input_file, delimiter='\t', fieldnames=['col1', 'col2']) 17 | for row in reader: 18 | d[row['col1']] = int(row['col2']) 19 | return d -------------------------------------------------------------------------------- /createLuceneIndex.py: -------------------------------------------------------------------------------- 1 | ### TO RUN: python createLuceneIndex.py 2 | 3 | import cPickle as pkl 4 | import re 5 | import string 6 | import json 7 | import os 8 | import lucene 9 | from lucene import * 10 | import codecs 11 | from java.io import File 12 | import sys 13 | import nltk 14 | import unidecode 15 | import unicodedata 16 | from nltk.corpus import stopwords 17 | import lucene 18 | from org.apache.lucene.analysis.standard import StandardAnalyzer 19 | from org.apache.lucene.document import Document, Field 20 | from org.apache.lucene.index import IndexWriter, IndexWriterConfig 21 | from org.apache.lucene.store import SimpleFSDirectory 22 | from org.apache.lucene.util import Version 23 | stop = set(stopwords.words('english')) 24 | wikidata_dir = sys.argv[1] 25 | transe_dir = sys.argv[2] 26 | 27 | string.punctuation='!"#$&\'()*+,-./:;<=>?@[\]^_`{|}~ ' 28 | regex = re.compile('[%s]' % re.escape(string.punctuation)) 29 | lucene.initVM(vmargs=['-Djava.awt.headless=true']) 30 | index_dir = os.path.join(transe_dir+'/lucene_index') 31 | analyzer = StandardAnalyzer(Version.LUCENE_36) 32 | index = SimpleFSDirectory(File(index_dir)) 33 | if not os.path.exists(index_dir): 34 | os.makedirs(index_dir) 35 | config = IndexWriterConfig(Version.LUCENE_36, analyzer) 36 | writer = IndexWriter(index, config) 37 | with codecs.open(wikidata_dir+'/items_wikidata_n.json','r','utf-8') as data_file: 38 | item_data = json.load(data_file) 39 | filtered_wikidata = pkl.load(open(transe_dir+'/ent_id_map.pickle')) 40 | item_data = {k:v for k,v in item_data.items() if k in filtered_wikidata} 41 | i=0 42 | num_errors=0 43 | for k,v in item_data.items(): 44 | k = k.strip() 45 | doc = Document() 46 | v_orig2 = v 47 | v = unicodedata.normalize('NFKD', v).encode('ascii','ignore') 48 | v_orig = v.strip() 49 | v = v.lower().strip() 50 | doc.add(Field("wiki_id", str(k), Field.Store.YES, Field.Index.NOT_ANALYZED)) 51 | doc.add(Field("wiki_name_orig", str(v_orig), Field.Store.YES, Field.Index.NOT_ANALYZED)) 52 | doc.add(Field("wiki_name", str(v), Field.Store.YES, Field.Index.NOT_ANALYZED)) 53 | doc.add(Field("wiki_name_analyzed", str(v), Field.Store.YES, Field.Index.ANALYZED)) 54 | v_punct_removed = re.sub(' +', ' ', regex.sub(' ', v)).strip() 55 | doc.add(Field("wiki_name_analyzed_nopunct", str(v_punct_removed), Field.Store.YES, Field.Index.ANALYZED)) 56 | v_stop_removed = " ".join([x for x in nltk.word_tokenize(v_punct_removed) if x not in stop]) 57 | doc.add(Field("wiki_name_analyzed_nopunct_nostop", str(v_stop_removed), Field.Store.YES, Field.Index.ANALYZED)) 58 | writer.addDocument(doc) 59 | i=i+1 60 | if i%10000==0: 61 | print 'finished ',i 62 | print 'num errors while indexing ', num_errors 63 | writer.close() 64 | index.close() 65 | 66 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | PIPE = "|" 4 | COMMA = "," 5 | TAB = "\t" 6 | SPACE = " " 7 | 8 | def union(*sets): 9 | target_set = set([]) 10 | for s in sets: 11 | target_set = target_set.union(s) 12 | return target_set 13 | 14 | 15 | def extract_dimension_from_tuples_as_list(list_of_tuples, dim): 16 | result = [] 17 | for tuple in list_of_tuples: 18 | result.append(tuple[dim]) 19 | return result 20 | 21 | 22 | def get_str_of_seq(entities): 23 | return PIPE.join(entities) 24 | 25 | 26 | def get_str_of_nested_seq(paths): 27 | result = [] 28 | for path in paths: 29 | result.append(COMMA.join(path)) 30 | return PIPE.join(result) 31 | 32 | 33 | def pad(arr, L): 34 | arr_cpy = list(arr) 35 | assert (len(arr_cpy) <= L) 36 | while len(arr_cpy) < L: 37 | arr_cpy.append(0) 38 | return arr_cpy 39 | -------------------------------------------------------------------------------- /evaluate/calculate_bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import re 4 | 5 | decoder_dir = sys.argv[1] 6 | kvmem_dir = sys.argv[2] 7 | test_type = sys.argv[3] 8 | state = sys.argv[4] 9 | type = sys.argv[5] 10 | pred_file = decoder_dir+'/test_output_'+test_type+'_'+state+'/pred_sent.txt' 11 | ent_file = decoder_dir+'/test_output_'+test_type+'_'+state+'/top20_ent_from_'+type+'.txt' 12 | 13 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 14 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 15 | k = 0 16 | max_k = 10 17 | length = len(pred_op) 18 | replace_kb = True 19 | for j in range(len(pred_op)): 20 | if pred_op[j] in ['','','','','']: 21 | pred_op[j] = '' 22 | if pred_op[j]=='': 23 | length = j 24 | if pred_op[j].startswith(''): 25 | if not replace_kb: 26 | pred_op[j] = '' 27 | continue 28 | if k == len(kb_name_list_unique) or k == max_k: 29 | replace_kb = False 30 | pred_op[j] = '' 31 | continue 32 | pred_op[j] = kb_name_list_unique[k] 33 | k = k+1 34 | pred_op = pred_op[:length] 35 | pred_op = re.sub(' +',' ',' '.join(pred_op)).strip() 36 | return pred_op 37 | 38 | with open(pred_file) as pred_lines, open(ent_file) as ent_lines: 39 | for pred, ent in zip(pred_lines, ent_lines): 40 | word_list = pred.strip().split(' ') 41 | 42 | kb_count = 1 43 | for word in pred.strip().split(' '): 44 | if word=='': 45 | word_list.append('_'+str(kb_count)) 46 | kb_count = kb_count+1 47 | else: 48 | word_list.append(word) 49 | 50 | word_list = list(OrderedDict.fromkeys(word_list)) 51 | if '' in word_list: 52 | word_list.remove('') 53 | if '' in word_list: 54 | word_list.remove('') 55 | if '' in word_list: 56 | word_list.remove('') 57 | if '' in word_list: 58 | word_list.remove('') 59 | if '|' in ent: 60 | ent = ent.strip().split('|') 61 | else: 62 | ent = [x.strip() for x in ent.strip().split(',')] 63 | pred = replace_kb_ent_in_resp(ent, word_list) 64 | print pred 65 | 66 | 67 | -------------------------------------------------------------------------------- /evaluate/calculate_bleu.sh: -------------------------------------------------------------------------------- 1 | ~/anaconda/bin/python calculate_bleu.py $1 $2 $3 $4 "mem" > $2/test_output_$3_$4/twostep_pred_sent_replaced_from_mem.txt 2 | echo $2/test_output_$3_$4/twostep_pred_sent_replaced_from_mem.txt 3 | perl multi_bleu.pl -lc $2/test_output_$3_$4/gold_resp.txt < $2/test_output_$3_$4/twostep_pred_sent_replaced_from_mem.txt 4 | ~/anaconda/bin/python calculate_bleu.py $1 $2 $3 $4 "kb" > $2/test_output_$3_$4/twostep_pred_sent_replaced_from_kb.txt 5 | echo $2/test_output_$3_$4/twostep_pred_sent_replaced_from_kb.txt 6 | perl multi_bleu.pl -lc $2/test_output_$3_$4/gold_resp.txt < $2/test_output_$3_$4/twostep_pred_sent_replaced_from_kb.txt 7 | -------------------------------------------------------------------------------- /evaluate/compute_bool_acc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import re 3 | import sys 4 | from words2number import * 5 | gold_file = "model_softmax_kvmem_validtrim_unfilt_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/gold_resp.txt"#sys.argv[1] 6 | pred_file = "model_softmax_decoder_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/pred_sent.txt" #sys.argv[2] 7 | ent_file = "model_softmax_kvmem_validtrim_unfilt_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/top20_ent_from_mem.txt" 8 | goldlines = open(gold_file).readlines() 9 | predlines = open(pred_file).readlines() 10 | entlines = open(ent_file).readlines() 11 | 12 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 13 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 14 | k = 0 15 | max_k = 10 16 | length = len(pred_op) 17 | replace_kb = True 18 | for j in range(len(pred_op)): 19 | if pred_op[j] in ['','','','','']: 20 | pred_op[j] = '' 21 | if pred_op[j]=='': 22 | length = j 23 | if pred_op[j].startswith(''): 24 | if not replace_kb: 25 | pred_op[j] = '' 26 | continue 27 | if k == len(kb_name_list_unique) or k == max_k: 28 | replace_kb = False 29 | pred_op[j] = '' 30 | continue 31 | pred_op[j] = kb_name_list_unique[k] 32 | k = k+1 33 | pred_op = pred_op[:length] 34 | pred_op = re.sub(' +',' ',' '.join(pred_op)).strip() 35 | return pred_op 36 | 37 | acc = 0.0 38 | count = 0.0 39 | for goldline, predline,entline in zip(goldlines,predlines,entlines): 40 | goldline = goldline.strip().lower() 41 | predline = predline.strip().lower() 42 | word_list = predline.strip().split(' ') 43 | ''' 44 | kb_count = 1 45 | for word in pred.strip().split(' '): 46 | if word=='': 47 | word_list.append('_'+str(kb_count)) 48 | kb_count = kb_count+1 49 | else: 50 | word_list.append(word) 51 | ''' 52 | word_list = list(OrderedDict.fromkeys(word_list)) 53 | if '' in word_list: 54 | word_list.remove('') 55 | if '' in word_list: 56 | word_list.remove('') 57 | if '' in word_list: 58 | word_list.remove('') 59 | if '' in word_list: 60 | word_list.remove('') 61 | if '|' in entline: 62 | entline = entline.strip().split('|') 63 | else: 64 | entline = [x.strip() for x in entline.strip().split(',')] 65 | predline = replace_kb_ent_in_resp(entline, word_list) 66 | acc_old = acc 67 | print goldline,'||', predline, 68 | goldline = " ".join([x for x in goldline.lower().split(' ') if x in ['yes','no']]) 69 | predline = " ".join([x for x in predline.lower().split(' ') if x in ['yes','no']]) 70 | print '||',predline, 71 | if goldline==predline: 72 | acc=acc+1.0 73 | ''' 74 | else: 75 | goldline = goldline.split(" ") 76 | predline = predline.split(" ") 77 | frac_acc = 0.0 78 | for i in range(min(len(goldline),len(predline))): 79 | if goldline[i]==predline[i]: 80 | frac_acc = frac_acc + 1.0/float(len(goldline)) 81 | acc=acc+frac_acc 82 | ''' 83 | print '||',acc-acc_old 84 | count=count+1.0 85 | #print predline,'||',goldline 86 | print acc/count 87 | 88 | -------------------------------------------------------------------------------- /evaluate/compute_count_acc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import re 3 | import sys 4 | from words2number import * 5 | gold_file = "model_softmax_kvmem_validtrim_unfilt_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/gold_resp.txt"#sys.argv[1] 6 | pred_file = "model_softmax_decoder_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/pred_sent.txt" #sys.argv[2] 7 | ent_file = "model_softmax_kvmem_validtrim_unfilt_newversion/test_output_"+sys.argv[2]+"_"+sys.argv[1]+"/top20_ent_from_mem.txt" 8 | goldlines = open(gold_file).readlines() 9 | predlines = open(pred_file).readlines() 10 | entlines = open(ent_file).readlines() 11 | 12 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 13 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 14 | k = 0 15 | max_k = 10 16 | length = len(pred_op) 17 | replace_kb = True 18 | for j in range(len(pred_op)): 19 | if pred_op[j] in ['','','','','']: 20 | pred_op[j] = '' 21 | if pred_op[j]=='': 22 | length = j 23 | if pred_op[j].startswith(''): 24 | if not replace_kb: 25 | pred_op[j] = '' 26 | continue 27 | if k == len(kb_name_list_unique) or k == max_k: 28 | replace_kb = False 29 | pred_op[j] = '' 30 | continue 31 | pred_op[j] = kb_name_list_unique[k] 32 | k = k+1 33 | pred_op = pred_op[:length] 34 | pred_op = re.sub(' +',' ',' '.join(pred_op)).strip() 35 | return pred_op 36 | 37 | def parse(line): 38 | line = line.replace('respectively','').strip() 39 | number1 = None 40 | entity1 = None 41 | number2 = None 42 | entity2 = None 43 | try: 44 | number1=text2int(line) 45 | #full line was a single numerical (in words) 46 | except: 47 | if line.isdigit(): 48 | number1 = int(line) 49 | #full line was a single numerical (in digits) 50 | else: 51 | line = ' '+line.strip()+' ' 52 | if ' and ' in line: 53 | parts = [x.strip() for x in line.split(' and ')] 54 | if len(parts)>2: 55 | print 'not handled ', line 56 | try: 57 | number1 = int(parts[0].split(' ')[0]) 58 | entity1 = " ".join(parts[0].split(' ')[1:]) 59 | except: 60 | try: 61 | number1 = int(parts[0]) 62 | except: 63 | try: 64 | number1 = text2int(parts[0]) 65 | except: 66 | try: 67 | number1 = text2int(parts[0].split(' ')[0]) 68 | entity1 = " ".join(parts[0].split(' ')[1:]) 69 | except: 70 | if line.isdigit(): 71 | number1 = int(line) 72 | else: 73 | number1 = None 74 | entity1 = None 75 | try: 76 | number2 = int(parts[1].split(' ')[0]) 77 | entity2 = " ".join(parts[1].split(' ')[1:]) 78 | except: 79 | try: 80 | number2 = int(parts[1]) 81 | except: 82 | try: 83 | number2 = text2int(parts[1]) 84 | except: 85 | try: 86 | number2 = text2int(parts[1].split(' ')[0]) 87 | entity2 = " ".join(parts[1].split(' ')[1:]) 88 | except: 89 | if line.isdigit(): 90 | number2 = int(line) 91 | else: 92 | number2 = None 93 | entity2 = None 94 | else: 95 | line = line.strip() 96 | try: 97 | number1 = int(line.split(' ')[0]) 98 | entity1 = " ".join(line.split(' ')[1:]) 99 | except: 100 | try: 101 | number1 = text2int(line.split(' ')[0]) 102 | entity1 = " ".join(line.split(' ')[1:]) 103 | except: 104 | print 'not handled ', line 105 | #if entity1 is None: 106 | # print line, ' ---->', number1 107 | #else: 108 | # print line, ' ---->', entity1, '(',number1,') ',entity2, '(',number2,') ' 109 | return number1, entity1, number2, entity2 110 | acc = 0.0 111 | count = 0.0 112 | for goldline, predline,entline in zip(goldlines,predlines,entlines): 113 | goldline = goldline.strip().lower() 114 | predline = predline.strip().lower() 115 | word_list = predline.strip().split(' ') 116 | ''' 117 | kb_count = 1 118 | for word in pred.strip().split(' '): 119 | if word=='': 120 | word_list.append('_'+str(kb_count)) 121 | kb_count = kb_count+1 122 | else: 123 | word_list.append(word) 124 | ''' 125 | word_list = list(OrderedDict.fromkeys(word_list)) 126 | if '' in word_list: 127 | word_list.remove('') 128 | if '' in word_list: 129 | word_list.remove('') 130 | if '' in word_list: 131 | word_list.remove('') 132 | if '' in word_list: 133 | word_list.remove('') 134 | if '|' in entline: 135 | entline = entline.strip().split('|') 136 | else: 137 | entline = [x.strip() for x in entline.strip().split(',')] 138 | predline = replace_kb_ent_in_resp(entline, word_list) 139 | #print goldline+' || '+predline 140 | gold_number1, gold_entity1, gold_number2, gold_entity2 = parse(goldline) 141 | pred_number1, pred_entity1, pred_number2, pred_entity2 = parse(predline) 142 | print gold_number1, '::',gold_entity1, ' ::',gold_number2, '::',gold_entity2 143 | print pred_number1, '::',pred_entity1, '::',pred_number2, '::',pred_entity2 144 | acc_old = acc 145 | if gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2==pred_number2 and gold_entity2==pred_entity2: 146 | acc=acc+1.0 147 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2==pred_number1 and gold_entity2==pred_entity1: 148 | acc=acc+1.0 149 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None and gold_entity2 is None: 150 | acc=acc+1.0 151 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 152 | acc=acc+0.5 153 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None and gold_entity2 is None: 154 | acc=acc+1.0 155 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 156 | acc=acc+0.5 157 | elif gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None: 158 | acc=acc+1.0 159 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None: 160 | acc=acc+1.0 161 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is None and gold_entity1 is None:# and sys.argv[1]=='15': 162 | # acc=acc+0.5 163 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is None and gold_entity1 is not None:# and sys.argv[1]=='15': 164 | # acc=acc+0.5 165 | #elif gold_entity1 is not None and (gold_entity1==pred_entity1 or gold_entity1==pred_entity2) and gold_number2 is None:# and sys.argv[1]=='15': 166 | # acc=acc+0.5 167 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is not None:# and sys.argv[1]=='15': 168 | # acc=acc+0.25 169 | #elif gold_entity1 is not None and (gold_entity1==pred_entity1 or gold_entity1==pred_entity2) and gold_number2 is not None:# and sys.argv[1]=='15': 170 | # acc=acc+0.25 171 | #elif gold_number2 is not None and (gold_number2==pred_number1 or gold_number2==pred_number2):# and sys.argv[1]=='15': 172 | # acc=acc+0.25 173 | #elif gold_entity2 is not None and (gold_entity2==pred_entity1 or gold_entity2==pred_entity2):# and sys.argv[1]=='15': 174 | # acc=acc+0.25 175 | #if (acc-acc_old)>0: 176 | ''' 177 | if gold_entity1 is None: 178 | gold_entity1 = "None" 179 | if gold_number1 is None: 180 | gold_number1 = "None" 181 | if gold_entity2 is None: 182 | gold_entity2 = "None" 183 | if gold_number2 is None: 184 | gold_number2 = "None" 185 | ''' 186 | gold_parsed = str(gold_entity1)+'('+str(gold_number1)+') '+str(gold_entity2)+'('+str(gold_number2)+') ' 187 | pred_parsed = str(pred_entity1)+'('+str(pred_number1)+') '+str(pred_entity2)+'('+str(pred_number2)+') ' 188 | #if acc-acc_old>0: 189 | #print goldline, '::', gold_parsed,' ||| ',predline, '::', pred_parsed, ' ||| ', (acc-acc_old) 190 | #print '' 191 | count=count+1.0 192 | print acc/count 193 | -------------------------------------------------------------------------------- /evaluate/compute_count_acc2.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import re 3 | import sys 4 | from words2number import * 5 | gold_file = "../new_model_softmax_decoder2/test_output_"+sys.argv[1]+"/true_sent.txt"#sys.argv[1] 6 | pred_file = "../new_model_softmax_decoder2/test_output_"+sys.argv[1]+"/pred_sent.txt" #sys.argv[2] 7 | goldlines = open(gold_file).readlines() 8 | predlines = open(pred_file).readlines() 9 | #entlines = open(ent_file).readlines() 10 | 11 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 12 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 13 | k = 0 14 | max_k = 10 15 | length = len(pred_op) 16 | replace_kb = True 17 | for j in range(len(pred_op)): 18 | if pred_op[j] in ['','','','','']: 19 | pred_op[j] = '' 20 | if pred_op[j]=='': 21 | length = j 22 | if pred_op[j].startswith(''): 23 | if not replace_kb: 24 | pred_op[j] = '' 25 | continue 26 | if k == len(kb_name_list_unique) or k == max_k: 27 | replace_kb = False 28 | pred_op[j] = '' 29 | continue 30 | pred_op[j] = kb_name_list_unique[k] 31 | k = k+1 32 | pred_op = pred_op[:length] 33 | pred_op = re.sub(' +',' ',' '.join(pred_op)).strip() 34 | return pred_op 35 | 36 | def parse(line): 37 | line = line.replace('respectively','').strip() 38 | number1 = None 39 | entity1 = None 40 | number2 = None 41 | entity2 = None 42 | try: 43 | number1=text2int(line) 44 | #full line was a single numerical (in words) 45 | except: 46 | if line.isdigit(): 47 | number1 = int(line) 48 | #full line was a single numerical (in digits) 49 | else: 50 | line = ' '+line.strip()+' ' 51 | if ' and ' in line: 52 | parts = [x.strip() for x in line.split(' and ')] 53 | if len(parts)>2: 54 | print 'not handled ', line 55 | try: 56 | number1 = int(parts[0].split(' ')[0]) 57 | entity1 = " ".join(parts[0].split(' ')[1:]) 58 | except: 59 | try: 60 | number1 = int(parts[0]) 61 | except: 62 | try: 63 | number1 = text2int(parts[0]) 64 | except: 65 | try: 66 | number1 = text2int(parts[0].split(' ')[0]) 67 | entity1 = " ".join(parts[0].split(' ')[1:]) 68 | except: 69 | if line.isdigit(): 70 | number1 = int(line) 71 | else: 72 | number1 = None 73 | entity1 = None 74 | try: 75 | number2 = int(parts[1].split(' ')[0]) 76 | entity2 = " ".join(parts[1].split(' ')[1:]) 77 | except: 78 | try: 79 | number2 = int(parts[1]) 80 | except: 81 | try: 82 | number2 = text2int(parts[1]) 83 | except: 84 | try: 85 | number2 = text2int(parts[1].split(' ')[0]) 86 | entity2 = " ".join(parts[1].split(' ')[1:]) 87 | except: 88 | if line.isdigit(): 89 | number2 = int(line) 90 | else: 91 | number2 = None 92 | entity2 = None 93 | else: 94 | line = line.strip() 95 | try: 96 | number1 = int(line.split(' ')[0]) 97 | entity1 = " ".join(line.split(' ')[1:]) 98 | except: 99 | try: 100 | number1 = text2int(line.split(' ')[0]) 101 | entity1 = " ".join(line.split(' ')[1:]) 102 | except: 103 | print 'not handled ', line 104 | #if entity1 is None: 105 | # print line, ' ---->', number1 106 | #else: 107 | # print line, ' ---->', entity1, '(',number1,') ',entity2, '(',number2,') ' 108 | return number1, entity1, number2, entity2 109 | acc = 0.0 110 | count = 0.0 111 | for goldline, predline in zip(goldlines,predlines): 112 | goldline = goldline.strip().lower() 113 | predline = predline.strip().lower().replace('you mean','').replace('you and','').replace('you','').replace('?','').strip() 114 | predline = predline.replace('','').replace('','') 115 | if len(goldline.strip())==0 or 'did you mean' in goldline: 116 | continue 117 | word_list = predline.strip().split(' ') 118 | ''' 119 | kb_count = 1 120 | for word in pred.strip().split(' '): 121 | if word=='': 122 | word_list.append('_'+str(kb_count)) 123 | kb_count = kb_count+1 124 | else: 125 | word_list.append(word) 126 | ''' 127 | word_list = list(OrderedDict.fromkeys(word_list)) 128 | if '' in word_list: 129 | word_list.remove('') 130 | if '' in word_list: 131 | word_list.remove('') 132 | if '' in word_list: 133 | word_list.remove('') 134 | if '' in word_list: 135 | word_list.remove('') 136 | predline = ' '.join(word_list) 137 | #predline = replace_kb_ent_in_resp(entline, word_list) 138 | #print goldline+' || '+predline 139 | #print goldline, '||', predline 140 | gold_number1, gold_entity1, gold_number2, gold_entity2 = parse(goldline) 141 | pred_number1, pred_entity1, pred_number2, pred_entity2 = parse(predline) 142 | #print gold_number1, '::',gold_entity1, ' ::',gold_number2, '::',gold_entity2 143 | #print pred_number1, '::',pred_entity1, '::',pred_number2, '::',pred_entity2 144 | acc_old = acc 145 | if gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2==pred_number2 and gold_entity2==pred_entity2: 146 | acc=acc+1.0 147 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2==pred_number1 and gold_entity2==pred_entity1: 148 | acc=acc+1.0 149 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None and gold_entity2 is None: 150 | acc=acc+1.0 151 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 152 | acc=acc+0.5 153 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None and gold_entity2 is None: 154 | acc=acc+1.0 155 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 156 | acc=acc+0.5 157 | elif gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None: 158 | acc=acc+1.0 159 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None: 160 | acc=acc+1.0 161 | ''' 162 | if gold_number1==pred_number1 and gold_number2==pred_number2: 163 | acc=acc+1.0 164 | elif gold_number1==pred_number2 and gold_number2==pred_number1: 165 | acc=acc+1.0 166 | elif gold_number1 is not None and gold_number1==pred_number1 and gold_number2 is None: 167 | acc=acc+1.0 168 | elif gold_number1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is not None:# and sys.argv[1]=='15': 169 | acc=acc+0.5 170 | elif gold_number1 is not None and gold_number1==pred_number2 and gold_number2 is None: 171 | acc=acc+1.0 172 | elif gold_number1 is not None and gold_number1==pred_number2 and gold_number2 is not None:# and sys.argv[1]=='15': 173 | acc=acc+0.5 174 | elif gold_number1==pred_number1 and gold_number2 is None: 175 | acc=acc+1.0 176 | elif gold_number1==pred_number2 and gold_number2 is None: 177 | acc=acc+1.0 178 | ''' 179 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is None and gold_entity1 is None:# and sys.argv[1]=='15': 180 | # acc=acc+0.5 181 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is None and gold_entity1 is not None:# and sys.argv[1]=='15': 182 | # acc=acc+0.5 183 | #elif gold_entity1 is not None and (gold_entity1==pred_entity1 or gold_entity1==pred_entity2) and gold_number2 is None:# and sys.argv[1]=='15': 184 | # acc=acc+0.5 185 | #elif gold_number1 is not None and (gold_number1==pred_number1 or gold_number1==pred_number2) and gold_number2 is not None:# and sys.argv[1]=='15': 186 | # acc=acc+0.25 187 | #elif gold_entity1 is not None and (gold_entity1==pred_entity1 or gold_entity1==pred_entity2) and gold_number2 is not None:# and sys.argv[1]=='15': 188 | # acc=acc+0.25 189 | #elif gold_number2 is not None and (gold_number2==pred_number1 or gold_number2==pred_number2):# and sys.argv[1]=='15': 190 | # acc=acc+0.25 191 | #elif gold_entity2 is not None and (gold_entity2==pred_entity1 or gold_entity2==pred_entity2):# and sys.argv[1]=='15': 192 | # acc=acc+0.25 193 | #if (acc-acc_old)>0: 194 | ''' 195 | if gold_entity1 is None: 196 | gold_entity1 = "None" 197 | if gold_number1 is None: 198 | gold_number1 = "None" 199 | if gold_entity2 is None: 200 | gold_entity2 = "None" 201 | if gold_number2 is None: 202 | gold_number2 = "None" 203 | ''' 204 | gold_parsed = str(gold_entity1)+'('+str(gold_number1)+') '+str(gold_entity2)+'('+str(gold_number2)+') ' 205 | pred_parsed = str(pred_entity1)+'('+str(pred_number1)+') '+str(pred_entity2)+'('+str(pred_number2)+') ' 206 | #if acc-acc_old>0: 207 | print goldline, '::', gold_parsed,' ||| ',predline, '::', pred_parsed, ' ||| ', (acc-acc_old) 208 | #print '' 209 | count=count+1.0 210 | print acc/count 211 | -------------------------------------------------------------------------------- /evaluate/compute_count_accuracy.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import re 3 | import sys 4 | from words2number import * 5 | gold_file = "../"+sys.argv[2]+"/test_output_"+sys.argv[1]+"/true_sent.txt"#sys.argv[1] 6 | pred_file = "../"+sys.argv[2]+"/test_output_"+sys.argv[1]+"/pred_sent.txt" #sys.argv[2] 7 | goldlines = open(gold_file).readlines() 8 | predlines = open(pred_file).readlines() 9 | #entlines = open(ent_file).readlines() 10 | 11 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 12 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 13 | k = 0 14 | max_k = 10 15 | length = len(pred_op) 16 | replace_kb = True 17 | for j in range(len(pred_op)): 18 | if pred_op[j] in ['','','','','']: 19 | pred_op[j] = '' 20 | if pred_op[j]=='': 21 | length = j 22 | if pred_op[j].startswith(''): 23 | if not replace_kb: 24 | pred_op[j] = '' 25 | continue 26 | if k == len(kb_name_list_unique) or k == max_k: 27 | replace_kb = False 28 | pred_op[j] = '' 29 | continue 30 | pred_op[j] = kb_name_list_unique[k] 31 | k = k+1 32 | pred_op = pred_op[:length] 33 | pred_op = re.sub(' +',' ',' '.join(pred_op)).strip() 34 | return pred_op 35 | 36 | def parse(line): 37 | line = line.replace('respectively','').strip() 38 | number1 = None 39 | entity1 = None 40 | number2 = None 41 | entity2 = None 42 | try: 43 | number1=text2int(line) 44 | #full line was a single numerical (in words) 45 | except: 46 | if line.isdigit(): 47 | number1 = int(line) 48 | #full line was a single numerical (in digits) 49 | else: 50 | line = ' '+line.strip()+' ' 51 | if ' and ' in line: 52 | parts = [x.strip() for x in line.split(' and ')] 53 | if len(parts)>2: 54 | print 'not handled ', line 55 | try: 56 | number1 = int(parts[0].split(' ')[0]) 57 | entity1 = " ".join(parts[0].split(' ')[1:]) 58 | except: 59 | try: 60 | number1 = int(parts[0]) 61 | except: 62 | try: 63 | number1 = text2int(parts[0]) 64 | except: 65 | try: 66 | number1 = text2int(parts[0].split(' ')[0]) 67 | entity1 = " ".join(parts[0].split(' ')[1:]) 68 | except: 69 | if line.isdigit(): 70 | number1 = int(line) 71 | else: 72 | number1 = None 73 | entity1 = None 74 | try: 75 | number2 = int(parts[1].split(' ')[0]) 76 | entity2 = " ".join(parts[1].split(' ')[1:]) 77 | except: 78 | try: 79 | number2 = int(parts[1]) 80 | except: 81 | try: 82 | number2 = text2int(parts[1]) 83 | except: 84 | try: 85 | number2 = text2int(parts[1].split(' ')[0]) 86 | entity2 = " ".join(parts[1].split(' ')[1:]) 87 | except: 88 | if line.isdigit(): 89 | number2 = int(line) 90 | else: 91 | number2 = None 92 | entity2 = None 93 | else: 94 | line = line.strip() 95 | try: 96 | number1 = int(line.split(' ')[0]) 97 | entity1 = " ".join(line.split(' ')[1:]) 98 | except: 99 | try: 100 | number1 = text2int(line.split(' ')[0]) 101 | entity1 = " ".join(line.split(' ')[1:]) 102 | except: 103 | print 'not handled ', line 104 | return number1, entity1, number2, entity2 105 | 106 | def parse2(line): 107 | return [x for x in line.strip().split(' ') if x.isdigit()] 108 | prec = 0.0 109 | rec = 0.0 110 | jacc = 0.0 111 | count = 0.0 112 | for goldline, predline in zip(goldlines,predlines): 113 | goldline = goldline.strip().lower() 114 | predline = predline.strip().lower().replace('you mean','').replace('you and','').replace('you','').replace('?','').strip() 115 | predline = predline.replace('','').replace('','') 116 | if len(goldline.strip())==0 or 'did you mean' in goldline: 117 | continue 118 | word_list = predline.strip().split(' ') 119 | word_list = list(OrderedDict.fromkeys(word_list)) 120 | if '' in word_list: 121 | word_list.remove('') 122 | if '' in word_list: 123 | word_list.remove('') 124 | if '' in word_list: 125 | word_list.remove('') 126 | if '' in word_list: 127 | word_list.remove('') 128 | predline = ' '.join(word_list) 129 | gold_number1, gold_entity1, gold_number2, gold_entity2 = parse(goldline) 130 | pred_number1, pred_entity1, pred_number2, pred_entity2 = parse(predline) 131 | #acc_old = acc 132 | gold_set = parse2(goldline)#[x for x in [str(gold_number1), str(gold_number2)] if x!='None'] 133 | pred_set = parse2(predline)#[x for x in [str(pred_number1), str(pred_number2)] if x!='None'] 134 | ints = float(len([x for x in pred_set if x in gold_set])) 135 | union = ints + float(len([x for x in pred_set if x not in gold_set])) + float(len([x for x in gold_set if x not in pred_set])) 136 | #print 'pred_line ', predline, 'gold_set ', gold_set, 'pred_set ', pred_set,' ints ', ints, 'union ', union 137 | if union>0: 138 | jacc += ints/union 139 | if len(pred_set)>0: 140 | prec += ints/float(len(pred_set)) 141 | if len(gold_set)>0: 142 | rec += ints/float(len(gold_set)) 143 | '''' 144 | if gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2==pred_number2 and gold_entity2==pred_entity2: 145 | acc=acc+1.0 146 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2==pred_number1 and gold_entity2==pred_entity1: 147 | acc=acc+1.0 148 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None and gold_entity2 is None: 149 | acc=acc+1.0 150 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 151 | acc=acc+0.5 152 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None and gold_entity2 is None: 153 | acc=acc+1.0 154 | elif gold_number1 is not None and gold_entity1 is not None and gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is not None and gold_entity2 is not None:# and sys.argv[1]=='15': 155 | acc=acc+0.5 156 | elif gold_number1==pred_number1 and gold_entity1==pred_entity1 and gold_number2 is None: 157 | acc=acc+1.0 158 | elif gold_number1==pred_number2 and gold_entity1==pred_entity2 and gold_number2 is None: 159 | acc=acc+1.0 160 | gold_parsed = str(gold_entity1)+'('+str(gold_number1)+') '+str(gold_entity2)+'('+str(gold_number2)+') ' 161 | pred_parsed = str(pred_entity1)+'('+str(pred_number1)+') '+str(pred_entity2)+'('+str(pred_number2)+') ' 162 | #if acc-acc_old>0: 163 | print goldline, '::', gold_parsed,' ||| ',predline, '::', pred_parsed, ' ||| ', (acc-acc_old) 164 | #print '' 165 | ''' 166 | count=count+1.0 167 | prec /= count 168 | rec /= count 169 | jacc /= count 170 | f1 = (2*prec*rec)/(prec+rec) 171 | print 'total prec ', prec, ' total rec ', rec, 'total jacc ', jacc, ' count ', count 172 | print 'precsion ', prec 173 | print 'recall ', rec 174 | print 'jacc ', jacc 175 | print 'f1 ', f1 176 | -------------------------------------------------------------------------------- /evaluate/compute_precision_active_set.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import sys, re 3 | sys.path.append('../') 4 | from itertools import izip 5 | import time 6 | from load_wikidata_wfn import * 7 | # from parse_active_set import * 8 | 9 | states = [3,4,5,6,16,17,18,19] 10 | test_types = ['hard'] 11 | #ks = [2,5,10,20] 12 | 13 | #act_set_file = 'test_easy_active_set_-1.txt' 14 | #pred_ent_name_file = 'top20_ent_id_from_mem.txt' 15 | #k = 20 16 | 17 | def parse_active_set(active_set, target): 18 | active_set = active_set.strip() 19 | anding = False 20 | orring = False 21 | notting1 = False 22 | notting2 = False 23 | if active_set.startswith('AND(') or active_set.startswith('OR('): 24 | if active_set.startswith('AND('): 25 | anding = True 26 | active_set = re.sub('^\(|\)$','',active_set.replace('AND', '',1)) 27 | if active_set.startswith('OR('): 28 | anding = True 29 | active_set = re.sub('^\(|\)$','',active_set.replace('OR', '',1)) 30 | while active_set.startswith('(') and active_set.endswith(')'): 31 | active_set = re.sub('^\(|\)$','',active_set) 32 | active_set_parts = active_set.split(', ') 33 | active_set_part1 = active_set_parts[0].strip() 34 | active_set_part2 = active_set_parts[1].strip() 35 | if active_set_part1.startswith('NOT('): 36 | active_set_part1 = re.sub('^\(|\)$','',active_set_part1.replace('NOT','',1)) 37 | notting1 = True 38 | is_present1 = parse_basic_active_set(active_set_part1.strip(), target) 39 | if notting1: 40 | is_present1 = not is_present1 41 | if active_set_part2.startswith('NOT('): 42 | active_set_part2 = re.sub('^\(|\)$','',active_set_part2.replace('NOT','',1)) 43 | notting2 = True 44 | is_present2 = parse_basic_active_set(active_set_part2.strip(), target) 45 | if notting2: 46 | is_present2 = not is_present2 47 | if anding: 48 | if is_present1 and is_present2: 49 | return True 50 | else: 51 | return False 52 | if orring: 53 | if is_present1 or is_present2: 54 | return True 55 | else: 56 | return False 57 | else: 58 | return parse_basic_active_set(active_set, target) 59 | 60 | def parse_basic_active_set(active_set, target): 61 | # st_time = time.time() 62 | if len(active_set) == 0: 63 | return False 64 | active_set_orig = active_set 65 | 66 | while active_set.startswith('(') and active_set.endswith(')'): 67 | active_set = re.sub('^\(|\)$','',active_set) 68 | while active_set.startswith('(') and not active_set.endswith(')'): 69 | active_set = re.sub('^\(','',active_set) 70 | while active_set.endswith(')') and not active_set.startswith('('): 71 | active_set = re.sub('\)$','',active_set) 72 | 73 | assert not active_set.startswith('(') 74 | assert not active_set.endswith(')') 75 | 76 | # print 'time taken for regex proc = %f' % (time.time() - st_time) 77 | 78 | parent = None 79 | parent1 = None 80 | parent2 = None 81 | 82 | try: 83 | assert len(active_set.strip().split(',')) == 3 84 | except: 85 | print 'active_set = %s' % active_set_orig 86 | logging.exception('Something aweful happened') 87 | raise Exception('ERROR!!!') 88 | parts = active_set.strip().split(',') 89 | 90 | if parts[0].startswith('c') and parts[2].startswith('c'): 91 | parent1 = parts[0].split('(')[1].split(')')[0].strip() 92 | parent2 = parts[2].split('(')[1].split(')')[0].strip() 93 | if parts[0].startswith('c'): 94 | parent = parts[0].split('(')[1].split(')')[0].strip() 95 | ent = parts[2].strip() 96 | elif parts[2].startswith('c'): 97 | parent = parts[2].split('(')[1].split(')')[0].strip() 98 | ent = parts[0].strip() 99 | rel = parts[1].strip() 100 | if parent and ent: 101 | try: 102 | if child_par_dict[target] == parent and (target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]): 103 | # print 'time taken = %f' % (time.time() - st_time) 104 | return True 105 | else: 106 | # print 'time taken = %f' % (time.time() - st_time) 107 | return False 108 | except: 109 | return False 110 | elif parent1 and parent2: 111 | try: 112 | if child_par_dict[target] == parent1: 113 | children2 = par_child_dict[parent2] 114 | for ent in children2: 115 | if target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]: 116 | # print 'time taken = %f' % (time.time() - st_time) 117 | return True 118 | elif child_par_dict[target] == parent2: 119 | children1 = par_child_dict[parent1] 120 | for ent in children1: 121 | if target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]: 122 | # print 'time taken = %f' % (time.time() - st_time) 123 | return True 124 | except: 125 | # print 'time taken = %f' % (time.time() - st_time) 126 | return False 127 | # print 'time taken = %f' % (time.time() - st_time) 128 | return False 129 | 130 | 131 | def is_contained_in_act_set(active_set, ent_id): 132 | active_set_tokens = active_set.split('#') 133 | 134 | for active_set_token in active_set_tokens: 135 | try: 136 | start_time = time.time() 137 | if parse_active_set(active_set_token, ent_id): 138 | return True 139 | # print 'active_set = %s' % active_set_token 140 | # print 'time taken = %f' % (time.time() - start_time) 141 | except: 142 | print active_set_token 143 | logging.exception('Something aweful happened') 144 | raise Exception('ERROR!!!') 145 | return False 146 | ''' 147 | for state in states: 148 | for test_type in test_types: 149 | for k in ks: 150 | n_prec_sum = 0 151 | count = 0 152 | sTime= time.time() 153 | act_set_file = sys.argv[1]+'/test_output_'+test_type+'_'+str(state)+'/active_set.txt' 154 | pred_ent_name_file = sys.argv[1]+'/test_output_'+test_type+'_'+str(state)+'/top20_ent_id_from_mem.txt' 155 | with open(act_set_file) as act_set_lines, open(pred_ent_name_file) as pred_lines: 156 | for act_set_line, pred_line in izip(act_set_lines,pred_lines): 157 | if time.time() - sTime > 1000: 158 | break 159 | try: 160 | if count % 1000 == 0: 161 | print count 162 | active_set = act_set_line.rstrip() 163 | pred_entities = pred_line.rstrip().split(', ') # QIDs 164 | 165 | pred_entities = pred_entities[:k] 166 | n_topK = len([x for x in pred_entities if is_contained_in_act_set(active_set, x)]) 167 | n_prec_sum += n_topK*1.0/float(len(pred_entities)) 168 | count += 1 169 | except: 170 | break 171 | 172 | avg_prec = n_prec_sum*1.0/float(count) 173 | print 'File ', pred_ent_name_file 174 | print 'Avg. prec for k= ',k,'= %f' % avg_prec 175 | print '' 176 | ''' 177 | 178 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 179 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 180 | k = 0 181 | max_k = 20 182 | length = len(pred_op) 183 | replace_kb = True 184 | top_k_ent = [] 185 | for j in range(len(pred_op)): 186 | if pred_op[j] in ['','','','','']: 187 | pred_op[j] = '' 188 | if pred_op[j]=='': 189 | length = j 190 | if pred_op[j].startswith(''): 191 | if not replace_kb: 192 | pred_op[j] = '' 193 | continue 194 | if k == len(kb_name_list_unique) or k == max_k: 195 | replace_kb = False 196 | pred_op[j] = '' 197 | continue 198 | top_k_ent.append(kb_name_list_unique[k]) 199 | k = k+1 200 | return top_k_ent 201 | 202 | for state in states: 203 | for test_type in test_types: 204 | n_prec_sum = 0 205 | count = 0 206 | sTime= time.time() 207 | act_set_file = sys.argv[1]+'/test_output_'+test_type+'_'+str(state)+'/active_set.txt' 208 | pred_ent_name_file = sys.argv[1]+'/test_output_'+test_type+'_'+str(state)+'/top20_ent_id_from_mem.txt' 209 | pred_file =sys.argv[2]+'/test_output_'+test_type+'_'+str(state)+'/pred_sent.txt' 210 | with open(act_set_file) as act_set_lines, open(pred_ent_name_file) as ent_lines, open(pred_file) as pred_lines: 211 | for pred, ent, act_set_line in izip(pred_lines, ent_lines, act_set_lines): 212 | word_list = pred.strip().split(' ') 213 | kb_count = 1 214 | for word in pred.strip().split(' '): 215 | if word=='': 216 | word_list.append('_'+str(kb_count)) 217 | kb_count = kb_count+1 218 | else: 219 | word_list.append(word) 220 | word_list = list(OrderedDict.fromkeys(word_list)) 221 | if '|' in ent: 222 | ent = ent.strip().split('|') 223 | else: 224 | ent = [x.strip() for x in ent.strip().split(',')] 225 | top_k_ent = replace_kb_ent_in_resp(ent, word_list) 226 | if time.time() - sTime > 10000: 227 | continue 228 | #try: 229 | if count % 1000 == 0: 230 | print count 231 | active_sets = act_set_line.rstrip().split('#') 232 | pred_entities = top_k_ent 233 | if(len(pred_entities))>0: 234 | n_topK = 0 235 | for x in pred_entities: 236 | for active_set in active_sets: 237 | if is_contained_in_act_set(active_set, x): 238 | n_topK += 1 239 | n_prec_sum += n_topK*1.0/float(len(pred_entities)) 240 | count += 1 241 | #except: 242 | # continue 243 | avg_prec = n_prec_sum*1.0/float(count) 244 | print 'File ', pred_ent_name_file 245 | print 'Avg. prec = %f' % avg_prec 246 | -------------------------------------------------------------------------------- /evaluate/compute_precision_active_set.sh: -------------------------------------------------------------------------------- 1 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 12 2 2 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 12 5 3 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 12 10 4 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 12 20 5 | 6 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 13 2 7 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 13 5 8 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 13 10 9 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 13 20 10 | 11 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 14 2 12 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 14 5 13 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 14 10 14 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 14 20 15 | 16 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 15 2 17 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 15 5 18 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 15 10 19 | python compute_precision_active_set.py ../new_model_softmax_kvmem easy 15 20 20 | 21 | -------------------------------------------------------------------------------- /evaluate/compute_recall.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import sys 3 | from itertools import izip 4 | 5 | gold_ent_name_file = sys.argv[1] 6 | pred_ent_name_file = sys.argv[2] 7 | pred_file = sys.argv[3] 8 | n_recall_sum = 0 9 | count = 0 10 | n_prec_sum = 0 11 | n_jacc_sum = 0 12 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 13 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 14 | #print kb_name_list_unique, '::',prob_memory_entities 15 | k = 0 16 | max_k = 100000 17 | length = len(pred_op) 18 | replace_kb = True 19 | top_k_ent = [] 20 | for j in range(len(pred_op)): 21 | if pred_op[j] in ['','','','','']: 22 | pred_op[j] = '' 23 | if pred_op[j]=='': 24 | length = j 25 | if pred_op[j].startswith(''): 26 | if not replace_kb: 27 | pred_op[j] = '' 28 | continue 29 | if k == len(kb_name_list_unique) or k == max_k: 30 | replace_kb = False 31 | pred_op[j] = '' 32 | continue 33 | top_k_ent.append(kb_name_list_unique[k]) 34 | k = k+1 35 | return k 36 | 37 | with open(gold_ent_name_file) as gold_, open(pred_ent_name_file) as pred_, open(pred_file) as sent_: 38 | gold_lines = gold_.readlines() 39 | pred_lines = pred_.readlines() 40 | sent_lines = sent_.readlines() 41 | sent_lines = sent_lines[:len(gold_lines)] 42 | for gold_line, ent, pred in izip(gold_lines,pred_lines, sent_lines): 43 | #for gold_line, ent in izip(gold_lines, pred_lines): 44 | word_list = pred.strip().split(' ') 45 | #print 'sent', word_list 46 | ''' 47 | kb_count = 1 48 | for word in pred.strip().split(' '): 49 | if word=='': 50 | word_list.append('_'+str(kb_count)) 51 | kb_count = kb_count+1 52 | else: 53 | word_list.append(word) 54 | word_list = list(OrderedDict.fromkeys(word_list)) 55 | ''' 56 | if '|' in ent: 57 | ent = ent.strip().split('|') 58 | else: 59 | ent = [x.strip() for x in ent.strip().split(',')] 60 | index = replace_kb_ent_in_resp(ent, word_list) 61 | #print 'num of kb words ', index 62 | top_k_ent = ent[:index] 63 | if '|' in gold_line: 64 | gold_entities = gold_line.rstrip().split('|') 65 | else: 66 | gold_entities = [x.strip() for x in gold_line.rstrip().split(',')] 67 | #print 'topk ', top_k_ent, ' :::::: gold', gold_entities 68 | if len(gold_entities) > 0: 69 | pred_entities = top_k_ent 70 | #print 'len of pred entities', len(pred_entities) 71 | n_topK = len(set(gold_entities).intersection(set(pred_entities))) 72 | union= len(set(gold_entities).union(set(pred_entities))) 73 | n_recall_sum += n_topK*1.0/float(len(gold_entities)) 74 | count += 1 75 | if len(pred_entities) > 0: 76 | n_prec_sum += n_topK*1.0/float(len(pred_entities)) 77 | n_jacc_sum += float(n_topK)/float(union) 78 | avg_recall = n_recall_sum*100.0/float(count) 79 | print gold_ent_name_file 80 | print 'Avg. recall over= %f' % avg_recall 81 | avg_prec= n_prec_sum*100.0/float(count) 82 | print 'Avg. precision over= %f' % avg_prec 83 | avg_jacc = n_jacc_sum*100.0/float(count) 84 | print 'Avg. jaccard over= %f' % avg_jacc 85 | print 'Avg. F1 over= ',(2.0*avg_recall*avg_prec)/(avg_recall+avg_prec) 86 | print 'total prec ', n_prec_sum, ' total rec ', n_recall_sum, 'total jacc ', n_jacc_sum, ' count ', count 87 | print 'All numbers in %' 88 | -------------------------------------------------------------------------------- /evaluate/compute_recall.sh: -------------------------------------------------------------------------------- 1 | ~/anaconda/bin/python compute_recall.py ../new_model_softmax_kvmem_direct_gold2/test_output_$1/gold_resp_ent_id.txt ../new_model_softmax_kvmem_direct_gold2/test_output_$1/top20_ent_id_from_mem.txt ../new_model_softmax_decoder_direct_gold2/test_output_simple/pred_sent.txt 2 | #python compute_recall.py $1/test_output_easy_$2/gold_ent.txt $1/test_output_easy_$2/top20_ent_from_mem.txt $3/test_output_easy_$2/pred_sent.txt 3 | #python compute_recall.py $1/test_output_easy_$2/gold_ent.txt $1/test_output_easy_$2/top20_ent_from_mem.txt 2 4 | #python compute_recall.py $1/test_output_easy_$2/gold_ent.txt $1/test_output_easy_$2/top20_ent_from_mem.txt 5 5 | #python compute_recall.py $1/test_output_easy_$2/gold_ent.txt $1/test_output_easy_$2/top20_ent_from_mem.txt 10 6 | #python compute_recall.py $1/test_output_easy_$2/gold_ent.txt $1/test_output_easy_$2/top20_ent_from_mem.txt 20 7 | 8 | 9 | #python compute_recall.py $1/test_output_hard_$2/gold_ent.txt $1/test_output_hard_$2/top20_ent_from_mem.txt 2 10 | #python compute_recall.py $1/test_output_hard_$2/gold_ent.txt $1/test_output_hard_$2/top20_ent_from_mem.txt 5 11 | #python compute_recall.py $1/test_output_hard_$2/gold_ent.txt $1/test_output_hard_$2/top20_ent_from_mem.txt 10 12 | #python compute_recall.py $1/test_output_hard_$2/gold_ent.txt $1/test_output_hard_$2/top20_ent_from_mem.txt 20 13 | -------------------------------------------------------------------------------- /evaluate/compute_recall_active_set.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import sys, re 3 | from itertools import izip 4 | import time 5 | #from load_wikidata_wfn import * 6 | from parse_active_set import * 7 | #par_child_dict = json.load(open('/dccstor/cssblr/vardaan/dialog-qa/par_child_dict.json')) 8 | total_entities = len(par_child_dict) 9 | states = [6,9,10,12,13,14,15]#,7,8] 10 | test_types = ['hard'] 11 | ks = [2,5,10,20] 12 | 13 | #act_set_file = 'test_easy_active_set_-1.txt' 14 | #pred_ent_name_file = 'top20_ent_id_from_mem.txt' 15 | #k = 20 16 | def get_active_set_size(active_set): 17 | active_set = active_set.strip() 18 | anding = False 19 | orring = False 20 | notting1 = False 21 | notting2 = False 22 | if active_set.startswith('AND(') or active_set.startswith('OR('): 23 | if active_set.startswith('AND('): 24 | anding = True 25 | active_set = re.sub('^\(|\)$','',active_set.replace('AND', '',1)) 26 | if active_set.startswith('OR('): 27 | anding = True 28 | active_set = re.sub('^\(|\)$','',active_set.replace('OR', '',1)) 29 | while active_set.startswith('(') and active_set.endswith(')'): 30 | active_set = re.sub('^\(|\)$','',active_set) 31 | active_set_parts = active_set.split(', ') 32 | active_set_part1 = active_set_parts[0].strip() 33 | active_set_part2 = active_set_parts[1].strip() 34 | if active_set_part1.startswith('NOT('): 35 | active_set_part1 = re.sub('^\(|\)$','',active_set_part1.replace('NOT','',1)) 36 | notting1 = True 37 | set_part1 = get_basic_active_set(active_set_part1.strip()) 38 | if active_set_part2.startswith('NOT('): 39 | active_set_part2 = re.sub('^\(|\)$','',active_set_part2.replace('NOT','',1)) 40 | notting2 = True 41 | set_part2 = get_basic_active_set(active_set_part2.strip()) 42 | set_final = set([]) 43 | set_final_len = 0 44 | if anding: 45 | if notting1 and not notting2: 46 | set_final = set_part2 - set_part1 47 | set_final_len = len(set_final) 48 | elif notting2 and not notting1: 49 | set_final = set_part1 - set_part2 50 | set_final_len = len(set_final) 51 | elif not notting1 and not notting2: 52 | set_final = set_part1.intersection(set_part2) 53 | set_final_len = len(set_final) 54 | elif notting1 and notting2: 55 | #print 'found notting1 and notting2 ', active_set 56 | set_final.update(set_part1) 57 | set_final.update(set_part2) 58 | set_final_len = total_entities - len(set_final) 59 | if orring: 60 | if notting2 and not notting1: 61 | #set_final.update(set_part1) 62 | set_final = set_part1.intersection(set_part2) 63 | set_final_len = total_entities - len(set_part2) + len(set_final) 64 | elif notting1 and not notting2: 65 | set_final = set_part1.intersection(set_part2) 66 | set_final_len = total_entities - len(set_part1) + len(set_final) 67 | elif not notting1 and not notting2: 68 | set_final.update(set_part1) 69 | set_final.update(set_part2) 70 | set_final_len = len(set_final) 71 | elif notting1 and notting2: 72 | set_final= set_part1.intersection(set_part2) 73 | set_final_len = total_entities - len(set_final) 74 | return set_final_len 75 | else: 76 | return len(get_basic_active_set(active_set.strip())) 77 | 78 | def parse_active_set(active_set, target): 79 | active_set = active_set.strip() 80 | anding = False 81 | orring = False 82 | notting1 = False 83 | notting2 = False 84 | if active_set.startswith('AND(') or active_set.startswith('OR('): 85 | if active_set.startswith('AND('): 86 | anding = True 87 | active_set = re.sub('^\(|\)$','',active_set.replace('AND', '',1)) 88 | if active_set.startswith('OR('): 89 | anding = True 90 | active_set = re.sub('^\(|\)$','',active_set.replace('OR', '',1)) 91 | while active_set.startswith('(') and active_set.endswith(')'): 92 | active_set = re.sub('^\(|\)$','',active_set) 93 | active_set_parts = active_set.split(', ') 94 | active_set_part1 = active_set_parts[0].strip() 95 | active_set_part2 = active_set_parts[1].strip() 96 | if active_set_part1.startswith('NOT('): 97 | active_set_part1 = re.sub('^\(|\)$','',active_set_part1.replace('NOT','',1)) 98 | notting1 = True 99 | is_present1 = parse_basic_active_set(active_set_part1.strip(), target) 100 | if not notting1: 101 | is_present1 = not is_present1 102 | if active_set_part2.startswith('NOT('): 103 | active_set_part2 = re.sub('^\(|\)$','',active_set_part2.replace('NOT','',1)) 104 | notting2 = True 105 | is_present2 = parse_basic_active_set(active_set_part2.strip(), target) 106 | set_final = set([]) 107 | set_final_len = 0 108 | if notting2: 109 | is_present2 = not is_present2 110 | if anding: 111 | if is_present1 and is_present2: 112 | return True 113 | else: 114 | return False 115 | if orring: 116 | if is_present1 or is_present2: 117 | return True 118 | else: 119 | return False 120 | else: 121 | return parse_basic_active_set(active_set, target) 122 | 123 | def get_basic_active_set(active_set): 124 | if len(active_set) == 0: 125 | return set([]) 126 | active_set_orig = active_set 127 | 128 | while active_set.startswith('(') and active_set.endswith(')'): 129 | active_set = re.sub('^\(|\)$','',active_set) 130 | while active_set.startswith('(') and not active_set.endswith(')'): 131 | active_set = re.sub('^\(','',active_set) 132 | while active_set.endswith(')') and not active_set.startswith('('): 133 | active_set = re.sub('\)$','',active_set) 134 | 135 | assert not active_set.startswith('(') 136 | assert not active_set.endswith(')') 137 | 138 | # print 'time taken for regex proc = %f' % (time.time() - st_time) 139 | 140 | parent = None 141 | parent1 = None 142 | parent2 = None 143 | 144 | try: 145 | assert len(active_set.strip().split(',')) == 3 146 | except: 147 | print 'active_set = %s' % active_set_orig 148 | logging.exception('Something aweful happened') 149 | raise Exception('ERROR!!!') 150 | parts = active_set.strip().split(',') 151 | 152 | if parts[0].startswith('c') and parts[2].startswith('c'): 153 | parent1 = parts[0].split('(')[1].split(')')[0].strip() 154 | parent2 = parts[2].split('(')[1].split(')')[0].strip() 155 | if parts[0].startswith('c'): 156 | parent = parts[0].split('(')[1].split(')')[0].strip() 157 | ent = parts[2].strip() 158 | elif parts[2].startswith('c'): 159 | parent = parts[2].split('(')[1].split(')')[0].strip() 160 | ent = parts[0].strip() 161 | rel = parts[1].strip() 162 | feasible_children = set([]) 163 | if parent and ent: 164 | if parent not in par_child_dict: 165 | return set([]) 166 | sources = set([]) 167 | targets = set([]) 168 | 169 | try: 170 | sources.update(wikidata[ent][rel]) 171 | targets.update(wikidata[ent][rel]) 172 | sources.update(reverse_dict[ent][rel]) 173 | targets.update(reverse_dict[ent][rel]) 174 | except: 175 | pass 176 | all_entities = set([]) 177 | all_entities.update(sources) 178 | all_entities.update(targets) 179 | children_of_par = par_child_dict[parent] 180 | feasible_children = set(children_of_par).intersection(all_entities) 181 | num_children = len(feasible_children) 182 | 183 | elif parent1 and parent2: 184 | if parent1 not in par_child_dict or parent2 not in par_child_dict: 185 | return set([]) 186 | children2 = par_child_dict[parent2] 187 | children1 = par_child_dict[parent1] 188 | if len(children1) 1000: 298 | break 299 | try: 300 | if count % 1000 == 0: 301 | print count 302 | active_set = act_set_line.rstrip() 303 | pred_entities = pred_line.rstrip().split(', ') # QIDs 304 | 305 | pred_entities = pred_entities[:k] 306 | n_topK = len([x for x in pred_entities if is_contained_in_act_set(active_set, x)]) 307 | n_prec_sum += n_topK*1.0/float(len(pred_entities)) 308 | count += 1 309 | except: 310 | break 311 | 312 | avg_prec = n_prec_sum*1.0/float(count) 313 | print 'File ', pred_ent_name_file 314 | print 'Avg. prec for k= ',k,'= %f' % avg_prec 315 | print '' 316 | ''' 317 | 318 | 319 | def replace_kb_ent_in_resp(prob_memory_entities, pred_op): 320 | kb_name_list_unique = list(OrderedDict.fromkeys(prob_memory_entities))[:20] 321 | k = 0 322 | max_k = 20 323 | length = len(pred_op) 324 | replace_kb = True 325 | top_k_ent = [] 326 | for j in range(len(pred_op)): 327 | if pred_op[j] in ['','','','','']: 328 | pred_op[j] = '' 329 | if pred_op[j]=='': 330 | length = j 331 | if pred_op[j].startswith(''): 332 | if not replace_kb: 333 | pred_op[j] = '' 334 | continue 335 | if k == len(kb_name_list_unique) or k == max_k: 336 | replace_kb = False 337 | pred_op[j] = '' 338 | continue 339 | top_k_ent.append(kb_name_list_unique[k]) 340 | k = k+1 341 | return top_k_ent 342 | fw = open('out7.txt','w') 343 | dir1='model_softmax_kvmem_validtrim_unfilt_newversion' 344 | dir2='model_softmax_decoder_newversion' 345 | 346 | for state in states: 347 | for test_type in test_types: 348 | n_prec_sum = 0 349 | count = 0 350 | sTime= time.time() 351 | act_set_file = dir1+'/test_output_'+test_type+'_'+str(state)+'/active_set.txt' 352 | pred_ent_name_file = dir1+'/test_output_'+test_type+'_'+str(state)+'/top20_ent_id_from_mem.txt' 353 | pred_file = dir2+'/test_output_'+test_type+'_'+str(state)+'/pred_sent.txt' 354 | with open(act_set_file) as act_set_lines, open(pred_ent_name_file) as ent_lines, open(pred_file) as pred_lines: 355 | for pred, ent, act_set_line in izip(pred_lines, ent_lines, act_set_lines): 356 | word_list = pred.strip().split(' ') 357 | kb_count = 1 358 | for word in pred.strip().split(' '): 359 | if word=='': 360 | word_list.append('_'+str(kb_count)) 361 | kb_count = kb_count+1 362 | else: 363 | word_list.append(word) 364 | word_list = list(OrderedDict.fromkeys(word_list)) 365 | if '|' in ent: 366 | ent = ent.strip().split('|') 367 | else: 368 | ent = [x.strip() for x in ent.strip().split(',')] 369 | top_k_ent = replace_kb_ent_in_resp(ent, word_list) 370 | if time.time() - sTime > 100000: 371 | #print " time break" 372 | continue 373 | #try: 374 | #if count % 1000 == 0: 375 | # print count 376 | 377 | active_sets = act_set_line.rstrip().split('#') 378 | active_set_sizes = [] 379 | for active_set in active_sets: 380 | active_set_size = get_active_set_size(active_set) 381 | active_set_sizes.append(active_set_size) 382 | 383 | pred_entities = top_k_ent 384 | n_topK = 0 385 | for x in pred_entities: 386 | for i,active_set in enumerate(active_sets): 387 | if is_contained_in_act_set(active_set,x): 388 | n_topK +=1 389 | break 390 | active_set_size = sum(active_set_sizes) 391 | n_topK = active_set_size - n_topK 392 | if active_set_size>0: 393 | n_prec_sum += 1-(n_topK*1.0/float(active_set_size)) 394 | count += 1 395 | #except: 396 | # break 397 | avg_prec = n_prec_sum*1.0/float(count) 398 | fw.write('File '+pred_ent_name_file+'\n') 399 | fw.write('Avg. prec = '+str(avg_prec)+'\n') 400 | print 'File '+pred_ent_name_file 401 | print 'Avg. prec = '+str(avg_prec) 402 | -------------------------------------------------------------------------------- /evaluate/compute_recall_per_state.sh: -------------------------------------------------------------------------------- 1 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/top20_ent_from_mem.txt 2 2 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/top20_ent_from_mem.txt 5 3 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/top20_ent_from_mem.txt 10 4 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_-1/top20_ent_from_mem.txt 20 5 | 6 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/top20_ent_from_mem.txt 2 7 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/top20_ent_from_mem.txt 5 8 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/top20_ent_from_mem.txt 10 9 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/top20_ent_from_mem.txt 20 10 | 11 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/top20_ent_from_mem.txt 2 12 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/top20_ent_from_mem.txt 5 13 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/top20_ent_from_mem.txt 10 14 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/top20_ent_from_mem.txt 20 15 | 16 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/top20_ent_from_mem.txt 2 17 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/top20_ent_from_mem.txt 5 18 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/top20_ent_from_mem.txt 10 19 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/top20_ent_from_mem.txt 20 20 | 21 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/top20_ent_from_mem.txt 2 22 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/top20_ent_from_mem.txt 5 23 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/top20_ent_from_mem.txt 10 24 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/top20_ent_from_mem.txt 20 25 | 26 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/top20_ent_from_mem.txt 2 27 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/top20_ent_from_mem.txt 5 28 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/top20_ent_from_mem.txt 10 29 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/top20_ent_from_mem.txt 20 30 | 31 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/top20_ent_from_mem.txt 2 32 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/top20_ent_from_mem.txt 5 33 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/top20_ent_from_mem.txt 10 34 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/top20_ent_from_mem.txt 20 35 | 36 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/top20_ent_from_mem.txt 2 37 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/top20_ent_from_mem.txt 5 38 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/top20_ent_from_mem.txt 10 39 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/top20_ent_from_mem.txt 20 40 | 41 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/top20_ent_from_mem.txt 2 42 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/top20_ent_from_mem.txt 5 43 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/top20_ent_from_mem.txt 10 44 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/top20_ent_from_mem.txt 20 45 | 46 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/top20_ent_from_mem.txt 2 47 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/top20_ent_from_mem.txt 5 48 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/top20_ent_from_mem.txt 10 49 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/top20_ent_from_mem.txt 20 50 | 51 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/top20_ent_from_mem.txt 2 52 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/top20_ent_from_mem.txt 5 53 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/top20_ent_from_mem.txt 10 54 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/top20_ent_from_mem.txt 20 55 | 56 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/top20_ent_from_mem.txt 2 57 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/top20_ent_from_mem.txt 5 58 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/top20_ent_from_mem.txt 10 59 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/top20_ent_from_mem.txt 20 60 | 61 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/top20_ent_from_mem.txt 2 62 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/top20_ent_from_mem.txt 5 63 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/top20_ent_from_mem.txt 10 64 | #python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/top20_ent_from_mem.txt 20 65 | 66 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/top20_ent_from_mem.txt 2 67 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/top20_ent_from_mem.txt 5 68 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/top20_ent_from_mem.txt 10 69 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_12/top20_ent_from_mem.txt 20 70 | 71 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/top20_ent_from_mem.txt 2 72 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/top20_ent_from_mem.txt 5 73 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/top20_ent_from_mem.txt 10 74 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_13/top20_ent_from_mem.txt 20 75 | 76 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/top20_ent_from_mem.txt 2 77 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/top20_ent_from_mem.txt 5 78 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/top20_ent_from_mem.txt 10 79 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_14/top20_ent_from_mem.txt 20 80 | 81 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/top20_ent_from_mem.txt 2 82 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/top20_ent_from_mem.txt 5 83 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/top20_ent_from_mem.txt 10 84 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_hard_15/top20_ent_from_mem.txt 20 85 | 86 | 87 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/top20_ent_from_mem.txt 2 88 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/top20_ent_from_mem.txt 5 89 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/top20_ent_from_mem.txt 10 90 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_12/top20_ent_from_mem.txt 20 91 | 92 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/top20_ent_from_mem.txt 2 93 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/top20_ent_from_mem.txt 5 94 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/top20_ent_from_mem.txt 10 95 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_13/top20_ent_from_mem.txt 20 96 | 97 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/top20_ent_from_mem.txt 2 98 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/top20_ent_from_mem.txt 5 99 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/top20_ent_from_mem.txt 10 100 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_14/top20_ent_from_mem.txt 20 101 | 102 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/top20_ent_from_mem.txt 2 103 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/top20_ent_from_mem.txt 5 104 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/top20_ent_from_mem.txt 10 105 | python compute_recall.py model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/gold_ent.txt model_softmax_kvmem_valid_trip_unfilt/test_output_easy_15/top20_ent_from_mem.txt 20 106 | -------------------------------------------------------------------------------- /evaluate/load_wikidata_wfn.py: -------------------------------------------------------------------------------- 1 | import json, codecs, random, requests, pickle, traceback, logging, os, math 2 | 3 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_short_1.json','r','utf-8') as data_file: 4 | wikidata = json.load(data_file) 5 | print 'Successfully loaded wikidata1' 6 | 7 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_short_2.json','r','utf-8') as data_file: 8 | wikidata2 = json.load(data_file) 9 | print 'Successfully loaded wikidata2' 10 | 11 | wikidata.update(wikidata2) 12 | del wikidata2 13 | 14 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/items_wikidata_n.json','r','utf-8') as data_file: 15 | item_data = json.load(data_file) 16 | print 'Successfully loaded items json' 17 | 18 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/comp_wikidata_rev.json','r','utf-8') as data_file: 19 | reverse_dict = json.load(data_file) 20 | print 'Successfully loaded reverse_dict json' 21 | 22 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_fanout_dict.json','r','utf-8') as data_file: 23 | wikidata_fanout_dict = json.load(data_file) 24 | print 'Successfully loaded wikidata_fanout_dict json' 25 | 26 | wikidata_fanout_dict_list = pickle.load(open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_fanout_dict_list.pickle', 'rb')) 27 | print 'Successfully loaded wikidata_fanout_dict_list pickle' 28 | 29 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/child_par_dict_save.json','r','utf-8') as data_file: 30 | child_par_dict = json.load(data_file) 31 | print 'Successfully loaded child_par_dict json' 32 | 33 | 34 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/filtered_property_wikidata4.json','r','utf-8') as data_file: 35 | prop_data = json.load(data_file) 36 | 37 | wikidata_remove_list = [q for q in wikidata if q not in item_data] 38 | 39 | wikidata_remove_list.extend([q for q in wikidata if 'P31' not in wikidata[q] and 'P279' not in wikidata[q]]) 40 | 41 | wikidata_remove_list.extend([u'Q7375063', u'Q24284139', u'Q1892495', u'Q22980687', u'Q25093915', u'Q22980685', u'Q22980688', u'Q25588222', u'Q1668023', u'Q20794889', u'Q22980686',u'Q297106',u'Q1293664']) 42 | 43 | # wikidata_remove_list.extend([q for q in wikidata if q not in child_par_dict]) 44 | 45 | for q in wikidata_remove_list: 46 | wikidata.pop(q,None) 47 | # ************************************************************************ 48 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_type_dict.json','r','utf-8') as f1: 49 | wikidata_type_dict = json.load(f1) 50 | 51 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/wikidata_rev_type_dict.json','r','utf-8') as f1: 52 | wikidata_type_rev_dict = json.load(f1) 53 | 54 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/par_child_dict.json','r','utf-8') as f1: 55 | par_child_dict = json.load(f1) 56 | 57 | with codecs.open('/dccstor/cssblr/vardaan/dialog-qa/child_par_dict_name3.json','r','utf-8') as f1: 58 | child_par_dict_name_2 = json.load(f1) 59 | 60 | # with codecs.open('dict_val/sing_sub_annot.json','r','utf-8') as f1: 61 | # sing_sub_annot = json.load(f1) 62 | 63 | # with codecs.open('dict_val/plu_sub_annot.json','r','utf-8') as f1: 64 | # plu_sub_annot = json.load(f1) 65 | 66 | # with codecs.open('dict_val/sing_obj_annot.json','r','utf-8') as f1: 67 | # sing_obj_annot = json.load(f1) 68 | 69 | # with codecs.open('dict_val/plu_obj_annot.json','r','utf-8') as f1: 70 | # plu_obj_annot = json.load(f1) 71 | 72 | # # with codecs.open('dict_val/neg_sub_annot.json','r','utf-8') as f1: 73 | # # neg_sub_annot = json.load(f1) 74 | 75 | # with codecs.open('dict_val/neg_plu_sub_annot.json','r','utf-8') as f1: 76 | # neg_plu_sub_annot = json.load(f1) 77 | 78 | # # with codecs.open('dict_val/neg_obj_annot.json','r','utf-8') as f1: 79 | # # neg_obj_annot = json.load(f1) 80 | 81 | # with codecs.open('dict_val/neg_plu_obj_annot.json','r','utf-8') as f1: 82 | # neg_plu_obj_annot = json.load(f1) 83 | 84 | # # ****************************************************************** 85 | 86 | # with codecs.open('dict_val/sing_sub_annot_wh.json','r','utf-8') as f1: 87 | # sing_sub_annot_wh = json.load(f1) 88 | 89 | # with codecs.open('dict_val/plu_sub_annot_wh.json','r','utf-8') as f1: 90 | # plu_sub_annot_wh = json.load(f1) 91 | 92 | # with codecs.open('dict_val/sing_obj_annot_wh.json','r','utf-8') as f1: 93 | # sing_obj_annot_wh = json.load(f1) 94 | 95 | # with codecs.open('dict_val/plu_obj_annot_wh.json','r','utf-8') as f1: 96 | # plu_obj_annot_wh = json.load(f1) 97 | 98 | # # with codecs.open('dict_val/neg_sub_annot_wh.json','r','utf-8') as f1: 99 | # # neg_sub_annot_wh = json.load(f1) 100 | 101 | # with codecs.open('dict_val/neg_plu_sub_annot_wh.json','r','utf-8') as f1: 102 | # neg_plu_sub_annot_wh = json.load(f1) 103 | 104 | # # with codecs.open('dict_val/neg_obj_annot_wh.json','r','utf-8') as f1: 105 | # # neg_obj_annot_wh = json.load(f1) 106 | 107 | # with codecs.open('dict_val/neg_plu_obj_annot_wh.json','r','utf-8') as f1: 108 | # neg_plu_obj_annot_wh = json.load(f1) 109 | 110 | # with codecs.open('prop_obj_90_map5.json','r','utf-8') as data_file: 111 | # obj_90_map = json.load(data_file) 112 | 113 | # with codecs.open('prop_sub_90_map5.json','r','utf-8') as data_file: 114 | # sub_90_map = json.load(data_file) 115 | -------------------------------------------------------------------------------- /evaluate/multi_bleu.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | sub add_to_ref { 35 | my ($file,$REF) = @_; 36 | my $s=0; 37 | open(REF,$file) or die "Can't read $file"; 38 | while() { 39 | chop; 40 | push @{$$REF[$s++]}, $_; 41 | } 42 | close(REF); 43 | } 44 | 45 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 46 | my $s=0; 47 | while() { 48 | chop; 49 | $_ = lc if $lowercase; 50 | my @WORD = split; 51 | my %REF_NGRAM = (); 52 | my $length_translation_this_sentence = scalar(@WORD); 53 | my ($closest_diff,$closest_length) = (9999,9999); 54 | foreach my $reference (@{$REF[$s]}) { 55 | # print "$s $_ <=> $reference\n"; 56 | $reference = lc($reference) if $lowercase; 57 | my @WORD = split(' ',$reference); 58 | my $length = scalar(@WORD); 59 | my $diff = abs($length_translation_this_sentence-$length); 60 | if ($diff < $closest_diff) { 61 | $closest_diff = $diff; 62 | $closest_length = $length; 63 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 64 | } elsif ($diff == $closest_diff) { 65 | $closest_length = $length if $length < $closest_length; 66 | # from two references with the same closeness to me 67 | # take the *shorter* into account, not the "first" one. 68 | } 69 | for(my $n=1;$n<=4;$n++) { 70 | my %REF_NGRAM_N = (); 71 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 72 | my $ngram = "$n"; 73 | for(my $w=0;$w<$n;$w++) { 74 | $ngram .= " ".$WORD[$start+$w]; 75 | } 76 | $REF_NGRAM_N{$ngram}++; 77 | } 78 | foreach my $ngram (keys %REF_NGRAM_N) { 79 | if (!defined($REF_NGRAM{$ngram}) || 80 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 81 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 82 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 83 | } 84 | } 85 | } 86 | } 87 | $length_translation += $length_translation_this_sentence; 88 | $length_reference += $closest_length; 89 | for(my $n=1;$n<=4;$n++) { 90 | my %T_NGRAM = (); 91 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 92 | my $ngram = "$n"; 93 | for(my $w=0;$w<$n;$w++) { 94 | $ngram .= " ".$WORD[$start+$w]; 95 | } 96 | $T_NGRAM{$ngram}++; 97 | } 98 | foreach my $ngram (keys %T_NGRAM) { 99 | $ngram =~ /^(\d+) /; 100 | my $n = $1; 101 | # my $corr = 0; 102 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 103 | $TOTAL[$n] += $T_NGRAM{$ngram}; 104 | if (defined($REF_NGRAM{$ngram})) { 105 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 106 | $CORRECT[$n] += $T_NGRAM{$ngram}; 107 | # $corr = $T_NGRAM{$ngram}; 108 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 109 | } 110 | else { 111 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 112 | # $corr = $REF_NGRAM{$ngram}; 113 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 114 | } 115 | } 116 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 117 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 118 | } 119 | } 120 | $s++; 121 | } 122 | my $brevity_penalty = 1; 123 | my $bleu = 0; 124 | 125 | my @bleu=(); 126 | 127 | for(my $n=1;$n<=4;$n++) { 128 | if (defined ($TOTAL[$n]) and defined($CORRECT[$n])){ 129 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 130 | print "DONE" 131 | #print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 132 | }else{ 133 | $bleu[$n]=0.0000001; 134 | } 135 | printf " %d %f %f \n",$n,$bleu[$n],my_log($bleu[$n]); 136 | } 137 | 138 | if ($length_reference==0){ 139 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 140 | exit(1); 141 | } 142 | 143 | #if ($length_translation<$length_reference) { 144 | # $brevity_penalty = exp(1-$length_reference/$length_translation); 145 | #} 146 | 147 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 148 | my_log( $bleu[2] ) + 149 | my_log( $bleu[3] ) + 150 | my_log( $bleu[4] ) ) / 4) ; 151 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 152 | 100*$bleu, 153 | 100*$bleu[1], 154 | 100*$bleu[2], 155 | 100*$bleu[3], 156 | 100*$bleu[4], 157 | $brevity_penalty, 158 | $length_translation / $length_reference, 159 | $length_translation, 160 | $length_reference; 161 | 162 | sub my_log { 163 | return -9999999999 unless $_[0]; 164 | return log($_[0]); 165 | } 166 | -------------------------------------------------------------------------------- /evaluate/parse_active_set.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from load_wikidata2 import load_wikidata 3 | import json 4 | from itertools import izip 5 | import re 6 | wikidata, reverse_dict, prop_data, child_par_dict, wikidata_fanout_dict = load_wikidata() 7 | par_child_dict = json.load(open('/dccstor/cssblr/vardaan/dialog-qa/par_child_dict.json')) 8 | 9 | def parse_active_set(active_set, target): 10 | active_set = active_set.strip() 11 | anding = False 12 | orring = False 13 | notting1 = False 14 | notting2 = False 15 | if active_set.startswith('AND(') or active_set.startswith('OR('): 16 | if active_set.startswith('AND('): 17 | anding = True 18 | active_set = re.sub('^\(|\)$','',active_set.replace('AND', '',1)) 19 | if active_set.startswith('OR('): 20 | anding = True 21 | active_set = re.sub('^\(|\)$','',active_set.replace('OR', '',1)) 22 | while active_set.startswith('(') and active_set.endswith(')'): 23 | active_set = re.sub('^\(|\)$','',active_set) 24 | active_set_parts = active_set.split(', ') 25 | active_set_part1 = active_set_parts[0].strip() 26 | active_set_part2 = active_set_parts[1].strip() 27 | if active_set_part1.startswith('NOT('): 28 | active_set_part1 = re.sub('^\(|\)$','',active_set_part1.replace('NOT','',1)) 29 | notting1 = True 30 | is_present1 = parse_basic_active_set(active_set_part1.strip(), target) 31 | if not notting1: 32 | is_present1 = not is_present1 33 | if active_set_part2.startswith('NOT('): 34 | active_set_part2 = re.sub('^\(|\)$','',active_set_part2.replace('NOT','',1)) 35 | notting2 = True 36 | is_present2 = parse_basic_active_set(active_set_part2.strip(), target) 37 | if notting2: 38 | is_present2 = not is_present2 39 | if anding: 40 | if is_present1 and is_present2: 41 | return True 42 | else: 43 | return False 44 | if orring: 45 | if is_present1 or is_present2: 46 | return True 47 | else: 48 | return False 49 | else: 50 | return parse_basic_active_set(active_set, target) 51 | 52 | def parse_basic_active_set(active_set, target): 53 | while active_set.startswith('(') and active_set.endswith(')'): 54 | active_set = re.sub('^\(|\)$','',active_set) 55 | while active_set.startswith('(') and not active_set.endswith(')'): 56 | active_set = re.sub('^\(','',active_set) 57 | print active_set 58 | parent = None 59 | parent1 = None 60 | parent2 = None 61 | parts = active_set.strip().split(',') 62 | if parts[0].startswith('c') and parts[2].startswith('c'): 63 | parent1 = parts[0].split('(')[1].split(')')[0].strip() 64 | parent2 = parts[2].split('(')[1].split(')')[0].strip() 65 | if parts[0].startswith('c'): 66 | parent = parts[0].split('(')[1].split(')')[0].strip() 67 | ent = parts[2].strip() 68 | elif parts[2].startswith('c'): 69 | parent = parts[2].split('(')[1].split(')')[0].strip() 70 | ent = parts[0].strip() 71 | rel = parts[1].strip() 72 | if parent and ent: 73 | try: 74 | if target in par_child_dict[parent] and (target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]): 75 | return True 76 | else: 77 | return False 78 | except: 79 | return False 80 | elif parent1 and parent2: 81 | try: 82 | if target in par_child_dict[parent1]: 83 | children2 = par_child_dict[parent2] 84 | for ent in children2: 85 | if target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]: 86 | return True 87 | elif target in par_child_dict[parent2]: 88 | children1 = par_child_dict[parent1] 89 | for ent in children1: 90 | if target in wikidata[ent][rel] or ent in wikidata[target][rel] or target in reverse_dict[ent][rel] or ent in reverse_dict[target][rel]: 91 | return True 92 | except: 93 | return False 94 | 95 | return False 96 | 97 | 98 | if __name__=="__main__": 99 | dir="/dccstor/cssblr/amrita/dialog_qa/code/hred_kvmem2_softmax/model_softmax_new/dump" 100 | target_file = dir+'/test_hard_target_'+sys.argv[1]+'.txt' 101 | active_set_file = dir+'/test_hard_active_set_'+sys.argv[1]+'.txt' 102 | with open(target_file) as targetlines, open(active_set_file) as activelines: 103 | for target, active_set in izip(targetlines, activelines): 104 | target = target.strip().split('|') 105 | active_set = active_set.strip() 106 | for target_i in target: 107 | is_present = parse_active_set(active_set, target_i) 108 | print 'ACTIVE SET: ',active_set, ' TARGET: ',target_i, ' IS_PRESENT: ',is_present 109 | -------------------------------------------------------------------------------- /evaluate/postprocess_bool.py: -------------------------------------------------------------------------------- 1 | import sys 2 | file = sys.argv[1] 3 | ''' 4 | target_fw = open(sys.argv[2],'w') 5 | for line in open(file).readlines(): 6 | line = line.strip() 7 | words = line.split(' ') 8 | words_yes_no = [] 9 | for w in words: 10 | if (w=='yes' or w=='no') and w not in words_yes_no: 11 | words_yes_no.append(w) 12 | new_line = ' and '.join(words_yes_no) 13 | target_fw.write(new_line.strip()+'\n') 14 | print line, '----->',new_line 15 | target_fw.close() 16 | ''' 17 | acc = 0.0 18 | count = 0.0 19 | goldlines = open(sys.argv[1]).readlines() 20 | predlines = open(sys.argv[2]).readlines() 21 | for goldline, predline in zip(goldlines, predlines): 22 | goldline = goldline.lower().strip() 23 | predline = predline.lower().strip() 24 | goldline = " ".join([x for x in goldline.lower().split(' ') if x in ['yes','no']]) 25 | predline = " ".join([x for x in predline.lower().split(' ') if x in ['yes','no']]) 26 | if goldline==predline: 27 | acc=acc+1.0 28 | count+=1.0 29 | print acc/count 30 | -------------------------------------------------------------------------------- /evaluate/run_calculate_bleu.sh: -------------------------------------------------------------------------------- 1 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy -1 2 | #./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 2 3 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 3 4 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 4 5 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 5 6 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 6 7 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 7 8 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 8 9 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 9 10 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 10 11 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 12 12 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 13 13 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 14 14 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion easy 15 15 | 16 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard -1 17 | #./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 2 18 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 3 19 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 4 20 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 5 21 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 6 22 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 7 23 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 8 24 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 9 25 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 10 26 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 12 27 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 13 28 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 14 29 | ./calculate_bleu.sh model_softmax_decoder_newversion model_softmax_kvmem_validtrim_unfilt_newversion hard 15 30 | -------------------------------------------------------------------------------- /evaluate/run_compute_recall.sh: -------------------------------------------------------------------------------- 1 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion -1 model_softmax_decoder_newversion 2 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 3 model_softmax_decoder_newversion 3 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 4 model_softmax_decoder_newversion 4 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 5 model_softmax_decoder_newversion 5 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 6 model_softmax_decoder_newversion 6 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 7 model_softmax_decoder_newversion 7 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 8 model_softmax_decoder_newversion 8 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 9 model_softmax_decoder_newversion 9 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 10 model_softmax_decoder_newversion 10 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 12 model_softmax_decoder_newversion 11 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 13 model_softmax_decoder_newversion 12 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 14 model_softmax_decoder_newversion 13 | ./compute_recall.sh model_softmax_kvmem_validtrim_unfilt_newversion 15 model_softmax_decoder_newversion 14 | -------------------------------------------------------------------------------- /evaluate/split_count_op.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from itertools import izip 3 | 4 | ques_id = int(sys.argv[1]) 5 | 6 | model_dir = 'model_softmax_kvmem_valid_trip_unfilt' 7 | op_dir_orig_hard = os.path.join(model_dir,'test_output_hard_%d' % ques_id) 8 | 9 | op_dir_orig_hard_p1 = os.path.join(model_dir,'test_output_hard_%da' % ques_id) 10 | op_dir_orig_hard_p2 = os.path.join(model_dir,'test_output_hard_%db' % ques_id) 11 | 12 | if not os.path.exists(op_dir_orig_hard_p1): 13 | os.makedirs(op_dir_orig_hard_p1) 14 | 15 | if not os.path.exists(op_dir_orig_hard_p2): 16 | os.makedirs(op_dir_orig_hard_p2) 17 | 18 | which_line_ids = [] 19 | how_line_ids = [] 20 | 21 | with open(os.path.join(model_dir,'dump','test_hard_context_%d.txt' % ques_id)) as context_lines: 22 | for i,line in enumerate(context_lines): 23 | _, ques = line.split('|') 24 | ques_tokenized = ques.split(' ') 25 | 26 | if ques_tokenized[0].lower() == 'which': 27 | which_line_ids.append(i) 28 | else: 29 | how_line_ids.append(i) 30 | 31 | print 'which lc = %d' % len(which_line_ids) 32 | print 'how lc = %d' % len(how_line_ids) 33 | 34 | for filename in os.listdir(op_dir_orig_hard): 35 | f1 = open(os.path.join(op_dir_orig_hard,filename),'r') 36 | line_list = f1.readlines() 37 | line_list_a = [line_list[i] for i in which_line_ids] 38 | line_list_b = [line_list[i] for i in how_line_ids] 39 | 40 | f2 = open(os.path.join(op_dir_orig_hard_p1,filename),'w') 41 | f3 = open(os.path.join(op_dir_orig_hard_p2,filename),'w') 42 | 43 | for line in line_list_a: 44 | f2.write(line) 45 | 46 | for line in line_list_b: 47 | f3.write(line) 48 | f1.close() 49 | f2.close() 50 | f3.close() 51 | 52 | act_set_filename = os.path.join(model_dir,'dump','test_hard_active_set_%d.txt' % ques_id) 53 | target_filename = os.path.join(model_dir,'dump','test_hard_target_%d.txt' % ques_id) 54 | 55 | f1 = open(act_set_filename,'r') 56 | line_list = f1.readlines() 57 | line_list_a = [line_list[i] for i in which_line_ids] 58 | line_list_b = [line_list[i] for i in how_line_ids] 59 | 60 | f2 = open(os.path.join(model_dir,'dump','test_hard_active_set_%da.txt' % ques_id),'w') 61 | f3 = open(os.path.join(model_dir,'dump','test_hard_active_set_%db.txt' % ques_id),'w') 62 | 63 | for line in line_list_a: 64 | f2.write(line) 65 | 66 | for line in line_list_b: 67 | f3.write(line) 68 | f1.close() 69 | f2.close() 70 | f3.close() 71 | 72 | f1 = open(target_filename,'r') 73 | line_list = f1.readlines() 74 | line_list_a = [line_list[i] for i in which_line_ids] 75 | line_list_b = [line_list[i] for i in how_line_ids] 76 | 77 | f2 = open(os.path.join(model_dir,'dump','test_hard_target_%da.txt' % ques_id),'w') 78 | f3 = open(os.path.join(model_dir,'dump','test_hard_target_%db.txt' % ques_id),'w') 79 | 80 | for line in line_list_a: 81 | f2.write(line) 82 | 83 | for line in line_list_b: 84 | f3.write(line) 85 | f1.close() 86 | f2.close() 87 | f3.close() -------------------------------------------------------------------------------- /evaluate/words2number.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | 3 | def text2int(textnum, numwords={}): 4 | textnum = textnum.replace(',','') 5 | textnum = textnum.replace('-',' ') 6 | 7 | if not numwords: 8 | units = [ 9 | "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", 10 | "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", 11 | "sixteen", "seventeen", "eighteen", "nineteen", 12 | ] 13 | 14 | tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] 15 | 16 | scales = ["hundred", "thousand", "million", "billion", "trillion"] 17 | 18 | numwords["and"] = (1, 0) 19 | for idx, word in enumerate(units): numwords[word] = (1, idx) 20 | for idx, word in enumerate(tens): numwords[word] = (1, idx * 10) 21 | for idx, word in enumerate(scales): numwords[word] = (10 ** (idx * 3 or 2), 0) 22 | 23 | current = result = 0 24 | for word in textnum.split(): 25 | if word not in numwords: 26 | raise Exception("Illegal word: " + word) 27 | 28 | scale, increment = numwords[word] 29 | current = current * scale + increment 30 | if scale > 100: 31 | result += current 32 | current = 0 33 | 34 | return result + current 35 | ''' 36 | inf_eng = inflect.engine() 37 | num = 1995 38 | 39 | x = inf_eng.number_to_words(int(num)) 40 | print x 41 | print text2int(x) 42 | 43 | for n in range(1000000): 44 | if n % 100 == 0: 45 | print n 46 | n_rec = text2int(inf_eng.number_to_words(int(n))) 47 | try: 48 | assert n == n_rec 49 | except: 50 | print n 51 | ''' 52 | -------------------------------------------------------------------------------- /load_wikidata2.py: -------------------------------------------------------------------------------- 1 | import json, codecs, random, pickle, traceback, logging, os, math 2 | 3 | def load_wikidata(wikidata_dir): 4 | with codecs.open(wikidata_dir+'/wikidata_short_1.json','r','utf-8') as data_file: 5 | wikidata = json.load(data_file) 6 | print 'Successfully loaded wikidata1' 7 | 8 | with codecs.open(wikidata_dir+'/wikidata_short_2.json','r','utf-8') as data_file: 9 | wikidata2 = json.load(data_file) 10 | print 'Successfully loaded wikidata2' 11 | 12 | wikidata.update(wikidata2) 13 | del wikidata2 14 | 15 | with codecs.open(wikidata_dir+'/items_wikidata_n.json','r','utf-8') as data_file: 16 | item_data = json.load(data_file) 17 | print 'Successfully loaded items json' 18 | 19 | with codecs.open(wikidata_dir+'/comp_wikidata_rev.json','r','utf-8') as data_file: 20 | reverse_dict = json.load(data_file) 21 | print 'Successfully loaded reverse_dict json' 22 | 23 | with codecs.open(wikidata_dir+'/wikidata_fanout_dict.json','r','utf-8') as data_file: 24 | wikidata_fanout_dict = json.load(data_file) 25 | print 'Successfully loaded wikidata_fanout_dict json' 26 | 27 | with codecs.open(wikidata_dir+'/child_par_dict_save.json','r','utf-8') as data_file: 28 | child_par_dict = json.load(data_file) 29 | print 'Successfully loaded child_par_dict json' 30 | 31 | with codecs.open(wikidata_dir+'/child_all_parents_till_5_levels.json','r','utf-8') as data_file: 32 | child_all_parents_dict = json.load(data_file) 33 | print 'Successfully loaded child_all_parents_dict json' 34 | 35 | with codecs.open(wikidata_dir+'/filtered_property_wikidata4.json','r','utf-8') as data_file: 36 | prop_data = json.load(data_file) 37 | 38 | with codecs.open(wikidata_dir+'/par_child_dict.json','r','utf-8') as f1: 39 | par_child_dict = json.load(f1) 40 | 41 | wikidata_remove_list = [q for q in wikidata if q not in item_data] 42 | 43 | wikidata_remove_list.extend([q for q in wikidata if 'P31' not in wikidata[q] and 'P279' not in wikidata[q]]) 44 | 45 | wikidata_remove_list.extend([u'Q7375063', u'Q24284139', u'Q1892495', u'Q22980687', u'Q25093915', u'Q22980685', u'Q22980688', u'Q25588222', u'Q1668023', u'Q20794889', u'Q22980686',u'Q297106',u'Q1293664']) 46 | 47 | # wikidata_remove_list.extend([q for q in wikidata if q not in child_par_dict]) 48 | 49 | for q in wikidata_remove_list: 50 | wikidata.pop(q,None) 51 | 52 | with codecs.open(wikidata_dir+'/child_par_dict_immed.json','r','utf-8') as data_file: 53 | child_par_dict_immed = json.load(data_file) 54 | #************************ FIX for wierd parent types (wikimedia, metaclass etc.)******************************** 55 | stop_par_list = ['Q21025364', 'Q19361238', 'Q21027609', 'Q20088085', 'Q15184295', 'Q11266439', 'Q17362920', 'Q19798645', 'Q26884324', 'Q14204246', 'Q13406463', 'Q14827288', 'Q4167410', 'Q21484471', 'Q17442446', 'Q4167836', 'Q19478619', 'Q24017414', 'Q19361238', 'Q24027526', 'Q15831596', 'Q24027474', 'Q23958852', 'Q24017465', 'Q24027515', 'Q1924819'] 56 | stop_par_immed_list = ['Q10876391', 'Q1351452', 'Q1423994', 'Q1443451', 'Q14943910', 'Q151', 'Q15156455', 'Q15214930', 'Q15407973', 'Q15647814', 'Q15671253', 'Q162032', 'Q16222597', 'Q17146139', 'Q17633526', 'Q19798644', 'Q19826567', 'Q19842659', 'Q19887878', 'Q20010800', 'Q20113609', 'Q20116696', 'Q20671729', 'Q20769160', 'Q20769287', 'Q21281405', 'Q21286738', 'Q21450877', 'Q21469493', 'Q21705225', 'Q22001316', 'Q22001389', 'Q22001390', 'Q23840898', 'Q23894246', 'Q24025936', 'Q24046192', 'Q24571886', 'Q24731821', 'Q2492014', 'Q252944', 'Q26267864', 'Q35120', 'Q351749', 'Q367', 'Q370', 'Q3933727', 'Q4663903', 'Q4989363', 'Q52', 'Q5296', 'Q565', 'Q6540697', 'Q79786', 'Q964'] # courtsey Amrita Saha 57 | 58 | ent_list = [] 59 | 60 | for x in stop_par_list: 61 | ent_list.extend(par_child_dict[x]) 62 | 63 | ent_list = list(set(ent_list)) 64 | ent_list_resolved = [x for x in ent_list if x in child_par_dict_immed and child_par_dict_immed[x] not in stop_par_list and child_par_dict_immed[x] not in stop_par_immed_list] 65 | 66 | child_par_dict_val = list(set(child_par_dict.values())) 67 | old_2_new_pars_map = {x:x for x in child_par_dict_val} 68 | rem_par_list = set() 69 | 70 | for x in ent_list_resolved: 71 | child_par_dict[x] = child_par_dict_immed[x] 72 | old_2_new_pars_map[child_par_dict[x]] = child_par_dict_immed[x] 73 | rem_par_list.add(child_par_dict[x]) 74 | 75 | ent_list_discard = list(set(ent_list) - set(ent_list_resolved)) 76 | 77 | for q in ent_list_discard: 78 | par_q = None 79 | if q in child_par_dict: 80 | child_par_dict.pop(q,None) 81 | if q in wikidata: 82 | wikidata.pop(q,None) 83 | if q in reverse_dict: 84 | reverse_dict.pop(q,None) 85 | 86 | return wikidata, reverse_dict, prop_data, child_par_dict, child_all_parents_dict, wikidata_fanout_dict, par_child_dict 87 | 88 | 89 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import rnn,rnn_cell 3 | import os 4 | def get_params(dir): 5 | param={} 6 | dir= str(dir) 7 | param['train_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/CSQA_v7/train" 8 | param['valid_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/CSQA_v7/valid/" 9 | param['test_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/CSQA_v7/test/" 10 | param['wikidata_dir']="/dccstor/cssblr/vardaan/dialog-qa/" 11 | param['transe_dir']="transe_dir/" 12 | param['lucene_dir']="lucene_dir/" 13 | param['glove_dir']="/dccstor/cssblr/amrita/resources/glove/" 14 | param['dump_dir_loc']=dir+"/dump/" 15 | param['test_output_dir']=dir+"/test_output/" 16 | param['vocab_file']=dir+"/vocab.pkl" 17 | param['train_data_file']=[dir+"/dump/"+x for x in os.listdir(dir+"/dump") if x.startswith('train_')] 18 | param['valid_data_file']=[dir+"/dump/"+x for x in os.listdir(dir+"/dump") if x.startswith('valid_')] 19 | param['test_data_file']=dir+"/dump/test_data_file.pkl" 20 | param['vocab_file']=dir+"/vocab.pkl" 21 | param['response_vocab_file']=dir+"/response_vocab.pkl" 22 | param['vocab_stats_file']=dir+"/vocab_stats.pkl" 23 | param['model_path']=dir+"/model" 24 | param['terminal_op']=dir+"/terminal_output.txt" 25 | param['logs_path']=dir+"/log" 26 | param['type_of_loss']="decoder" 27 | param['text_embedding_size'] = 300 28 | param['activation'] = None #tf.tanh 29 | param['output_activation'] = None #tf.nn.softmax 30 | param['cell_size']= 512 31 | param['cell_type']=rnn_cell.GRUCell 32 | param['batch_size']=64 33 | param['vocab_freq_cutoff']=5 34 | param['learning_rate']=0.0004 35 | param['patience']=200 36 | param['early_stop']=100 37 | param['max_epochs']=1000000 38 | param['max_len']=20 39 | param['max_utter']=2 40 | param['print_train_freq']=100 41 | param['show_grad_freq']=20 42 | param['valid_freq']=5000 43 | param['max_gradient_norm']=0.1 44 | param['train_loss_incremenet_tolerance']=0.01 45 | param['wikidata_embed_size']= 100 46 | param['memory_size'] = 50000 47 | param['gold_target_size'] = 10 48 | param['stopwords'] = 'stopwords.pkl' 49 | param['stopwords_histogram'] = 'stopwords_histogram.txt' 50 | param['vocab_max_len'] = 40000 51 | return param 52 | -------------------------------------------------------------------------------- /params_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import rnn,rnn_cell 3 | import os 4 | def get_params(dir, ques_type_id): 5 | param={} 6 | dir= str(dir) 7 | param['train_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/QA_train_final6/train" 8 | param['valid_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/QA_train_final6/valid" 9 | param['test_dir_loc']="/dccstor/cssblr/vardaan/dialog-qa/QA_train_final6/test/" 10 | param['wikidata_dir']="/dccstor/cssblr/vardaan/dialog-qa/" 11 | param['transe_dir']="transe_dir/" 12 | param['lucene_dir']="lucene_dir/" 13 | param['glove_dir']="/dccstor/cssblr/amrita/resources/glove/" 14 | param['dump_dir_loc']=dir+"/dump/" 15 | param['test_output_dir']=dir+"/test_output_"+ques_type_id+"/" 16 | param['vocab_file']=dir+"/vocab.pkl" 17 | param['train_data_file']=dir+"/dump/train_data_file.pkl" 18 | param['valid_data_file']=dir+"/dump/valid_data_file.pkl" 19 | param['ques_type_id'] = int(ques_type_id) 20 | param['test_data_file']=dir+"/dump/test_data_file_%s.pkl" % ques_type_id 21 | param['vocab_file']=dir+"/vocab.pkl" 22 | param['vocab_stats_file']=dir+"/vocab_stats.pkl" 23 | param['model_path']=dir+"/model" 24 | param['logs_path']=dir+"/log" 25 | param['type_of_loss']="decoder" 26 | param['response_vocab_file']=dir+"/response_vocab.pkl" 27 | param['text_embedding_size'] = 300 28 | param['activation'] = None #tf.tanh 29 | param['output_activation'] = None #tf.nn.softmax 30 | param['cell_size']= 512 31 | param['cell_type']=rnn_cell.GRUCell 32 | param['batch_size']=64 33 | param['vocab_freq_cutoff']=5 34 | param['learning_rate']=0.0004 35 | param['patience']=200 36 | param['early_stop']=100 37 | param['max_epochs']=1000000 38 | param['max_len']=20 39 | param['max_utter']=2 40 | param['print_train_freq']=100 41 | param['show_grad_freq']=20 42 | param['valid_freq']=1000 43 | param['max_gradient_norm']=0.1 44 | param['train_loss_incremenet_tolerance']=0.01 45 | param['wikidata_embed_size']= 100 46 | param['memory_size'] = 10000 47 | param['gold_target_size'] = 10 48 | param['stopwords'] = 'stopwords.pkl' 49 | param['stopwords_histogram'] = 'stopwords_histogram.txt' 50 | param['vocab_max_len'] = 40000 51 | return param 52 | -------------------------------------------------------------------------------- /read_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import cPickle as pkl 4 | import os 5 | #from params import * 6 | from prepare_data_for_hred import PrepareData as PrepareData 7 | start_symbol_index = 0 8 | end_symbol_index = 1 9 | unk_symbol_index = 2 10 | pad_symbol_index = 3 11 | kb_pad_idx = 0 12 | nkb = 1 13 | import os 14 | 15 | 16 | def get_dialog_dict(param): 17 | train_dir_loc = param['train_dir_loc'] 18 | valid_dir_loc = param['valid_dir_loc'] 19 | test_dir_loc = param['test_dir_loc'] 20 | dump_dir_loc = param['dump_dir_loc'] 21 | vocab_file = param['vocab_file'] 22 | vocab_stats_file = param['vocab_stats_file'] 23 | vocab_freq_cutoff = param['vocab_freq_cutoff'] 24 | train_data_file = param['train_data_file'] 25 | valid_data_file = param['valid_data_file'] 26 | test_data_file = param['test_data_file'] 27 | max_utter = param['max_utter'] 28 | max_len = param['max_len'] 29 | stopwords = param['stopwords'] 30 | stopwords_histogram = param['stopwords_histogram'] 31 | max_mem_size = param['memory_size'] 32 | max_target_size = param['gold_target_size'] 33 | ques_type_id = param['ques_type_id'] 34 | ques_type_name = param['ques_type_name'] 35 | vocab_max_len = param['vocab_max_len'] 36 | wikidata_dir = param['wikidata_dir'] 37 | lucene_dir = param['lucene_dir'] 38 | transe_dir = param['transe_dir'] 39 | glove_dir = param['glove_dir'] 40 | preparedata = PrepareData(max_utter, max_len, start_symbol_index, end_symbol_index, unk_symbol_index, pad_symbol_index, kb_pad_idx, nkb, stopwords, stopwords_histogram, lucene_dir, transe_dir, wikidata_dir, glove_dir, max_mem_size, max_target_size, vocab_max_len, True, cutoff=vocab_freq_cutoff) 41 | if os.path.isfile(vocab_file): 42 | print 'found existing vocab file in '+str(vocab_file)+', ... reading from there' 43 | print 'to delete later ',os.path.join(dump_dir_loc, "train") 44 | preparedata.prepare_data(train_dir_loc, vocab_file, vocab_stats_file, os.path.join(dump_dir_loc, "train"), train_data_file, ques_type_id, ques_type_name) 45 | preparedata.prepare_data(valid_dir_loc, vocab_file, vocab_stats_file, os.path.join(dump_dir_loc, "valid"), valid_data_file, ques_type_id, ques_type_name) 46 | #preparedata.prepare_data(test_dir_loc, vocab_file, vocab_stats_file, os.path.join(dump_dir_loc, "test"), test_data_file, ques_type_id) 47 | 48 | 49 | def get_dialog_dict_for_test(param): 50 | test_dir_loc = param['test_dir_loc'] 51 | dump_dir_loc = param['dump_dir_loc'] 52 | vocab_file = param['vocab_file'] 53 | vocab_stats_file = param['vocab_stats_file'] 54 | response_vocab_file = param['response_vocab_file'] 55 | vocab_freq_cutoff = param['vocab_freq_cutoff'] 56 | test_data_file = param['test_data_file'] 57 | max_utter = param['max_utter'] 58 | max_len = param['max_len'] 59 | stopwords = param['stopwords'] 60 | stopwords_histogram = param['stopwords_histogram'] 61 | max_mem_size = param['memory_size'] 62 | max_target_size = param['gold_target_size'] 63 | ques_type_id = param['ques_type_id'] 64 | ques_type_name = param['ques_type_name'] 65 | vocab_max_len = param['vocab_max_len'] 66 | wikidata_dir = param['wikidata_dir'] 67 | lucene_dir = param['lucene_dir'] 68 | transe_dir = param['transe_dir'] 69 | glove_dir = param['glove_dir'] 70 | preparedata = PrepareData(max_utter, max_len, start_symbol_index, end_symbol_index, unk_symbol_index, pad_symbol_index, kb_pad_idx, nkb, stopwords, stopwords_histogram, lucene_dir, transe_dir, wikidata_dir, glove_dir, max_mem_size, max_target_size, vocab_max_len, True, cutoff=vocab_freq_cutoff) 71 | if os.path.isfile(vocab_file): 72 | print 'found existing vocab file in '+str(vocab_file)+', ... reading from there' 73 | print 'to delete later ',os.path.join(dump_dir_loc, "train") 74 | preparedata.prepare_data(test_dir_loc, vocab_file, vocab_stats_file, response_vocab_file, os.path.join(dump_dir_loc, "test"), test_data_file, ques_type_id, ques_type_name) 75 | 76 | 77 | 78 | def get_utter_seq_len(dialogue_dict_w2v, dialogue_dict_kb, dialogue_target, dialogue_response, dialogue_response_length, dialogue_sources, dialogue_rel, dialogue_key_target, max_len, max_utter, max_target_size, max_mem_size, batch_size, is_test=False): 79 | padded_utters_id_w2v = None 80 | padded_utters_id_kb = None 81 | padded_target =[] 82 | decode_seq_len = [] 83 | padded_utters_id_w2v = np.asarray([[xij for xij in dialogue_i] for dialogue_i in dialogue_dict_w2v]) 84 | padded_utters_id_kb = np.asarray([[xij for xij in dialogue_i] for dialogue_i in dialogue_dict_kb]) 85 | if not is_test: 86 | padded_target = np.asarray([xi for xi in dialogue_target]) 87 | else: 88 | padded_target = dialogue_target 89 | padded_response = np.asarray([xi for xi in dialogue_response]) 90 | pad_to_response = np.reshape(np.array([pad_symbol_index]*batch_size), (batch_size, 1)) 91 | padded_decoder_input = np.concatenate((pad_to_response, padded_response[:,:-1]), axis=1) 92 | padded_response_length = np.asarray(dialogue_response_length) 93 | padded_sources = np.asarray([xi[:-1-max(0, len(xi)-max_mem_size)]+[kb_pad_idx]*(max(0,max_mem_size-len(xi)))+[xi[-1]] for xi in dialogue_sources], dtype=np.int32) 94 | padded_rel = np.asarray([xi[:-1-max(0, len(xi)-max_mem_size)]+[kb_pad_idx]*(max(0,max_mem_size-len(xi)))+[xi[-1]] for xi in dialogue_rel], dtype=np.int32) 95 | padded_key_target = np.asarray([xi[:-1-max(0, len(xi)-max_mem_size)]+[kb_pad_idx]*(max(0,max_mem_size-len(xi)))+[xi[-1]] for xi in dialogue_key_target], dtype=np.int32) 96 | return padded_utters_id_w2v, padded_utters_id_kb, padded_target, padded_response, padded_response_length, padded_decoder_input, padded_sources, padded_rel, padded_key_target 97 | 98 | def get_weights(batch_size, max_len, actual_len): 99 | remaining_len = max_len - actual_len 100 | weights = [[1.]*actual_len_i+[0.]*remaining_len_i for actual_len_i,remaining_len_i in zip(actual_len,remaining_len)] 101 | weights = np.asarray(weights) 102 | return weights 103 | 104 | def get_memory_weights(batch_size, max_mem_size, sources, rel, target): 105 | weights = np.ones((batch_size, max_mem_size)) 106 | weights[np.where(sources==kb_pad_idx)] = 0. 107 | weights[np.where(rel==kb_pad_idx)] = 0. 108 | weights[np.where(target==kb_pad_idx)] = 0. 109 | weights[np.where(sources==nkb)] = 0. 110 | weights[np.where(rel==nkb)] = 0. 111 | weights[np.where(target==nkb)] = 0. 112 | return weights 113 | 114 | def get_batch_data(max_len, max_utter, max_mem_size, max_target_size, batch_size, data_dict, overriding_memory=None, is_test=False): 115 | data_dict = np.asarray(data_dict) 116 | batch_enc_w2v = data_dict[:,0] 117 | batch_enc_kb = data_dict[:,1] 118 | batch_target = data_dict[:,2] 119 | batch_response = data_dict[:,3] 120 | batch_response_length = data_dict[:,4] 121 | batch_orig_response = data_dict[:,5] 122 | batch_sources = [x.split("|") for x in data_dict[:,6]] 123 | batch_rel = [x.split("|") for x in data_dict[:,7]] 124 | batch_key_target = [x.split("|") for x in data_dict[:,8]] 125 | if is_test: 126 | batch_orig_response_entities = data_dict[:,10] 127 | else: 128 | batch_orig_response_entities = ['']*data_dict.shape[0] 129 | if isinstance(batch_orig_response_entities, np.ndarray): 130 | batch_orig_response_entities = batch_orig_response_entities.tolist() 131 | if len(batch_orig_response_entities)!=batch_size: 132 | batch_orig_response_entities.extend(['']*(batch_size-len(batch_orig_response_entities))) 133 | if overriding_memory is not None: 134 | batch_sources = [x[:-1][:overriding_memory-1]+[x[-1]] for x in batch_sources] 135 | batch_rel = [x[:-1][:overriding_memory-1]+[x[-1]] for x in batch_rel] 136 | batch_key_target = [x[:-1][:overriding_memory-1]+[x[-1]] for x in batch_key_target] 137 | ''' 138 | try: 139 | batch_active_set = data_dict[:,9] 140 | except: 141 | batch_active_set = ['']*len(data_dict) 142 | ''' 143 | batch_active_set = ['']*len(data_dict) 144 | batch_response_len = [len(response) for response in batch_response] 145 | orig_lens = [len(batch_sources_i) for batch_sources_i in batch_sources] 146 | max_mem_size = max(orig_lens) 147 | avg_mem_size = float(sum(orig_lens))/float(len(batch_sources)) 148 | if max_mem_size - avg_mem_size >500: 149 | print 'WARNING: max_mem_size ',max_mem_size, 'avg_mem_size ',avg_mem_size 150 | if len(data_dict) % batch_size != 0: 151 | batch_enc_w2v, batch_enc_kb, batch_target, batch_response, batch_response_length, batch_orig_response, batch_sources, batch_rel, batch_key_target, batch_active_set = check_padding(batch_enc_w2v, batch_enc_kb, batch_target, batch_response, batch_response_length, batch_orig_response, batch_sources, batch_rel, batch_key_target, batch_active_set, max_len, max_utter, max_mem_size, max_target_size, batch_size, is_test) 152 | 153 | padded_enc_w2v, padded_enc_kb, padded_target, padded_response, padded_response_length, padded_decoder_input, padded_batch_sources, padded_batch_rel, padded_batch_key_target = get_utter_seq_len(batch_enc_w2v, batch_enc_kb, batch_target, batch_response, batch_response_length, batch_sources, batch_rel, batch_key_target, max_len, max_utter, max_target_size, max_mem_size, batch_size, is_test) 154 | 155 | padded_weights = get_weights(batch_size, max_len, padded_response_length) 156 | padded_memory_weights = get_memory_weights(batch_size, max_mem_size, padded_batch_sources, padded_batch_rel, padded_batch_key_target) 157 | 158 | padded_enc_w2v, padded_enc_kb, padded_target, padded_orig_target, padded_response, padded_weights, padded_decoder_input, padded_batch_sources, padded_batch_rel, padded_batch_key_target = transpose_utterances(padded_enc_w2v, padded_enc_kb, padded_target, padded_response, padded_weights, padded_decoder_input, padded_batch_sources, padded_batch_rel, padded_batch_key_target, max_mem_size, batch_size, is_test) 159 | 160 | return max_mem_size, padded_enc_w2v, padded_enc_kb, padded_target, padded_orig_target, padded_response, batch_orig_response, padded_weights, padded_memory_weights, padded_decoder_input, padded_batch_sources, padded_batch_rel, padded_batch_key_target, batch_active_set, batch_orig_response_entities 161 | 162 | def transpose_utterances(padded_enc_w2v, padded_enc_kb, padded_target, padded_response, padded_weights, padded_decoder_input, batch_sources, batch_rel, batch_key_target, max_mem_size, batch_size, is_test): 163 | 164 | batch_key_target = np.asarray(batch_key_target) # batch_size * max_mem_size 165 | # padded_target : batch_size * max_target_size 166 | if not is_test: 167 | #print 'padded_target shape ', padded_target.shape 168 | mapped_padded_target = np.zeros(padded_target.shape) 169 | for i in xrange(mapped_padded_target.shape[0]): 170 | for j in xrange(mapped_padded_target.shape[1]): 171 | if padded_target[i,j] in batch_key_target[i,:] and padded_target[i,j] != kb_pad_idx: 172 | mapped_padded_target[i,j] = np.nonzero(batch_key_target[i,:] == padded_target[i,j])[0][0] 173 | elif padded_target[i,j] not in batch_key_target[i,:]: 174 | mapped_padded_target[i,j] = max_mem_size-1 175 | else: 176 | mapped_padded_target[i,j] = max_mem_size-1 177 | padded_transposed_enc_w2v = padded_enc_w2v.transpose((1,2,0)) 178 | padded_transposed_enc_kb = padded_enc_kb.transpose((1,2,0)) 179 | if not is_test: 180 | padded_transposed_target = mapped_padded_target.transpose((1,0)) 181 | else: 182 | padded_transposed_target = padded_target 183 | padded_transposed_response = padded_response.transpose((1,0)) 184 | padded_transposed_weights = padded_weights.transpose((1,0)) 185 | padded_transposed_decoder_input = padded_decoder_input.transpose((1,0)) 186 | if not is_test: 187 | padded_transposed_orig_target = padded_target.transpose((1,0)) 188 | else: 189 | padded_transposed_orig_target = padded_target 190 | padded_batch_sources = np.asarray(batch_sources).transpose((1,0)) 191 | padded_batch_rel = np.asarray(batch_rel).transpose((1,0)) 192 | padded_batch_key_target = np.asarray(batch_key_target).transpose((1,0)) 193 | 194 | return padded_transposed_enc_w2v, padded_transposed_enc_kb, padded_transposed_target, padded_transposed_orig_target, padded_transposed_response, padded_transposed_weights, padded_transposed_decoder_input, padded_batch_sources, padded_batch_rel, padded_batch_key_target 195 | 196 | def batch_padding_context(data_mat, max_len, max_utter, pad_size): 197 | empty_data = [start_symbol_index, end_symbol_index]+[pad_symbol_index]*(max_len-2) 198 | empty_data = [empty_data]*max_utter 199 | empty_data_mat = [empty_data]*pad_size 200 | data_mat=data_mat.tolist() 201 | data_mat.extend(empty_data_mat) 202 | return data_mat 203 | 204 | def batch_padding_target(data_mat, max_target_size, pad_size, is_test=False): 205 | if not is_test: 206 | empty_data = [kb_pad_idx] * max_target_size 207 | empty_data = [empty_data] * pad_size 208 | data_mat=data_mat.tolist() 209 | data_mat.extend(empty_data) 210 | else: 211 | if isinstance(data_mat, list): 212 | data_mat.extend(['']*pad_size) 213 | else: 214 | data_mat=data_mat.tolist() 215 | data_mat.extend(['']*pad_size) 216 | return data_mat 217 | 218 | def batch_padding_response(data_mat, max_len, pad_size): 219 | empty_data = [start_symbol_index, end_symbol_index]+[pad_symbol_index]*(max_len-2) 220 | empty_data_mat = [empty_data]*pad_size 221 | data_mat=data_mat.tolist() 222 | data_mat.extend(empty_data_mat) 223 | return data_mat 224 | 225 | def batch_padding_response_len(data_mat, pad_size): 226 | empty_data_mat = [2]*pad_size 227 | data_mat = data_mat.tolist() 228 | data_mat.extend(empty_data_mat) 229 | return data_mat 230 | 231 | def batch_padding_orig_response(data_mat, pad_size): 232 | data_mat = data_mat.tolist() 233 | data_mat.extend(['']*pad_size) 234 | return data_mat 235 | 236 | def batch_padding_active_set(data_mat, pad_size): 237 | if not isinstance(data_mat, list): 238 | data_mat = data_mat.tolist() 239 | data_mat.extend(['']*pad_size) 240 | return data_mat 241 | 242 | def batch_padding_memory_ent(data_mat, max_mem_size, pad_size): 243 | empty_data = [kb_pad_idx]*(max_mem_size) 244 | empty_data = [empty_data]*pad_size 245 | if not isinstance(data_mat, list): 246 | data_mat=data_mat.tolist() 247 | data_mat.extend(empty_data) 248 | return data_mat 249 | 250 | def batch_padding_memory_rel(data_mat, max_mem_size, pad_size): 251 | empty_data = [kb_pad_idx]*(max_mem_size) 252 | empty_data = [empty_data]*pad_size 253 | if not isinstance(data_mat, list): 254 | data_mat=data_mat.tolist() 255 | data_mat.extend(empty_data) 256 | return data_mat 257 | 258 | def check_padding(batch_enc_w2v, batch_enc_kb, batch_target, batch_response, batch_response_length, batch_orig_response, batch_sources, batch_rel, batch_key_target, batch_active_set, max_len, max_utter, max_mem_size, max_target_size, batch_size, is_test=False): 259 | pad_size = batch_size - len(batch_target) % batch_size 260 | batch_enc_w2v = batch_padding_context(batch_enc_w2v, max_len, max_utter, pad_size) 261 | batch_enc_kb = batch_padding_context(batch_enc_kb, max_len, max_utter, pad_size) 262 | batch_target = batch_padding_target(batch_target, max_target_size, pad_size, is_test) 263 | batch_response = batch_padding_response(batch_response, max_len, pad_size) 264 | batch_response_length = batch_padding_response_len(batch_response_length, pad_size) 265 | batch_orig_response = batch_padding_orig_response(batch_orig_response, pad_size) 266 | batch_sources = batch_padding_memory_ent(batch_sources, max_mem_size, pad_size) # adding one dummy entry for OOM entities 267 | batch_rel = batch_padding_memory_rel(batch_rel, max_mem_size, pad_size) # adding one dummy entry for OOM entities 268 | batch_key_target = batch_padding_memory_ent(batch_key_target, max_mem_size, pad_size) # adding one dummy entry for OOM entities 269 | batch_active_set = batch_padding_active_set(batch_active_set, pad_size) 270 | return batch_enc_w2v, batch_enc_kb, batch_target, batch_response, batch_response_length, batch_orig_response, batch_sources, batch_rel, batch_key_target, batch_active_set 271 | 272 | 273 | def load_valid_test_target(data_dict): 274 | return np.asarray(data_dict)[:,3] 275 | 276 | -------------------------------------------------------------------------------- /relation_linker/annoy_index_rel/glove_embedding_of_vocab.ann: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amritasaha1812/CSQA_Code/3d6724b7cab1972f2636affdde94aedecaf13978/relation_linker/annoy_index_rel/glove_embedding_of_vocab.ann -------------------------------------------------------------------------------- /relation_linker/annoy_index_rel/index2word.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | I0 3 | S'coach' 4 | p2 5 | sI1 6 | S'founder' 7 | p3 8 | sI2 9 | S'spacecraft' 10 | p4 11 | sI3 12 | S'legal' 13 | p5 14 | sI4 15 | S'settlement' 16 | p6 17 | sI5 18 | S'religious' 19 | p7 20 | sI6 21 | S'discoverer' 22 | p8 23 | sI7 24 | S'zone' 25 | p9 26 | sI8 27 | S'religion' 28 | p10 29 | sI9 30 | S'father' 31 | p11 32 | sI10 33 | S'languages' 34 | p12 35 | sI11 36 | S'overlaps' 37 | p13 38 | sI12 39 | S'jurisdiction' 40 | p14 41 | sI13 42 | S'detention' 43 | p15 44 | sI14 45 | S'strand' 46 | p16 47 | sI15 48 | S'gender' 49 | p17 50 | sI16 51 | S'venue' 52 | p18 53 | sI17 54 | S'manager' 55 | p19 56 | sI18 57 | S'standards' 58 | p20 59 | sI19 60 | S'presenter' 61 | p21 62 | sI20 63 | S'team' 64 | p22 65 | sI21 66 | S'biological' 67 | p23 68 | sI22 69 | S'occurrence' 70 | p24 71 | sI23 72 | S'licensed' 73 | p25 74 | sI24 75 | S'affiliation' 76 | p26 77 | sI25 78 | S'street' 79 | p27 80 | sI26 81 | S'port' 82 | p28 83 | sI27 84 | S'repeals' 85 | p29 86 | sI28 87 | S'anthem' 88 | p30 89 | sI29 90 | S'shares' 91 | p31 92 | sI30 93 | S'cell' 94 | p32 95 | sI31 96 | S'capital' 97 | p33 98 | sI32 99 | S'ammunition' 100 | p34 101 | sI33 102 | S'method' 103 | p35 104 | sI34 105 | S'movement' 106 | p36 107 | sI35 108 | S'body' 109 | p37 110 | sI36 111 | S'drafted' 112 | p38 113 | sI37 114 | S'degree' 115 | p39 116 | sI38 117 | S'exchange' 118 | p40 119 | sI39 120 | S'component' 121 | p41 122 | sI40 123 | S'water' 124 | p42 125 | sI41 126 | S'operating' 127 | p43 128 | sI42 129 | S'based' 130 | p44 131 | sI43 132 | S'publisher' 133 | p45 134 | sI44 135 | S'inspired' 136 | p46 137 | sI45 138 | S'airport' 139 | p47 140 | sI46 141 | S'periodic' 142 | p48 143 | sI47 144 | S'published' 145 | p49 146 | sI48 147 | S'military' 148 | p50 149 | sI49 150 | S'named' 151 | p51 152 | sI50 153 | S'family' 154 | p52 155 | sI51 156 | S'composer' 157 | p53 158 | sI52 159 | S'use' 160 | p54 161 | sI53 162 | S'eye' 163 | p55 164 | sI54 165 | S'contains' 166 | p56 167 | sI55 168 | S'arms' 169 | p57 170 | sI56 171 | S'sports' 172 | p58 173 | sI57 174 | S'archives' 175 | p59 176 | sI58 177 | S'vehicle' 178 | p60 179 | sI59 180 | S'australian' 181 | p61 182 | sI60 183 | S'manufacturer' 184 | p62 185 | sI61 186 | S'type' 187 | p63 188 | sI62 189 | S'minor' 190 | p64 191 | sI63 192 | S'conductor' 193 | p65 194 | sI64 195 | S'successful' 196 | p66 197 | sI65 198 | S'company' 199 | p67 200 | sI66 201 | S'award' 202 | p68 203 | sI67 204 | S'flag' 205 | p69 206 | sI68 207 | S'commissioned' 208 | p70 209 | sI69 210 | S'officeholder' 211 | p71 212 | sI70 213 | S'appointed' 214 | p72 215 | sI71 216 | S'crosses' 217 | p73 218 | sI72 219 | S'work' 220 | p74 221 | sI73 222 | S'history' 223 | p75 224 | sI74 225 | S'dedicated' 226 | p76 227 | sI75 228 | S'performer' 229 | p77 230 | sI76 231 | S'patron' 232 | p78 233 | sI77 234 | S'stated' 235 | p79 236 | sI78 237 | S'process' 238 | p80 239 | sI79 240 | S'currency' 241 | p81 242 | sI80 243 | S'discography' 244 | p82 245 | sI81 246 | S'coincident' 247 | p83 248 | sI82 249 | S'organizer' 250 | p84 251 | sI83 252 | S'cites' 253 | p85 254 | sI84 255 | S'native' 256 | p86 257 | sI85 258 | S'influenced' 259 | p87 260 | sI86 261 | S'feature' 262 | p88 263 | sI87 264 | S'species' 265 | p89 266 | sI88 267 | S'located' 268 | p90 269 | sI89 270 | S'criterion' 271 | p91 272 | sI90 273 | S'adjacent' 274 | p92 275 | sI91 276 | S'inflows' 277 | p93 278 | sI92 279 | S'stock' 280 | p94 281 | sI93 282 | S'product' 283 | p95 284 | sI94 285 | S'designer' 286 | p96 287 | sI95 288 | S'collection' 289 | p97 290 | sI96 291 | S'catalog' 292 | p98 293 | sI97 294 | S'birthday' 295 | p99 296 | sI98 297 | S'blood' 298 | p100 299 | sI99 300 | S'ethnic' 301 | p101 302 | sI100 303 | S'ortholog' 304 | p102 305 | sI101 306 | S'coat' 307 | p103 308 | sI102 309 | S'exhibition' 310 | p104 311 | sI103 312 | S'fictional' 313 | p105 314 | sI104 315 | S'chief' 316 | p106 317 | sI105 318 | S'order' 319 | p107 320 | sI106 321 | S'producer' 322 | p108 323 | sI107 324 | S'astronomical' 325 | p109 326 | sI108 327 | S'mission' 328 | p110 329 | sI109 330 | S'held' 331 | p111 332 | sI110 333 | S'costume' 334 | p112 335 | sI111 336 | S'statistical' 337 | p113 338 | sI112 339 | S'arrangement' 340 | p114 341 | sI113 342 | S'developer' 343 | p115 344 | sI114 345 | S'style' 346 | p116 347 | sI115 348 | S'group' 349 | p117 350 | sI116 351 | S'tonality' 352 | p118 353 | sI117 354 | S'crew' 355 | p119 356 | sI118 357 | S'symptoms' 358 | p120 359 | sI119 360 | S'platform' 361 | p121 362 | sI120 363 | S'production' 364 | p122 365 | sI121 366 | S'condition' 367 | p123 368 | sI122 369 | S'main' 370 | p124 371 | sI123 372 | S'qualifies' 373 | p125 374 | sI124 375 | S'combination' 376 | p126 377 | sI125 378 | S'material' 379 | p127 380 | sI126 381 | S'rank' 382 | p128 383 | sI127 384 | S'killed' 385 | p129 386 | sI128 387 | S'association' 388 | p130 389 | sI129 390 | S'storyboard' 391 | p131 392 | sI130 393 | S'name' 394 | p132 395 | sI131 396 | S'feast' 397 | p133 398 | sI132 399 | S'victory' 400 | p134 401 | sI133 402 | S'found' 403 | p135 404 | sI134 405 | S'burial' 406 | p136 407 | sI135 408 | S'list' 409 | p137 410 | sI136 411 | S'brother' 412 | p138 413 | sI137 414 | S'operator' 415 | p139 416 | sI138 417 | S'decays' 418 | p140 419 | sI139 420 | S'sponsor' 421 | p141 422 | sI140 423 | S'year' 424 | p142 425 | sI141 426 | S'distributor' 427 | p143 428 | sI142 429 | S'cathedral' 430 | p144 431 | sI143 432 | S'sexual' 433 | p145 434 | sI144 435 | S'event' 436 | p146 437 | sI145 438 | S'category' 439 | p147 440 | sI146 441 | S'network' 442 | p148 443 | sI147 444 | S'item' 445 | p149 446 | sI148 447 | S'day' 448 | p150 449 | sI149 450 | S'cause' 451 | p151 452 | sI150 453 | S'highway' 454 | p152 455 | sI151 456 | S'base' 457 | p153 458 | sI152 459 | S'formation' 460 | p154 461 | sI153 462 | S'advisor' 463 | p155 464 | sI154 465 | S'interaction' 466 | p156 467 | sI155 468 | S'district' 469 | p157 470 | sI156 471 | S'language' 472 | p158 473 | sI157 474 | S'launch' 475 | p159 476 | sI158 477 | S'route' 478 | p160 479 | sI159 480 | S'programming' 481 | p161 482 | sI160 483 | S'place' 484 | p162 485 | sI161 486 | S'competed' 487 | p163 488 | sI162 489 | S'first' 490 | p164 491 | sI163 492 | S'origin' 493 | p165 494 | sI164 495 | S'major' 496 | p166 497 | sI165 498 | S'characters' 499 | p167 500 | sI166 501 | S'owned' 502 | p168 503 | sI167 504 | S'endemic' 505 | p169 506 | sI168 507 | S'legislative' 508 | p170 509 | sI169 510 | S'city' 511 | p171 512 | sI170 513 | S'given' 514 | p172 515 | sI171 516 | S'publication' 517 | p173 518 | sI172 519 | S'service' 520 | p174 521 | sI173 522 | S'system' 523 | p175 524 | sI174 525 | S'station' 526 | p176 527 | sI175 528 | S'saint' 529 | p177 530 | sI176 531 | S'molecular' 532 | p178 533 | sI177 534 | S'final' 535 | p179 536 | sI178 537 | S'recovered' 538 | p180 539 | sI179 540 | S'hub' 541 | p181 542 | sI180 543 | S'broadcast' 544 | p182 545 | sI181 546 | S'part' 547 | p183 548 | sI182 549 | S'kept' 550 | p184 551 | sI183 552 | S'translation' 553 | p185 554 | sI184 555 | S'inventor' 556 | p186 557 | sI185 558 | S'powerplant' 559 | p187 560 | sI186 561 | S'diplomatic' 562 | p188 563 | sI187 564 | S'architect' 565 | p189 566 | sI188 567 | S'treated' 568 | p190 569 | sI189 570 | S'nominee' 571 | p191 572 | sI190 573 | S'headquarters' 574 | p192 575 | sI191 576 | S'manner' 577 | p193 578 | sI192 579 | S'border' 580 | p194 581 | sI193 582 | S'instrumentation' 583 | p195 584 | sI194 585 | S'chess' 586 | p196 587 | sI195 588 | S'tracklist' 589 | p197 590 | sI196 591 | S'track' 592 | p198 593 | sI197 594 | S'measures' 595 | p199 596 | sI198 597 | S'translator' 598 | p200 599 | sI199 600 | S'mouth' 601 | p201 602 | sI200 603 | S'significant' 604 | p202 605 | sI201 606 | S'organization' 607 | p203 608 | sI202 609 | S'device' 610 | p204 611 | sI203 612 | S'screenwriter' 613 | p205 614 | sI204 615 | S'outflow' 616 | p206 617 | sI205 618 | S'medical' 619 | p207 620 | sI206 621 | S'points' 622 | p208 623 | sI207 624 | S'canonization' 625 | p209 626 | sI208 627 | S'institutions' 628 | p210 629 | sI209 630 | S'rating' 631 | p211 632 | sI210 633 | S'radix' 634 | p212 635 | sI211 636 | S'lyrics' 637 | p213 638 | sI212 639 | S'certification' 640 | p214 641 | sI213 642 | S'armament' 643 | p215 644 | sI214 645 | S'nominated' 646 | p216 647 | sI215 648 | S'relation' 649 | p217 650 | sI216 651 | S'subsidiary' 652 | p218 653 | sI217 654 | S'taxonomic' 655 | p219 656 | sI218 657 | S'occupation' 658 | p220 659 | sI219 660 | S'ratio' 661 | p221 662 | sI220 663 | S'office' 664 | p222 665 | sI221 666 | S'title' 667 | p223 668 | sI222 669 | S'winner' 670 | p224 671 | sI223 672 | S'state' 673 | p225 674 | sI224 675 | S'employer' 676 | p226 677 | sI225 678 | S'monuments' 679 | p227 680 | sI226 681 | S'photography' 682 | p228 683 | sI227 684 | S'occupant' 685 | p229 686 | sI228 687 | S'artist' 688 | p230 689 | sI229 690 | S'teams' 691 | p231 692 | sI230 693 | S'peak' 694 | p232 695 | sI231 696 | S'river' 697 | p233 698 | sI232 699 | S'view' 700 | p234 701 | sI233 702 | S'art' 703 | p235 704 | sI234 705 | S'creator' 706 | p236 707 | sI235 708 | S'physically' 709 | p237 710 | sI236 711 | S'terrain' 712 | p238 713 | sI237 714 | S'relative' 715 | p239 716 | sI238 717 | S'edition' 718 | p240 719 | sI239 720 | S'operated' 721 | p241 722 | sI240 723 | S'sport' 724 | p242 725 | sI241 726 | S'constellation' 727 | p243 728 | sI242 729 | S'subject' 730 | p244 731 | sI243 732 | S'record' 733 | p245 734 | sI244 735 | S'electoral' 736 | p246 737 | sI245 738 | S'label' 739 | p247 740 | sI246 741 | S'written' 742 | p248 743 | sI247 744 | S'discovery' 745 | p249 746 | sI248 747 | S'parent' 748 | p250 749 | sI249 750 | S'genre' 751 | p251 752 | sI250 753 | S'received' 754 | p252 755 | sI251 756 | S'license' 757 | p253 758 | sI252 759 | S'country' 760 | p254 761 | sI253 762 | S'drug' 763 | p255 764 | sI254 765 | S'facet' 766 | p256 767 | sI255 768 | S'point' 769 | p257 770 | sI256 771 | S'color' 772 | p258 773 | sI257 774 | S'period' 775 | p259 776 | sI258 777 | S'maintained' 778 | p260 779 | sI259 780 | S'vessel' 781 | p261 782 | sI260 783 | S'writing' 784 | p262 785 | sI261 786 | S'described' 787 | p263 788 | sI262 789 | S'political' 790 | p264 791 | sI263 792 | S'convicted' 793 | p265 794 | sI264 795 | S'genetic' 796 | p266 797 | sI265 798 | S'offers' 799 | p267 800 | sI266 801 | S'basic' 802 | p268 803 | sI267 804 | S'basin' 805 | p269 806 | sI268 807 | S'territory' 808 | p270 809 | sI269 810 | S'board' 811 | p271 812 | sI270 813 | S'engine' 814 | p272 815 | sI271 816 | S'direction' 817 | p273 818 | sI272 819 | S'educated' 820 | p274 821 | sI273 822 | S'league' 823 | p275 824 | sI274 825 | S'tense' 826 | p276 827 | sI275 828 | S'child' 829 | p277 830 | sI276 831 | S'present' 832 | p278 833 | sI277 834 | S'mount' 835 | p279 836 | sI278 837 | S'cast' 838 | p280 839 | sI279 840 | S'watercourse' 841 | p281 842 | sI280 843 | S'applies' 844 | p282 845 | sI281 846 | S'aid' 847 | p283 848 | sI282 849 | S'voice' 850 | p284 851 | sI283 852 | S'orientation' 853 | p285 854 | sI284 855 | S'played' 856 | p286 857 | sI285 858 | S'surface' 859 | p287 860 | sI286 861 | S'player' 862 | p288 863 | sI287 864 | S'partner' 865 | p289 866 | sI288 867 | S'astronaut' 868 | p290 869 | sI289 870 | S'interacts' 871 | p291 872 | sI290 873 | S'participating' 874 | p292 875 | sI291 876 | S'damaged' 877 | p293 878 | sI292 879 | S'participant' 880 | p294 881 | sI293 882 | S'author' 883 | p295 884 | sI294 885 | S'administration' 886 | p296 887 | sI295 888 | S'member' 889 | p297 890 | sI296 891 | S'wheel' 892 | p298 893 | sI297 894 | S'chromosome' 895 | p299 896 | sI298 897 | S'afflicts' 898 | p300 899 | sI299 900 | S'party' 901 | p301 902 | sI300 903 | S'conflict' 904 | p302 905 | sI301 906 | S'higher' 907 | p303 908 | sI302 909 | S'status' 910 | p304 911 | sI303 912 | S'noble' 913 | p305 914 | sI304 915 | S'used' 916 | p306 917 | sI305 918 | S'director' 919 | p307 920 | sI306 921 | S'student' 922 | p308 923 | sI307 924 | S'fabrication' 925 | p309 926 | sI308 927 | S'practiced' 928 | p310 929 | sI309 930 | S'game' 931 | p311 932 | sI310 933 | S'academic' 934 | p312 935 | sI311 936 | S'mother' 937 | p313 938 | sI312 939 | S'position' 940 | p314 941 | sI313 942 | S'musical' 943 | p315 944 | sI314 945 | S'bodies' 946 | p316 947 | sI315 948 | S'kinship' 949 | p317 950 | sI316 951 | S'executive' 952 | p318 953 | sI317 954 | S'electrification' 955 | p319 956 | sI318 957 | S'aspect' 958 | p320 959 | sI319 960 | S'continent' 961 | p321 962 | sI320 963 | S'discipline' 964 | p322 965 | sI321 966 | S'doctoral' 967 | p323 968 | sI322 969 | S'death' 970 | p324 971 | sI323 972 | S'classification' 973 | p325 974 | sI324 975 | S'candidate' 976 | p326 977 | sI325 978 | S'grammatical' 979 | p327 980 | sI326 981 | S'character' 982 | p328 983 | sI327 984 | S'lake' 985 | p329 986 | sI328 987 | S'contractor' 988 | p330 989 | sI329 990 | S'instrument' 991 | p331 992 | sI330 993 | S'location' 994 | p332 995 | sI331 996 | S'input' 997 | p333 998 | sI332 999 | S'identical' 1000 | p334 1001 | sI333 1002 | S'government' 1003 | p335 1004 | sI334 1005 | S'cultural' 1006 | p336 1007 | sI335 1008 | S'birth' 1009 | p337 1010 | sI336 1011 | S'depicts' 1012 | p338 1013 | sI337 1014 | S'officer' 1015 | p339 1016 | sI338 1017 | S'served' 1018 | p340 1019 | sI339 1020 | S'works' 1021 | p341 1022 | sI340 1023 | S'gauge' 1024 | p342 1025 | sI341 1026 | S'diocese' 1027 | p343 1028 | sI342 1029 | S'hair' 1030 | p344 1031 | sI343 1032 | S'legislated' 1033 | p345 1034 | sI344 1035 | S'home' 1036 | p346 1037 | sI345 1038 | S'contested' 1039 | p347 1040 | sI346 1041 | S'mood' 1042 | p348 1043 | sI347 1044 | S'filming' 1045 | p349 1046 | sI348 1047 | S'connecting' 1048 | p350 1049 | sI349 1050 | S'leader' 1051 | p351 1052 | sI350 1053 | S'mode' 1054 | p352 1055 | sI351 1056 | S'assembly' 1057 | p353 1058 | sI352 1059 | S'notable' 1060 | p354 1061 | sI353 1062 | S'lakes' 1063 | p355 1064 | sI354 1065 | S'architectural' 1066 | p356 1067 | sI355 1068 | S'stage' 1069 | p357 1070 | sI356 1071 | S'sister' 1072 | p358 1073 | sI357 1074 | S'lifestyle' 1075 | p359 1076 | sI358 1077 | S'universe' 1078 | p360 1079 | sI359 1080 | S'industry' 1081 | p361 1082 | sI360 1083 | S'taxon' 1084 | p362 1085 | sI361 1086 | S'ideology' 1087 | p363 1088 | sI362 1089 | S'airline' 1090 | p364 1091 | sI363 1092 | S'tributary' 1093 | p365 1094 | sI364 1095 | S'software' 1096 | p366 1097 | sI365 1098 | S'destination' 1099 | p367 1100 | sI366 1101 | S'donated' 1102 | p368 1103 | sI367 1104 | S'contributor' 1105 | p369 1106 | sI368 1107 | S'handedness' 1108 | p370 1109 | sI369 1110 | S'next' 1111 | p371 1112 | sI370 1113 | S'approved' 1114 | p372 1115 | sI371 1116 | S'start' 1117 | p373 1118 | sI372 1119 | S'describes' 1120 | p374 1121 | sI373 1122 | S'editor' 1123 | p375 1124 | sI374 1125 | S'heritage' 1126 | p376 1127 | sI375 1128 | S'function' 1129 | p377 1130 | sI376 1131 | S'head' 1132 | p378 1133 | sI377 1134 | S'form' 1135 | p379 1136 | sI378 1137 | S'specialty' 1138 | p380 1139 | sI379 1140 | S'encodes' 1141 | p381 1142 | sI380 1143 | S'encoded' 1144 | p382 1145 | sI381 1146 | S'line' 1147 | p383 1148 | sI382 1149 | S'highest' 1150 | p384 1151 | sI383 1152 | S'terminus' 1153 | p385 1154 | sI384 1155 | S'illustrator' 1156 | p386 1157 | sI385 1158 | S'official' 1159 | p387 1160 | sI386 1161 | S'signed' 1162 | p388 1163 | sI387 1164 | S'planet' 1165 | p389 1166 | sI388 1167 | S'crystal' 1168 | p390 1169 | sI389 1170 | S'intangible' 1171 | p391 1172 | sI390 1173 | S'distribution' 1174 | p392 1175 | sI391 1176 | S'citizenship' 1177 | p393 1178 | sI392 1179 | S'general' 1180 | p394 1181 | sI393 1182 | S'education' 1183 | p395 1184 | sI394 1185 | S'film' 1186 | p396 1187 | sI395 1188 | S'edibility' 1189 | p397 1190 | sI396 1191 | S'locality' 1192 | p398 1193 | sI397 1194 | S'innervated' 1195 | p399 1196 | sI398 1197 | S'treatment' 1198 | p400 1199 | sI399 1200 | S'actor' 1201 | p401 1202 | sI400 1203 | S'field' 1204 | p402 1205 | sI401 1206 | S'role' 1207 | p403 1208 | sI402 1209 | S'branch' 1210 | p404 1211 | sI403 1212 | S'residence' 1213 | p405 1214 | sI404 1215 | S'class' 1216 | p406 1217 | sI405 1218 | S'time' 1219 | p407 1220 | sI406 1221 | S'develops' 1222 | p408 1223 | sI407 1224 | S'registry' 1225 | p409 1226 | sI408 1227 | S'conferred' 1228 | p410 1229 | sI409 1230 | S'spouse' 1231 | p411 1232 | sI410 1233 | S'building' 1234 | p412 1235 | sI411 1236 | S'chairperson' 1237 | p413 1238 | sI412 1239 | S'space' 1240 | p414 1241 | sI413 1242 | S'original' 1243 | p415 1244 | sI414 1245 | S'narrative' 1246 | p416 1247 | sI415 1248 | S'determination' 1249 | p417 1250 | sI416 1251 | S'cpu' 1252 | p418 1253 | s. -------------------------------------------------------------------------------- /relation_linker/annoy_index_rel_noisy/glove_embedding_of_vocab.ann: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amritasaha1812/CSQA_Code/3d6724b7cab1972f2636affdde94aedecaf13978/relation_linker/annoy_index_rel_noisy/glove_embedding_of_vocab.ann -------------------------------------------------------------------------------- /relation_linker/build_annoy_index_over_relation_words.py: -------------------------------------------------------------------------------- 1 | import gensim 2 | import cPickle as pkl 3 | word2vec_pretrain_embed = gensim.models.Word2Vec.load_word2vec_format('/dccstor/cssblr/amrita/resources/glove/GoogleNews-vectors-negative300.bin', binary=True) 4 | from annoy import AnnoyIndex 5 | 6 | f=300 7 | index = AnnoyIndex(f, metric='euclidean') 8 | index_desc = {} 9 | vocab = pkl.load(open('vocab_count.pkl')) 10 | count = 0 11 | for word in vocab: 12 | word = word[0] 13 | if word in word2vec_pretrain_embed: 14 | if word in word2vec_pretrain_embed: 15 | embed = word2vec_pretrain_embed[word] 16 | index.add_item(count, embed) 17 | index_desc[count] = word 18 | count = count+1 19 | index.build(100) 20 | index.save('annoy_index/glove_embedding_of_vocab.ann') 21 | pkl.dump(index_desc, open('annoy_index/index2word.pkl','wb')) 22 | -------------------------------------------------------------------------------- /relation_linker/create_relation_annoy_index.py: -------------------------------------------------------------------------------- 1 | import string 2 | from string import maketrans 3 | import re 4 | import sys 5 | import traceback 6 | import nltk 7 | from nltk.corpus import stopwords 8 | stop = set(stopwords.words('english')) 9 | import os 10 | import gensim 11 | import cPickle as pkl 12 | from annoy import AnnoyIndex 13 | import json 14 | wikidata_id_name_map={k:re.sub(r'[^\x00-\x7F]+',' ',v) for k,v in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/item_data_filt.json')).items()} 15 | print 'loaded wikidata_id_name_map' 16 | relations = {} 17 | for line in open('predicates_bw.tsv').readlines(): 18 | line = line.strip().lower().split('\t') 19 | rel = line[0] 20 | label = [x for x in ' '.join(line[1:]).split(' ') if x not in stop] 21 | for w in label: 22 | if w not in relations: 23 | relations[w] = set([]) 24 | else: 25 | relations[w].add(rel) 26 | for line in open('predicates_fw.tsv').readlines(): 27 | line = line.strip().lower().split('\t') 28 | rel = line[0] 29 | label = [x for x in ' '.join(line[1:]).split(' ') if x not in stop] 30 | for w in label: 31 | if w not in relations: 32 | relations[w] = set([]) 33 | else: 34 | relations[w].add(rel) 35 | all_relation_words = set([]) 36 | all_relation_words.update(relations.keys()) 37 | word2vec_pretrain_embed = gensim.models.Word2Vec.load_word2vec_format('/dccstor/cssblr/amrita/resources/glove/GoogleNews-vectors-negative300.bin', binary=True) 38 | f=300 39 | index = AnnoyIndex(f, metric='euclidean') 40 | index_desc = {} 41 | count = 0 42 | for word in all_relation_words: 43 | word = word 44 | if word in word2vec_pretrain_embed: 45 | embed = word2vec_pretrain_embed[word] 46 | index.add_item(count, embed) 47 | index_desc[count] = word 48 | count = count+1 49 | index.build(100) 50 | index.save('annoy_index_noisy/glove_embedding_of_vocab.ann') 51 | pkl.dump(index_desc, open('annoy_index_noisy/index2word.pkl','wb')) 52 | 53 | -------------------------------------------------------------------------------- /relation_linker/perform_relation_identification.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import string 4 | import traceback 5 | import nltk 6 | from nltk.corpus import stopwords 7 | stop = set(stopwords.words('english')) 8 | import os 9 | import gensim 10 | import cPickle as pkl 11 | from annoy import AnnoyIndex 12 | import json 13 | wikidata_id_name_map={k:re.sub(r'[^\x00-\x7F]+',' ',v) for k,v in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/item_data_filt.json')).items()} 14 | glove_embedding = gensim.models.Word2Vec.load_word2vec_format('/dccstor/cssblr/amrita/resources/glove/GoogleNews-vectors-negative300.bin', binary=True) 15 | print 'loaded glove embeddings' 16 | ann = AnnoyIndex(300, metric='euclidean') 17 | ann.load('annoy_index_rel_noisy/glove_embedding_of_vocab.ann') 18 | ann_pickle = pkl.load(open('annoy_index_rel_noisy/index2word.pkl')) 19 | data_dir = '/dccstor/cssblr/vardaan/dialog-qa/QA_train_final5/' 20 | count_files=0 21 | rel_name_to_id = {} 22 | for line in open('predicates_bw.tsv').readlines(): 23 | line = line.strip().lower().split('\t') 24 | id = line[0] 25 | name = [x for x in ' '.join(line[1:]).split(' ') if x not in stop] 26 | for name_i in name: 27 | if name_i not in rel_name_to_id: 28 | rel_name_to_id[name_i] = set([]) 29 | rel_name_to_id[name_i].add(id) 30 | for line in open('predicates_fw.tsv').readlines(): 31 | line = line.strip().lower().split('\t') 32 | id = line[0] 33 | name = [x for x in ' '.join(line[1:]).split(' ') if x not in stop] 34 | for name_i in name: 35 | if name_i not in rel_name_to_id: 36 | rel_name_to_id[name_i] = set([]) 37 | rel_name_to_id[name_i].add(id) 38 | ann_pickle_rel = {} 39 | for id,name in ann_pickle.items(): 40 | ann_pickle_rel[id] = rel_name_to_id[name] 41 | pkl.dump(ann_pickle_rel, open('annoy_index_rel_noisy/index2rel.pkl','w')) 42 | prec = 0.0 43 | rec = 0.0 44 | count_acc = 0.0 45 | for dir in os.listdir(data_dir): 46 | if 'txt' in dir or 'pickle' in dir or 'xls' in dir: 47 | continue 48 | print dir 49 | for dir2 in os.listdir(data_dir+'/'+dir): 50 | if 'txt' in dir2 or 'pickle' in dir2 or 'xls' in dir2: 51 | continue 52 | for file in os.listdir(data_dir+'/'+dir+'/'+dir2): 53 | if not file.endswith('json'): 54 | continue 55 | count_files+=1 56 | if count_files%100==0: 57 | print 'finished ',count_files 58 | data = json.load(open(data_dir+'/'+dir+'/'+dir2+'/'+file)) 59 | for utter in data: 60 | try: 61 | utterance = utter['utterance'].lower() 62 | utterance = re.sub(r'[^\x00-\x7F]+',' ',utterance) 63 | utterance = str(utterance).translate(string.maketrans("",""),string.punctuation) 64 | if 'relations' in utter: 65 | relations = [x.lower() for x in utter['relations']] 66 | else: 67 | continue 68 | if 'entities' in utter: 69 | entities = utter['entities'] 70 | else: 71 | entities = [] 72 | if 'Qid' in utter: 73 | entities.append(utter['Qid']) 74 | if 'prop_Qid_par' in utter: 75 | parent = utter['prop_Qid_par'] 76 | else: 77 | parent = None 78 | entity_names = [wikidata_id_name_map[id].lower() for id in entities if id in wikidata_id_name_map] 79 | if parent is not None and parent in wikidata_id_name_map: 80 | parent_name = wikidata_id_name_map[parent].lower() 81 | else: 82 | parent_name = None 83 | utterance_replaced = utterance 84 | for e in entity_names: 85 | if e is not None: 86 | utterance_replaced = utterance_replaced.replace(e,'') 87 | if parent_name is not None: 88 | utterance_replaced = utterance_replaced.replace(parent_name, '') 89 | utterance_replaced = re.sub(' +',' ', utterance_replaced) 90 | words = set([x for x in utterance_replaced.split(' ') if not x.isdigit() and len(x)>1 and x in glove_embedding]) 91 | words = words - stop 92 | rel_ids = set([]) 93 | for word in words: 94 | if word not in glove_embedding: 95 | continue 96 | word_vec = glove_embedding[word] 97 | nns = ann.get_nns_by_vector(word_vec, 2) 98 | for nn in nns: 99 | nn_word = ann_pickle[nn] 100 | rel_ids.update(rel_name_to_id[nn_word]) 101 | #print 'predicted rel ids ', rel_ids 102 | #print 'true rel ids ', relations 103 | true_rel_ids = rel_ids.intersection(relations) 104 | if len(relations)>0: 105 | prec += float(len(true_rel_ids))/float(len(relations)) 106 | if len(rel_ids)>0: 107 | rec += float(len(relations))/float(len(rel_ids)) 108 | count_acc +=1. 109 | if count_acc % 1000==0: 110 | print 'Prec ' , float(prec)/float(count_acc), ' over ', int(count_acc) 111 | print 'Rec ', float(rec)/float(count_acc), ' over ', int(count_acc) 112 | except: 113 | #print 'error in utterance ' 114 | traceback.print_exc(file=sys.stdout) 115 | continue 116 | print 'Prec ' , float(prec)/float(count_acc) 117 | print 'Rec ', float(rec)/float(count_acc) 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.10.0-cp27-none-linux_x86_64.whl 2 | pattern 3 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #EXAMPLE RUN: python run_model.py 2 | cp $1/params.py . 3 | python run_model.py $1 4 | 5 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #EXAMPLE RUN: python run_test_twostep.py (QUES_TYPE_ID can be either of ‘simple’,’logical’,’quantitative’,’comparative’,’verify’,’quantitative_count’,’comparative_count’) 2 | cp $1/params_test.py . 3 | python run_test.py $1 $2 4 | -------------------------------------------------------------------------------- /run_test_jobs.sh: -------------------------------------------------------------------------------- 1 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h -1 2 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 2 3 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 3 4 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 4 5 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 5 6 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 6 7 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 7 8 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 8 9 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 9 10 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 12 11 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 13 12 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 14 13 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ easy x86_24h 15 14 | #./run_test.sh model_softmax_decoder_newversion/ easy x86_6h -1 15 | #./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 2 16 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 3 17 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 4 18 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 5 19 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 6 20 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 7 21 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 8 22 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 9 23 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 10 24 | #./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 11 25 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 12 26 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 13 27 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 14 28 | ./run_test.sh model_softmax_decoder_newversion/ easy x86_6h 15 29 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h -1 30 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 2 31 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 3 32 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 4 33 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 5 34 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 6 35 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 7 36 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 8 37 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 9 38 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 10 39 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 11 40 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 12 41 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 13 42 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 14 43 | #./run_test.sh model_softmax_kvmem_validtrim_unfilt_newversion/ hard x86_24h 15 44 | #./run_test.sh model_softmax_decoder_newversion/ hard x86_24h -1 45 | #./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 2 46 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 3 47 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 4 48 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 5 49 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 6 50 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 7 51 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 8 52 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 9 53 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 10 54 | #./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 11 55 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 12 56 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 13 57 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 14 58 | ./run_test.sh model_softmax_decoder_newversion/ hard x86_24h 15 59 | -------------------------------------------------------------------------------- /stopwords_histogram.txt: -------------------------------------------------------------------------------- 1 | a 9502 2 | a few 2470 3 | a film 1975 4 | a movie 726 5 | a writer 1349 6 | about 3947 7 | act 2523 8 | act in 2505 9 | acted 3244 10 | acted in 3244 11 | actor 863 12 | actor in 862 13 | actors 3256 14 | actors in 1604 15 | an 1089 16 | an actor 862 17 | and 2350 18 | and the 734 19 | appear 1692 20 | appear in 1692 21 | appears 861 22 | appears in 861 23 | are 4335 24 | are about 1675 25 | are the 1622 26 | as 930 27 | as director 834 28 | author 604 29 | author of 603 30 | be 2362 31 | be described 2175 32 | by 7645 33 | by this 587 34 | by who 1505 35 | can 4048 36 | can be 2175 37 | can you 1802 38 | creator 641 39 | creator of 638 40 | date 3093 41 | date of 3042 42 | describe 4979 43 | describe film 1032 44 | describe movie 1439 45 | describe the 939 46 | described 2175 47 | described by 1754 48 | describing 967 49 | did 10237 50 | direct 2046 51 | directed 7766 52 | directed by 2574 53 | directed the 2450 54 | directed which 692 55 | director 8818 56 | director for 2045 57 | director of 5549 58 | director that 873 59 | does 3904 60 | fall 593 61 | fall under 515 62 | few 2470 63 | few words 2452 64 | film 20246 65 | film did 2046 66 | film directed 703 67 | film genre 522 68 | film is 2084 69 | film script 638 70 | film the 2854 71 | film written 1242 72 | films 7346 73 | films are 843 74 | films can 880 75 | films did 3432 76 | films does 857 77 | films was 671 78 | for 7508 79 | for the 2003 80 | from 517 81 | genre 7194 82 | genre does 515 83 | genre for 1509 84 | genre is 1048 85 | genre of 4122 86 | girl 509 87 | give 1009 88 | give a 970 89 | i 670 90 | in 24254 91 | in a 1722 92 | in the 5560 93 | in which 862 94 | is 21907 95 | is a 2488 96 | is listed 834 97 | is movie 540 98 | is the 11804 99 | it 2355 100 | it released 1967 101 | john 635 102 | kind 1125 103 | kind of 1073 104 | language 2608 105 | language in 883 106 | language is 871 107 | language spoken 849 108 | last 516 109 | life 512 110 | listed 834 111 | listed as 834 112 | love 899 113 | man 1218 114 | me 563 115 | movie 19524 116 | movie did 1341 117 | movie is 2107 118 | movie the 3299 119 | movie written 693 120 | movies 9943 121 | movies are 831 122 | movies can 1295 123 | movies did 3010 124 | movies was 2901 125 | my 820 126 | night 642 127 | of 31170 128 | of a 557 129 | of film 2080 130 | of movie 2088 131 | of the 11977 132 | of which 660 133 | on 2569 134 | person 2668 135 | person directed 892 136 | person wrote 1180 137 | primary 616 138 | primary language 605 139 | release 5974 140 | release date 3042 141 | release year 2927 142 | released 7879 143 | screenplay 584 144 | screenplay for 583 145 | screenwriter 624 146 | screenwriter wrote 624 147 | script 1262 148 | script for 1262 149 | some 503 150 | sort 1038 151 | sort of 1038 152 | spoken 849 153 | spoken in 849 154 | star 2739 155 | star in 2601 156 | starred 4856 157 | starred in 1633 158 | starred which 1646 159 | starred who 1571 160 | stars 1671 161 | stars in 1627 162 | story 1734 163 | story for 1370 164 | that 1559 165 | that directed 873 166 | the 78241 167 | the actors 1604 168 | the author 603 169 | the creator 638 170 | the director 7980 171 | the film 12558 172 | the genre 4538 173 | the language 849 174 | the movie 12844 175 | the primary 605 176 | the release 5969 177 | the screenplay 583 178 | the script 624 179 | the story 1431 180 | the world 765 181 | the writer 4410 182 | this 767 183 | this person 587 184 | to 2031 185 | topics 1589 186 | topics is 1589 187 | type 996 188 | type of 996 189 | under 642 190 | was 27713 191 | was directed 925 192 | was it 1962 193 | was the 15388 194 | was who 871 195 | was written 580 196 | what 41930 197 | what does 2516 198 | what film 1225 199 | what films 5102 200 | what genre 2134 201 | what is 6598 202 | what kind 1039 203 | what language 871 204 | what movie 670 205 | what movies 5494 206 | what sort 1038 207 | what topics 532 208 | what type 996 209 | what was 7821 210 | what words 1047 211 | what year 2910 212 | when 5115 213 | when was 4969 214 | which 15947 215 | which actors 1646 216 | which film 1342 217 | which films 2238 218 | which movie 1015 219 | which movies 4443 220 | which person 2072 221 | which screenwriter 624 222 | which topics 1057 223 | which words 1499 224 | who 27772 225 | who acted 3244 226 | who are 1604 227 | who directed 2735 228 | who in 615 229 | who is 6015 230 | who starred 1634 231 | who stars 1621 232 | who was 3698 233 | who wrote 2414 234 | whos 1581 235 | whos the 1537 236 | with 1021 237 | words 5493 238 | words describe 2546 239 | words describing 967 240 | world 945 241 | world wrote 615 242 | write 4699 243 | write the 1370 244 | writer 5765 245 | writer of 4410 246 | writer on 1349 247 | written 3141 248 | written by 3136 249 | wrote 4833 250 | wrote the 3466 251 | year 5976 252 | year of 2953 253 | year was 2913 254 | you 2299 255 | you give 967 256 | -------------------------------------------------------------------------------- /text_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | Utilities for cleaning the text data 5 | """ 6 | import unicodedata 7 | 8 | def clean_word(word): 9 | word = word.strip('\n') 10 | word = word.strip('\r') 11 | word = word.lower() 12 | word = word.replace('%', '') #99 and 44/100% dead 13 | word = word.strip() 14 | word = word.replace(',', '') 15 | word = word.replace('.', '') 16 | word = word.replace('"', '') 17 | word = word.replace('\'', '') 18 | word = word.replace('?', '') 19 | word = word.replace('|', '') 20 | word = unicode(word, "utf-8") #Convert str -> unicode (Remember default encoding is ascii in python) 21 | word = unicodedata.normalize('NFKD', word).encode('ascii','ignore') #Convert normalized unicode to python str 22 | word = word.lower() #Don't remove this line, lowercase after the unicode normalization 23 | return word 24 | 25 | 26 | def clean_line(line): 27 | """ 28 | Do not replace PIPE here. 29 | """ 30 | line = line.strip('\n') 31 | line = line.strip('\r') 32 | line = line.strip() 33 | line = line.lower() 34 | return line 35 | 36 | def append_word_to_str(text, str): 37 | if len(text) == 0: 38 | return str 39 | else: 40 | return text + " " + str 41 | 42 | if __name__ == "__main__": 43 | print "__"+clean_word(" ")+"__" -------------------------------------------------------------------------------- /type_linker/annoy_index_type/glove_embedding_of_vocab.ann: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amritasaha1812/CSQA_Code/3d6724b7cab1972f2636affdde94aedecaf13978/type_linker/annoy_index_type/glove_embedding_of_vocab.ann -------------------------------------------------------------------------------- /type_linker/create_type_annoy_index.py: -------------------------------------------------------------------------------- 1 | import pattern.en 2 | import re 3 | import sys 4 | import string 5 | import traceback 6 | import json 7 | import nltk 8 | from nltk.corpus import stopwords 9 | import gensim 10 | from annoy import AnnoyIndex 11 | import os 12 | import cPickle as pkl 13 | stop = set(stopwords.words('english')) 14 | 15 | types = {} 16 | good_types = set([]) 17 | for x in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/prop_obj_90_map5.json')).values(): 18 | good_types.update(x) 19 | for x in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/prop_sub_90_map5.json')).values(): 20 | good_types.update(x) 21 | for k,v in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/child_par_dict_name_2_corr.json')).items(): 22 | if k not in good_types: 23 | continue 24 | v = set([x for x in v.lower().strip().split(' ') if x not in stop]) 25 | plur_v = set([pattern.en.pluralize(vi) for vi in v]) 26 | v = v.union(plur_v) 27 | for vi in v: 28 | if vi not in types: 29 | types[vi] = [] 30 | if k not in types[vi]: 31 | types[vi].append(k) 32 | for k,v in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/type_set_dict.json')).items(): 33 | if k not in good_types: 34 | continue 35 | v = set([x for x in v.lower().strip().split(' ') if x not in stop]) 36 | plur_v = set([pattern.en.pluralize(vi) for vi in v]) 37 | v = v.union(plur_v) 38 | for vi in v: 39 | if vi not in types: 40 | types[vi] = [] 41 | if k not in types[vi]: 42 | types[vi].append(k) 43 | json.dump(types, open('annoy_index_type/type_names.json','w'), indent=1) 44 | sys.exit(1) 45 | word2vec_pretrain_embed = gensim.models.Word2Vec.load_word2vec_format('/dccstor/cssblr/amrita/resources/glove/GoogleNews-vectors-negative300.bin', binary=True) 46 | f=300 47 | index = AnnoyIndex(f, metric='euclidean') 48 | index_desc = {} 49 | count = 0 50 | for word,ids in types.items(): 51 | if word not in word2vec_pretrain_embed: 52 | print 'could not find ::::', word 53 | continue 54 | embed = word2vec_pretrain_embed[word] 55 | index.add_item(count, embed) 56 | index_desc[count] = ids 57 | count = count + 1 58 | index.build(100) 59 | index.save('annoy_index_type/glove_embedding_of_vocab.ann') 60 | pkl.dump(index_desc, open('annoy_index_type/index2type.pkl','wb')) 61 | 62 | 63 | -------------------------------------------------------------------------------- /type_linker/perform_type_identification.py: -------------------------------------------------------------------------------- 1 | import pattern.en 2 | import re 3 | import sys 4 | import string 5 | import traceback 6 | import json 7 | import nltk 8 | from nltk.corpus import stopwords 9 | from annoy import AnnoyIndex 10 | stop = set(stopwords.words('english')) 11 | import os 12 | import cPickle as pkl 13 | import gensim 14 | wikidata_id_name_map={k:re.sub(r'[^\x00-\x7F]+',' ',v) for k,v in json.load(open('/dccstor/cssblr/vardaan/dialog-qa/item_data_filt.json')).items()} 15 | glove_embedding = gensim.models.KeyedVectors.load_word2vec_format('/dccstor/cssblr/amrita/resources/glove/GoogleNews-vectors-negative300.bin', binary=True) 16 | print 'loaded glove embeddings' 17 | ann = AnnoyIndex(300, metric='euclidean') 18 | ann.load('annoy_index_type/glove_embedding_of_vocab.ann') 19 | ann_pickle_type = pkl.load(open('annoy_index_type/index2type.pkl')) 20 | 21 | types = json.load(open('annoy_index_type/type_names.json')) 22 | def get_type(active_set): 23 | types_in_active_set = set([]) 24 | for x in active_set.split(','): 25 | if 'c(' in x: 26 | t = x.replace('c(','').replace(')','').strip().replace('(','').split('|') 27 | types_in_active_set.update(t) 28 | return types_in_active_set 29 | 30 | def get_filtered_utterance(utterance, utter): 31 | if 'entities' in utter: 32 | entities = utter['entities'] 33 | else: 34 | entities = [] 35 | if 'Qid' in utter: 36 | entities.append(utter['Qid']) 37 | entity_names = [wikidata_id_name_map[id].lower() for id in entities if id in wikidata_id_name_map] 38 | utterance_replaced = utterance 39 | for e in entity_names: 40 | if e is not None: 41 | utterance_replaced = utterance_replaced.replace(e,'') 42 | utterance_replaced = re.sub(' +',' ', utterance_replaced) 43 | words = set([x for x in utterance_replaced.split(' ') if not x.isdigit() and len(x)>1 and x in glove_embedding]) 44 | words = words - stop 45 | return words 46 | 47 | data_dir = '/dccstor/cssblr/vardaan/dialog-qa/QA_train_final6/' 48 | prec = 0.0 49 | rec = 0.0 50 | count_files = 0 51 | count_acc = 0.0 52 | for dir in os.listdir(data_dir): 53 | if 'txt' in dir or 'pickle' in dir or 'xls' in dir: 54 | continue 55 | print dir 56 | for dir2 in os.listdir(data_dir+'/'+dir): 57 | if 'txt' in dir2 or 'pickle' in dir2 or 'xls' in dir2: 58 | continue 59 | for file in os.listdir(data_dir+'/'+dir+'/'+dir2): 60 | if not file.endswith('json'): 61 | continue 62 | count_files+=1 63 | if count_files%100==0: 64 | print 'finished ',count_files 65 | data = json.load(open(data_dir+'/'+dir+'/'+dir2+'/'+file)) 66 | for utter in data: 67 | #print ('active_set' in utter), 'utterance ', utter['utterance'] 68 | 69 | try: 70 | if 'active_set' not in utter: 71 | utterance = utter['utterance'].lower() 72 | utterance = re.sub(r'[^\x00-\x7F]+',' ',utterance) 73 | utterance = str(utterance).translate(string.maketrans("",""),string.punctuation) 74 | utterance_filtered = get_filtered_utterance(utterance, utter) 75 | else: 76 | gold_types = set([]) 77 | for x in utter['active_set']: 78 | gold_types.update(get_type(x)) 79 | predicted_types = set([]) 80 | predicted_type_names = set([]) 81 | for type in types: 82 | if type in utterance: 83 | predicted_types.update(types[type]) 84 | predicted_type_names.add(type) 85 | if len(predicted_type_names)>=0: 86 | for word in utterance_filtered: 87 | sing_w = pattern.en.singularize(word) 88 | plur_w = pattern.en.pluralize(word) 89 | nn_words = set([]) 90 | if sing_w in glove_embedding: 91 | nn_words.update(ann.get_nns_by_vector(glove_embedding[sing_w], 5)) 92 | if plur_w in glove_embedding: 93 | nn_words.update(ann.get_nns_by_vector(glove_embedding[plur_w], 5)) 94 | for nn in nn_words: 95 | predicted_types.update(ann_pickle_type[nn]) 96 | 97 | ints = predicted_types.intersection(gold_types) 98 | #if len(ints)==0 and len(gold_types)>0: 99 | # print gold_types, '::::', predicted_types, ':::', predicted_type_names 100 | if len(gold_types)>0: 101 | prec += float(len(ints))/float(len(gold_types)) 102 | if len(predicted_types)>0: 103 | rec += float(len(ints))/float(len(predicted_types)) 104 | if len(gold_types)>0: 105 | count_acc += 1.0 106 | if count_acc %1000==0: 107 | print 'Prec ', prec/count_acc, ' Recall ', rec/count_acc, ' Over ', int(count_acc) 108 | #utterance = None 109 | #gold_types = None 110 | #predicted_types = None 111 | except: 112 | print traceback.print_exc() 113 | continue 114 | 115 | print 'Prec ', prec/count_acc, ' Recall ', rec/count_acc, ' Over ', int(count_acc) 116 | -------------------------------------------------------------------------------- /utils/cp_files.sh: -------------------------------------------------------------------------------- 1 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_3/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_3.txt 2 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_4/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_4.txt 3 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_5/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_5.txt 4 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_6/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_6.txt 5 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_7.txt 6 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_8.txt 7 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_9/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_9.txt 8 | # cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_10/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_10.txt 9 | 10 | cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7a/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_7a.txt 11 | cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_7b/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_7b.txt 12 | 13 | cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8a/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_8a.txt 14 | cp model_softmax_kvmem_valid_trip_unfilt/test_output_hard_8b/top20_ent_id_from_mem.txt model_softmax_kvmem_valid_trip_unfilt/prec_files/hard/top20_ent_id_from_mem_8b.txt -------------------------------------------------------------------------------- /utils/find_overlap.py: -------------------------------------------------------------------------------- 1 | import codecs, sys, pickle 2 | from itertools import izip 3 | 4 | # mem_target_file = sys.argv[1] 5 | # gold_target_file = sys.argv[2] 6 | 7 | # overlap_ratio_sum = 0 8 | # n_lines = 0 9 | # n_err = 0 10 | # with open(mem_target_file) as f1, open(gold_target_file) as f2: 11 | # for mem_target_line, gold_target_line in izip(f1, f2): 12 | # mem_target_line = mem_target_line.rstrip() 13 | # gold_target_line = gold_target_line.rstrip() 14 | 15 | # if len(mem_target_line) > 0: 16 | # mem_target_tokens = mem_target_line.split('|') 17 | # else: 18 | # mem_target_tokens = [] 19 | 20 | # if len(gold_target_line)>0: 21 | # gold_target_tokens = gold_target_line.split('|') 22 | # else: 23 | # gold_target_tokens = [] 24 | 25 | # common_tokens = set(mem_target_tokens).intersection(set(gold_target_tokens)) 26 | 27 | # if len(gold_target_tokens)>0: 28 | # overlap_ratio = len(common_tokens)*1.0/len(gold_target_tokens) 29 | # overlap_ratio_sum += overlap_ratio 30 | # n_lines += 1 31 | # print 'overlap_ratio = %f' % overlap_ratio 32 | # else: 33 | # n_err += 1 34 | 35 | # avg_overlap_ratio = overlap_ratio_sum*1.0/n_lines 36 | 37 | # print 'Avg. overlap_ratio = %f' % avg_overlap_ratio 38 | # print 'n_err = %d' % n_err 39 | 40 | data_pkl_file = sys.argv[1] 41 | # data_pkl_file = 'model_2/dump/train_data_file.pkl' 42 | 43 | data = pickle.load(open(data_pkl_file,'r')) 44 | 45 | overlap_ent = [set(data[i][2]).intersection(set(data[i][8])) for i in range(len(data))] 46 | overlap_ent_filt = [(x - set([0])) for x in overlap_ent] 47 | gold_target_len = [len(set(data[i][2]) - set([0])) for i in range(len(data))] 48 | 49 | overlap_ratio = [len(x)*1.0/y for x,y in izip(overlap_ent_filt, gold_target_len) if y != 0] 50 | 51 | avg_overlap_ratio = sum(overlap_ratio)*1.0/len(overlap_ratio) 52 | avg_nonzero_overlap_ratio = len([x for x in overlap_ent_filt if len(x)>0])*1.0/len(data) 53 | 54 | print 'Avg. overlap_ratio = %f' % avg_overlap_ratio 55 | print 'avg_nonzero_overlap_ratio = %f' % avg_nonzero_overlap_ratio 56 | 57 | 58 | -------------------------------------------------------------------------------- /utils/get_cont_chunks.py: -------------------------------------------------------------------------------- 1 | import nltk, re, copy 2 | 3 | s1 = 'washington, d.c., new york city' 4 | s2 = 'Which administrative territorial entity is the capital of United States of America and is not the location of Bonekampstraat ?' 5 | s3 = 'Epernay, hautvillers, damery' 6 | s4 = 'what about la vicogne?' 7 | s5 = 'vauciennes, Epernay, damery, hautvillers, le petit-quevilly, cumieres' 8 | s6 = 'which type of french administrative division shares border with mardeuil ?' 9 | s7 = 'metro-goldwyn-mayer, marvel studios, bad robot productions' 10 | s8 = 'which work of art were produced by metro-goldwyn-mayer, marvel studios and bad robot productions ?' 11 | s9 = 'amsterdam, athens, barcelona, beijing, beirut, bethlehem, chicago, damascus, domodedovo, famagusta, istanbul, ljubljana, los angeles, madrid, mexico city, moscow, naples, nicosia, rabat, reggio calabria, rio de janeiro, sofia, washington, d.c., boston, montreal, genoa, florence, lisbon, cali, prague, warsaw, tbilisi, havana, cusco, atlanta, belgrade' 12 | s10 = 'which administrative territorial entity is a sister city of those administrative territorial entities ?' 13 | s11 = 'Oldham Athletic A.F.C., Kilmarnock F.C.' 14 | 15 | s12 = 'Which administrative territorial entity is the capital of United States of America and is not the location of Bonekampstraat ?|washington, d.c., new york city' 16 | s13 = 'which type of french administrative division shares border with mardeuil ?|vauciennes, Epernay, damery, hautvillers, le petit-quevilly, cumieres' 17 | s14 = 'which administrative territorial entity is a sister city of those administrative territorial entities ?|amsterdam, athens, barcelona, beijing, beirut, bethlehem, chicago, damascus, domodedovo, famagusta, istanbul, ljubljana, los angeles, madrid, mexico city, moscow, naples, nicosia, rabat, reggio calabria, rio de janeiro, sofia, washington, d.c., boston, montreal, genoa, florence, lisbon, cali, prague, warsaw, tbilisi, havana, cusco, atlanta, belgrade' 18 | s15 = 'which work of art were produced by metro-goldwyn-mayer, marvel studios and bad robot productions ?|Oldham Athletic A.F.C., Kilmarnock F.C.' 19 | s16 = 'member of the European Parliament, Member of the Chamber of Deputies of the Parliament of the Czech Republic' 20 | s17 = 'Which administrative territorial entity is the capital of United States of America and is not the location of Bonekampstraat ?' 21 | 22 | text = s17 23 | 24 | lemmatizer = nltk.WordNetLemmatizer() 25 | stemmer = nltk.stem.porter.PorterStemmer() 26 | #Taken from Su Nam Kim Paper... 27 | grammar = r""" 28 | NBAR: 29 | {*} # Nouns and Adjectives, terminated with Nouns 30 | 31 | NP: 32 | {} 33 | {} # Above, connected with in/of/etc... 34 | """ 35 | chunker = nltk.RegexpParser(grammar) 36 | 37 | # values = set([]) 38 | toks = nltk.word_tokenize(text) 39 | postoks = nltk.pos_tag(toks) 40 | tree = chunker.parse(postoks) 41 | # print tree 42 | super_list = [w for w,t in tree.leaves()] 43 | return_dict = {} 44 | 45 | for subtree in tree.subtrees(): 46 | # print subtree 47 | if subtree==tree: 48 | continue 49 | chunk_list = [x[0].strip() for x in subtree.leaves()] 50 | chunk = ' '.join(chunk_list).strip() 51 | if len(chunk)<=1: 52 | continue 53 | if chunk not in return_dict: 54 | return_dict[chunk] = chunk_list 55 | # values.add(chunk) 56 | -------------------------------------------------------------------------------- /utils/get_nounphrases.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import re 3 | import copy 4 | # Used when tokenizing words 5 | sentence_re = r'''(?x) # set flag to allow verbose regexps 6 | ([A-Z])(\.[A-Z])+\.? # abbreviations, e.g. U.S.A. 7 | | \w+(-\w+)* # words with optional internal hyphens 8 | | \$?\d+(\.\d+)?%? # currency and percentages, e.g. $12.40, 82% 9 | | \.\.\. # ellipsis 10 | | [][.,;"'?():-_`] # these are separate tokens 11 | ''' 12 | 13 | lemmatizer = nltk.WordNetLemmatizer() 14 | stemmer = nltk.stem.porter.PorterStemmer() 15 | #Taken from Su Nam Kim Paper... 16 | grammar = r""" 17 | NBAR: 18 | {*} # Nouns and Adjectives, terminated with Nouns 19 | 20 | NP: 21 | {} 22 | {} # Above, connected with in/of/etc... 23 | """ 24 | chunker = nltk.RegexpParser(grammar) 25 | 26 | d="which administrative territorial entity is the capital of united states of america and is not the location of bonekampstraat ?" 27 | values = set([]) 28 | toks = nltk.word_tokenize(d) 29 | postoks = nltk.pos_tag(toks) 30 | tree = chunker.parse(postoks) 31 | print tree 32 | for subtree in tree.subtrees(): 33 | if subtree==tree: 34 | continue 35 | chunk = ' '.join([x[0].strip() for x in subtree.leaves()]).strip() 36 | if len(chunk)<=1: 37 | continue 38 | values.add(chunk) 39 | print values 40 | -------------------------------------------------------------------------------- /utils/search_entities.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys, os, lucene 3 | from lucene import * 4 | from java.io import File 5 | from org.apache.lucene.analysis.standard import StandardAnalyzer 6 | from org.apache.lucene.index import DirectoryReader, IndexReader 7 | from org.apache.lucene.index import Term 8 | from org.apache.lucene.search import BooleanClause, BooleanQuery, PhraseQuery, TermQuery 9 | from org.apache.lucene.queryparser.classic import QueryParser 10 | from org.apache.lucene.store import SimpleFSDirectory 11 | from org.apache.lucene.search import IndexSearcher 12 | from org.apache.lucene.util import Version 13 | 14 | class LuceneSearch(): 15 | def __init__(self,lucene_index_dir='/dccstor/cssblr/amrita/dialog_qa/code/prepro_lucene/lucene_index/'): 16 | lucene.initVM(vmargs=['-Djava.awt.headless=true']) 17 | directory = SimpleFSDirectory(File(lucene_index_dir)) 18 | self.searcher = IndexSearcher(DirectoryReader.open(directory)) 19 | self.num_docs_to_return =5 20 | self.ireader = IndexReader.open(directory) 21 | 22 | def search(self, value): 23 | query = TermQuery(Term("wiki_name",value.lower())) 24 | #query = BooleanQuery() 25 | #query.add(new TermQuery(Term("wikidata_name",v)),BooleanClause.Occur.SHOULD) 26 | #query.add(new TermQuery(Term("wikidata_name",v)),BooleanClause.Occur.SHOULD) 27 | scoreDocs = self.searcher.search(query, self.num_docs_to_return).scoreDocs 28 | for scoreDoc in scoreDocs: 29 | doc = self.searcher.doc(scoreDoc.doc) 30 | for f in doc.getFields(): 31 | print f.name(),':', f.stringValue(),', ' 32 | print '' 33 | print '-------------------------------------\n' 34 | 35 | if __name__=="__main__": 36 | ls = LuceneSearch() 37 | ls.search("United States") 38 | ls.search("India") 39 | ls.search("Barrack Obama") 40 | ls.search("Obama") 41 | -------------------------------------------------------------------------------- /vocabs/response_vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amritasaha1812/CSQA_Code/3d6724b7cab1972f2636affdde94aedecaf13978/vocabs/response_vocab.pkl -------------------------------------------------------------------------------- /vocabs/vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amritasaha1812/CSQA_Code/3d6724b7cab1972f2636affdde94aedecaf13978/vocabs/vocab.pkl -------------------------------------------------------------------------------- /wikidata_entities_with_digitnames.pkl: -------------------------------------------------------------------------------- 1 | c__builtin__ 2 | set 3 | p1 4 | ((lp2 5 | VQ18109127 6 | p3 7 | aVQ25588 8 | p4 9 | aVQ9554684 10 | p5 11 | aVQ4546067 12 | p6 13 | aVQ3597938 14 | p7 15 | aVQ17366607 16 | p8 17 | aVQ4029475 18 | p9 19 | aVQ4391031 20 | p10 21 | aVQ24578913 22 | p11 23 | aVQ4632083 24 | p12 25 | aVQ4184440 26 | p13 27 | aVQ964360 28 | p14 29 | aVQ20488544 30 | p15 31 | aVQ2815784 32 | p16 33 | aVQ4631551 34 | p17 35 | aVQ16145770 36 | p18 37 | aVQ764868 38 | p19 39 | aVQ4547699 40 | p20 41 | aVQ4630974 42 | p21 43 | aVQ4567398 44 | p22 45 | aVQ204417 46 | p23 47 | aVQ12299868 48 | p24 49 | aVQ4635634 50 | p25 51 | aVQ212155 52 | p26 53 | aVQ4031568 54 | p27 55 | aVQ18614572 56 | p28 57 | aVQ2808298 58 | p29 59 | aVQ135128 60 | p30 61 | aVQ1777439 62 | p31 63 | aVQ21168841 64 | p32 65 | aVQ23708589 66 | p33 67 | aVQ17412863 68 | p34 69 | aVQ4193501 70 | p35 71 | aVQ18614407 72 | p36 73 | aVQ4637553 74 | p37 75 | aVQ4193508 76 | p38 77 | aVQ3083972 78 | p39 79 | aVQ20649134 80 | p40 81 | aVQ4641763 82 | p41 83 | aVQ3392686 84 | p42 85 | aVQ25302674 86 | p43 87 | aVQ16056774 88 | p44 89 | aVQ3595790 90 | p45 91 | aVQ4300553 92 | p46 93 | aVQ23419122 94 | p47 95 | aVQ4548670 96 | p48 97 | aVQ3598992 98 | p49 99 | aVQ4548674 100 | p50 101 | aVQ216223 102 | p51 103 | aVQ4548678 104 | p52 105 | aVQ11185109 106 | p53 107 | aVQ25648862 108 | p54 109 | aVQ2075023 110 | p55 111 | aVQ229529 112 | p56 113 | aVQ1805889 114 | p57 115 | aVQ25648869 116 | p58 117 | aVQ10397298 118 | p59 119 | aVQ3281123 120 | p60 121 | aVQ4546517 122 | p61 123 | aVQ4643245 124 | p62 125 | aVQ206514 126 | p63 127 | aVQ202810 128 | p64 129 | aVQ29233 130 | p65 131 | aVQ4545526 132 | p66 133 | aVQ16823691 134 | p67 135 | aVQ124697 136 | p68 137 | aVQ3277679 138 | p69 139 | aVQ224565 140 | p70 141 | aVQ4571028 142 | p71 143 | aVQ261467 144 | p72 145 | aVQ1489273 146 | p73 147 | aVQ32786 148 | p74 149 | aVQ4573439 150 | p75 151 | aVQ7978951 152 | p76 153 | aVQ2617634 154 | p77 155 | aVQ18614896 156 | p78 157 | aVQ17221987 158 | p79 159 | aVQ207482 160 | p80 161 | aVQ3597272 162 | p81 163 | aVQ276033 164 | p82 165 | aVQ4582141 166 | p83 167 | aVQ19888945 168 | p84 169 | aVQ14509872 170 | p85 171 | aVQ2807261 172 | p86 173 | aVQ18148559 174 | p87 175 | aVQ10397150 176 | p88 177 | aVQ17997492 178 | p89 179 | aVQ4638150 180 | p90 181 | aVQ47021 182 | p91 183 | aVQ25648881 184 | p92 185 | aVQ1767127 186 | p93 187 | aVQ2476828 188 | p94 189 | aVQ5649683 190 | p95 191 | aVQ20468998 192 | p96 193 | aVQ2811327 194 | p97 195 | aVQ126093 196 | p98 197 | aVQ2059694 198 | p99 199 | aVQ4585762 200 | p100 201 | aVQ4585761 202 | p101 203 | aVQ1061257 204 | p102 205 | aVQ34148 206 | p103 207 | aVQ2119193 208 | p104 209 | aVQ4549884 210 | p105 211 | aVQ4546852 212 | p106 213 | aVQ11184856 214 | p107 215 | aVQ184605 216 | p108 217 | aVQ19607076 218 | p109 219 | aVQ2733678 220 | p110 221 | aVQ24211257 222 | p111 223 | aVQ4641091 224 | p112 225 | aVQ4565513 226 | p113 227 | aVQ9127987 228 | p114 229 | aVQ4643816 230 | p115 231 | aVQ184591 232 | p116 233 | aVQ11750190 234 | p117 235 | aVQ4643812 236 | p118 237 | aVQ4643810 238 | p119 239 | aVQ971533 240 | p120 241 | aVQ631134 242 | p121 243 | aVQ15036035 244 | p122 245 | aVQ23944773 246 | p123 247 | aVQ4549334 248 | p124 249 | aVQ18206412 250 | p125 251 | aVQ4565514 252 | p126 253 | aVQ1605778 254 | p127 255 | aVQ25648839 256 | p128 257 | aVQ20445820 258 | p129 259 | aVQ155169 260 | p130 261 | aVQ25648832 262 | p131 263 | aVQ25648833 264 | p132 265 | aVQ28267 266 | p133 267 | aVQ25648835 268 | p134 269 | aVQ4592160 270 | p135 271 | aVQ588992 272 | p136 273 | aVQ577900 274 | p137 275 | aVQ15967616 276 | p138 277 | aVQ258434 278 | p139 279 | aVQ4555538 280 | p140 281 | aVQ4646285 282 | p141 283 | aVQ10863829 284 | p142 285 | aVQ3597882 286 | p143 287 | aVQ2817750 288 | p144 289 | aVQ1415363 290 | p145 291 | aVQ4553908 292 | p146 293 | aVQ25267554 294 | p147 295 | aVQ4553902 296 | p148 297 | aVQ17507498 298 | p149 299 | aVQ3598102 300 | p150 301 | aVQ15123925 302 | p151 303 | aVQ3055389 304 | p152 305 | aVQ197208 306 | p153 307 | aVQ3640301 308 | p154 309 | aVQ207195 310 | p155 311 | aVQ1849161 312 | p156 313 | aVQ20813220 314 | p157 315 | aVQ1766848 316 | p158 317 | aVQ3280064 318 | p159 319 | aVQ19870981 320 | p160 321 | aVQ11751083 322 | p161 323 | aVQ20193198 324 | p162 325 | aVQ204404 326 | p163 327 | aVQ27668210 328 | p164 329 | aVQ12299970 330 | p165 331 | aVQ277297 332 | p166 333 | aVQ208283 334 | p167 335 | aVQ2818297 336 | p168 337 | aVQ158611 338 | p169 339 | aVQ2700640 340 | p170 341 | aVQ25136837 342 | p171 343 | aVQ931752 344 | p172 345 | aVQ4567088 346 | p173 347 | aVQ2811296 348 | p174 349 | aVQ17619599 350 | p175 351 | aVQ9132938 352 | p176 353 | aVQ4637338 354 | p177 355 | aVQ10397196 356 | p178 357 | aVQ129647 358 | p179 359 | aVQ2818471 360 | p180 361 | aVQ2375496 362 | p181 363 | aVQ175014 364 | p182 365 | aVQ17644407 366 | p183 367 | aVQ3018412 368 | p184 369 | aVQ667863 370 | p185 371 | aVQ12100164 372 | p186 373 | aVQ4553523 374 | p187 375 | aVQ21065414 376 | p188 377 | aVQ22913800 378 | p189 379 | aVQ226027 380 | p190 381 | aVQ19559250 382 | p191 383 | aVQ9132942 384 | p192 385 | aVQ4642490 386 | p193 387 | aVQ131390 388 | p194 389 | aVQ25648876 390 | p195 391 | aVQ3285292 392 | p196 393 | aVQ229553 394 | p197 395 | aVQ11188271 396 | p198 397 | aVQ4641100 398 | p199 399 | aVQ11188373 400 | p200 401 | aVQ11185210 402 | p201 403 | aVQ1626606 404 | p202 405 | aVQ11849625 406 | p203 407 | aVQ3599928 408 | p204 409 | aVQ4638179 410 | p205 411 | aVQ18159512 412 | p206 413 | aVQ1755859 414 | p207 415 | aVQ203383 416 | p208 417 | aVQ18701759 418 | p209 419 | aVQ4635159 420 | p210 421 | aVQ11185850 422 | p211 423 | aVQ4581519 424 | p212 425 | aVQ4581518 426 | p213 427 | aVQ948546 428 | p214 429 | aVQ4574404 430 | p215 431 | aVQ3977900 432 | p216 433 | aVQ3600263 434 | p217 435 | aVQ5652962 436 | p218 437 | aVQ5652964 438 | p219 439 | aVQ170487 440 | p220 441 | aVQ4632697 442 | p221 443 | aVQ3597864 444 | p222 445 | aVQ2812761 446 | p223 447 | aVQ2034 448 | p224 449 | aVQ4595910 450 | p225 451 | aVQ4580271 452 | p226 453 | aVQ2422126 454 | p227 455 | aVQ25648804 456 | p228 457 | aVQ25648807 458 | p229 459 | aVQ11186233 460 | p230 461 | aVQ20443695 462 | p231 463 | aVQ4630525 464 | p232 465 | aVQ4582810 466 | p233 467 | aVQ4630520 468 | p234 469 | aVQ55537 470 | p235 471 | aVQ55532 472 | p236 473 | aVQ4620176 474 | p237 475 | aVQ277435 476 | p238 477 | aVQ4646083 478 | p239 479 | aVQ4356001 480 | p240 481 | aVQ4634302 482 | p241 483 | aVQ208994 484 | p242 485 | aVQ222829 486 | p243 487 | aVQ3868542 488 | p244 489 | aVQ4569568 490 | p245 491 | aVQ25648899 492 | p246 493 | aVQ27668215 494 | p247 495 | aVQ4630415 496 | p248 497 | aVQ4630416 498 | p249 499 | aVQ25302844 500 | p250 501 | aVQ203560 502 | p251 503 | aVQ1212912 504 | p252 505 | aVQ19359159 506 | p253 507 | aVQ4594433 508 | p254 509 | aVQ4637373 510 | p255 511 | aVQ20871845 512 | p256 513 | aVQ19883027 514 | p257 515 | aVQ17063445 516 | p258 517 | aVQ208902 518 | p259 519 | aVQ4032783 520 | p260 521 | aVQ25648801 522 | p261 523 | aVQ4547308 524 | p262 525 | aVQ4546092 526 | p263 527 | aVQ82939 528 | p264 529 | aVQ4546098 530 | p265 531 | aVQ10397411 532 | p266 533 | aVQ3537233 534 | p267 535 | aVQ4577534 536 | p268 537 | aVQ4577535 538 | p269 539 | aVQ3235347 540 | p270 541 | aVQ18623 542 | p271 543 | aVQ25648805 544 | p272 545 | aVQ4630516 546 | p273 547 | aVQ152273 548 | p274 549 | aVQ14613598 550 | p275 551 | aVQ4560603 552 | p276 553 | aVQ4633272 554 | p277 555 | aVQ4633275 556 | p278 557 | aVQ4633277 558 | p279 559 | aVQ4633279 560 | p280 561 | aVQ185112 562 | p281 563 | aVQ4575970 564 | p282 565 | aVQ4031363 566 | p283 567 | aVQ939380 568 | p284 569 | aVQ20015844 570 | p285 571 | aVQ20751307 572 | p286 573 | aVQ16825231 574 | p287 575 | aVQ9136301 576 | p288 577 | aVQ3211638 578 | p289 579 | aVQ23638008 580 | p290 581 | aVQ241608 582 | p291 583 | aVQ19007151 584 | p292 585 | aVQ25648848 586 | p293 587 | aVQ25648834 588 | p294 589 | aVQ4642994 590 | p295 591 | aVQ25648843 592 | p296 593 | aVQ25648842 594 | p297 595 | aVQ2816861 596 | p298 597 | aVQ3545403 598 | p299 599 | aVQ4573442 600 | p300 601 | aVQ13485635 602 | p301 603 | aVQ4646283 604 | p302 605 | aVQ4640566 606 | p303 607 | aVQ4553147 608 | p304 609 | aVQ4573445 610 | p305 611 | aVQ16149980 612 | p306 613 | aVQ4646279 614 | p307 615 | aVQ3597633 616 | p308 617 | aVQ16525878 618 | p309 619 | aVQ20127610 620 | p310 621 | aVQ3597916 622 | p311 623 | aVQ4636448 624 | p312 625 | aVQ3291199 626 | p313 627 | aVQ4636443 628 | p314 629 | aVQ18614738 630 | p315 631 | aVQ2815195 632 | p316 633 | aVQ5653082 634 | p317 635 | aVQ4544986 636 | p318 637 | aVQ4028807 638 | p319 639 | aVQ4646277 640 | p320 641 | aVQ3596925 642 | p321 643 | aVQ4646278 644 | p322 645 | aVQ622724 646 | p323 647 | aVQ4631472 648 | p324 649 | aVQ276071 650 | p325 651 | aVQ3283529 652 | p326 653 | aVQ4546290 654 | p327 655 | aVQ640217 656 | p328 657 | aVQ4643813 658 | p329 659 | aVQ20888386 660 | p330 661 | aVQ4555731 662 | p331 663 | aVQ4638131 664 | p332 665 | aVQ4643481 666 | p333 667 | aVQ16525661 668 | p334 669 | aVQ2945964 670 | p335 671 | aVQ2399032 672 | p336 673 | aVQ1335092 674 | p337 675 | aVQ4641742 676 | p338 677 | aVQ4545531 678 | p339 679 | aVQ208824 680 | p340 681 | aVQ16241480 682 | p341 683 | aVQ377637 684 | p342 685 | aVQ10128332 686 | p343 687 | aVQ4558485 688 | p344 689 | aVQ1150961 690 | p345 691 | aVQ1412590 692 | p346 693 | aVQ4591118 694 | p347 695 | aVQ4640296 696 | p348 697 | aVQ18788265 698 | p349 699 | aVQ4549928 700 | p350 701 | aVQ16240051 702 | p351 703 | aVQ3281075 704 | p352 705 | aVQ544353 706 | p353 707 | aVQ27055451 708 | p354 709 | aVQ4565038 710 | p355 711 | aVQ164702 712 | p356 713 | aVQ11186264 714 | p357 715 | aVQ3597774 716 | p358 717 | aVQ4645691 718 | p359 719 | aVQ18615323 720 | p360 721 | aVQ20826648 722 | p361 723 | aVQ11186089 724 | p362 725 | aVQ16057084 726 | p363 727 | aVQ488075 728 | p364 729 | aVQ4634411 730 | p365 731 | aVQ25648798 732 | p366 733 | aVQ372444 734 | p367 735 | aVQ25648794 736 | p368 737 | aVQ25648793 738 | p369 739 | aVQ25648812 740 | p370 741 | aVQ25648813 742 | p371 743 | aVQ25648811 744 | p372 745 | aVQ25648814 746 | p373 747 | aVQ21598432 748 | p374 749 | aVQ4554738 750 | p375 751 | aVQ27267646 752 | p376 753 | aVQ4355460 754 | p377 755 | aVQ11188392 756 | p378 757 | aVQ4644625 758 | p379 759 | aVQ3599350 760 | p380 761 | aVQ15665645 762 | p381 763 | aVQ685785 764 | p382 765 | aVQ379484 766 | p383 767 | aVQ836687 768 | p384 769 | aVQ4644408 770 | p385 771 | aVQ219437 772 | p386 773 | aVQ213991 774 | p387 775 | aVQ213999 776 | p388 777 | aVQ4029001 778 | p389 779 | aVQ4634464 780 | p390 781 | aVQ4638175 782 | p391 783 | aVQ1791385 784 | p392 785 | aVQ1946716 786 | p393 787 | aVQ1248238 788 | p394 789 | aVQ1152631 790 | p395 791 | aVQ664787 792 | p396 793 | aVQ1948084 794 | p397 795 | aVQ5642787 796 | p398 797 | aVQ224638 798 | p399 799 | aVQ16202471 800 | p400 801 | aVQ2631364 802 | p401 803 | aVQ19804229 804 | p402 805 | aVQ4630463 806 | p403 807 | aVQ20426442 808 | p404 809 | aVQ16385008 810 | p405 811 | aVQ19689679 812 | p406 813 | aVQ3524094 814 | p407 815 | aVQ4554959 816 | p408 817 | aVQ141481 818 | p409 819 | aVQ4348357 820 | p410 821 | aVQ2807090 822 | p411 823 | aVQ4640152 824 | p412 825 | aVQ15868282 826 | p413 827 | aVQ15868283 828 | p414 829 | aVQ4633268 830 | p415 831 | aVQ2813204 832 | p416 833 | aVQ16864644 834 | p417 835 | aVQ17620299 836 | p418 837 | aVQ2807587 838 | p419 839 | aVQ17193218 840 | p420 841 | aVQ26989920 842 | p421 843 | aVQ4605626 844 | p422 845 | aVQ204470 846 | p423 847 | aVQ4633263 848 | p424 849 | aVQ206182 850 | p425 851 | aVQ4633266 852 | p426 853 | aVQ4633264 854 | p427 855 | aVQ4633269 856 | p428 857 | aVQ1141445 858 | p429 859 | aVQ4547751 860 | p430 861 | aVQ4547615 862 | p431 863 | aVQ4555320 864 | p432 865 | aVQ4637593 866 | p433 867 | aVQ208461 868 | p434 869 | aVQ269997 870 | p435 871 | aVQ2388537 872 | p436 873 | aVQ186799 874 | p437 875 | aVQ2813946 876 | p438 877 | aVQ25648856 878 | p439 879 | aVQ25648855 880 | p440 881 | aVQ25648853 882 | p441 883 | aVQ25648851 884 | p442 885 | aVQ25648914 886 | p443 887 | aVQ4580269 888 | p444 889 | aVQ3283010 890 | p445 891 | aVQ4589183 892 | p446 893 | aVQ4645111 894 | p447 895 | aVQ3599091 896 | p448 897 | aVQ3597744 898 | p449 899 | aVQ10351972 900 | p450 901 | aVQ2267126 902 | p451 903 | aVQ4643277 904 | p452 905 | aVQ1653284 906 | p453 907 | aVQ4636458 908 | p454 909 | aVQ4638734 910 | p455 911 | aVQ4636452 912 | p456 913 | aVQ4636451 914 | p457 915 | aVQ211685 916 | p458 917 | aVQ182280 918 | p459 919 | aVQ426069 920 | p460 921 | aVQ20200665 922 | p461 923 | aVQ2812793 924 | p462 925 | aVQ25648823 926 | p463 927 | aVQ20986886 928 | p464 929 | aVQ186689 930 | p465 931 | aVQ2818274 932 | p466 933 | aVQ4548302 934 | p467 935 | aVQ4607772 936 | p468 937 | aVQ20195667 938 | p469 939 | aVQ208837 940 | p470 941 | aVQ2806799 942 | p471 943 | aVQ25136741 944 | p472 945 | aVQ21980599 946 | p473 947 | aVQ4189846 948 | p474 949 | aVQ4632144 950 | p475 951 | aVQ3598709 952 | p476 953 | aVQ4547309 954 | p477 955 | aVQ21161474 956 | p478 957 | aVQ208429 958 | p479 959 | aVQ4549328 960 | p480 961 | aVQ4546849 962 | p481 963 | aVQ720620 964 | p482 965 | aVQ4546845 966 | p483 967 | aVQ3156755 968 | p484 969 | aVQ901789 970 | p485 971 | aVQ1954044 972 | p486 973 | aVQ20539435 974 | p487 975 | aVQ4642330 976 | p488 977 | aVQ18635413 978 | p489 979 | aVQ4639012 980 | p490 981 | aVQ4639015 982 | p491 983 | aVQ20445795 984 | p492 985 | aVQ20445797 986 | p493 987 | aVQ195576 988 | p494 989 | aVQ21002287 990 | p495 991 | aVQ5650997 992 | p496 993 | aVQ2151749 994 | p497 995 | aVQ2481193 996 | p498 997 | aVQ4548288 998 | p499 999 | aVQ216269 1000 | p500 1001 | aVQ974456 1002 | p501 1003 | aVQ25648827 1004 | p502 1005 | aVQ25648826 1006 | p503 1007 | aVQ25648825 1008 | p504 1009 | aVQ25648824 1010 | p505 1011 | aVQ783262 1012 | p506 1013 | aVQ2734213 1014 | p507 1015 | aVQ2338691 1016 | p508 1017 | aVQ9134609 1018 | p509 1019 | aVQ174997 1020 | p510 1021 | aVQ4632886 1022 | p511 1023 | aVQ693355 1024 | p512 1025 | aVQ238235 1026 | p513 1027 | aVQ1055986 1028 | p514 1029 | aVQ56194 1030 | p515 1031 | atRp516 1032 | . -------------------------------------------------------------------------------- /words2number.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | 3 | def text2int(textnum, numwords={}): 4 | textnum = textnum.replace(',','') 5 | textnum = textnum.replace('-',' ') 6 | 7 | if not numwords: 8 | units = [ 9 | "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", 10 | "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", 11 | "sixteen", "seventeen", "eighteen", "nineteen", 12 | ] 13 | 14 | tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] 15 | 16 | scales = ["hundred", "thousand", "million", "billion", "trillion"] 17 | 18 | numwords["and"] = (1, 0) 19 | for idx, word in enumerate(units): numwords[word] = (1, idx) 20 | for idx, word in enumerate(tens): numwords[word] = (1, idx * 10) 21 | for idx, word in enumerate(scales): numwords[word] = (10 ** (idx * 3 or 2), 0) 22 | 23 | current = result = 0 24 | for word in textnum.split(): 25 | if word not in numwords: 26 | raise Exception("Illegal word: " + word) 27 | 28 | scale, increment = numwords[word] 29 | current = current * scale + increment 30 | if scale > 100: 31 | result += current 32 | current = 0 33 | 34 | return result + current 35 | ''' 36 | inf_eng = inflect.engine() 37 | num = 1995 38 | 39 | x = inf_eng.number_to_words(int(num)) 40 | print x 41 | print text2int(x) 42 | 43 | for n in range(1000000): 44 | if n % 100 == 0: 45 | print n 46 | n_rec = text2int(inf_eng.number_to_words(int(n))) 47 | try: 48 | assert n == n_rec 49 | except: 50 | print n 51 | ''' 52 | --------------------------------------------------------------------------------