├── .gitignore ├── README.md ├── pathnet ├── __init__.py ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── obqa_data_reader │ │ ├── __init__.py │ │ └── data_reader_obqa.py │ ├── rel_vocab.py │ └── wikihop_data_reader │ │ ├── __init__.py │ │ └── data_reader_wikihop.py ├── model │ ├── __init__.py │ └── qa_with_raw │ │ ├── __init__.py │ │ ├── pathnet_full_modular.py │ │ └── pathnet_semi_modular.py ├── nn │ ├── __init__.py │ ├── layers.py │ └── util.py ├── pathfinder │ ├── __init__.py │ ├── obqa_path_extractor.py │ ├── path_extractor.py │ └── util.py ├── predictors │ ├── __init__.py │ └── wikihop_predictor.py └── tokenizers │ ├── __init__.py │ └── spacy_tokenizer.py ├── requirements.txt ├── scripts ├── __init__.py ├── break_orig_wikihop_train.py ├── break_train_data_obqa.py ├── break_train_data_wikihop.py ├── download.sh ├── evaluator.py ├── expand_vocabulary.py ├── install_requirements.sh ├── path_adjustments_obqa.sh ├── path_adjustments_wikihop.sh ├── path_finder_obqa.sh ├── path_finder_wikihop.sh ├── path_finder_wrapper.py ├── predict_wikihop.sh ├── prepare_outfile.py ├── prepro │ ├── __init__.py │ ├── obqa_path_finder.py │ ├── obqa_prep_data_with_lemma.py │ ├── path_finder_wikihop.py │ ├── preprocess_obqa.py │ ├── preprocess_wikihop.py │ └── wikihop_prep_data_with_lemma.py ├── preprocess_obqa.sh ├── preprocess_wikihop.sh ├── run_full_obqa.sh └── run_full_wikihop.sh └── training_configs ├── config_obqa.json ├── config_wikihop.json └── config_wikihop_makevocab.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.pyc 107 | /data/ 108 | /models/ 109 | .idea/ 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PathNet: Exploiting Explicit Paths for Multi-hop Reading Comprehension 2 | 3 | 4 | This repository contains the source code for the paper [Exploiting Explicit Paths for Multi-hop Reading Comprehension](https://arxiv.org/abs/1811.01127). 5 | This work was published at Association of Computational Linguistics (ACL) 2019. 6 | If you find the paper or this repository helpful in your work, please use the following citation: 7 | 8 | ``` 9 | @inproceedings{pathnet, 10 | title={ Exploiting Explicit Paths for Multi-hop Reading Comprehension }, 11 | author={ Souvik Kundu and Tushar Khot and Ashish Sabharwal and Peter Clark }, 12 | booktitle={ ACL }, 13 | year={ 2019 } 14 | } 15 | ``` 16 | 17 | ### Setup 18 | 19 | We used Python-3.6.2. Consider creating a virtual/conda environment for development. 20 | This code repository is built using [AllenNLP](https://github.com/allenai/allennlp). 21 | To install all dependencies please run the following: 22 | ```bash 23 | sh scripts/install_requirements.sh 24 | ``` 25 | 26 | 27 | ### Download 28 | 29 | To download all the required files, run `scripts/download.sh` 30 | 31 | 32 | ### Prediction for WikiHop 33 | 34 | Once you run the `scripts/download.sh`, you should have our pretrained model for WikiHop 35 | in `data/datasets/WikiHop/pretrained-model/`. 36 | For generating the predictions using this model, follow the steps given in `scripts/predict_wikihop.sh`. 37 | 38 | 39 | ### Training 40 | 41 | Follow the steps in `scripts/run_full_wikihop.sh` and `scripts/run_full_obqa.sh` for training new models 42 | for WikiHop and OBQA, respectively. 43 | 44 | 45 | ### Path Extraction Demo 46 | 47 | Run the `scripts/path_finder_wrapper.py` for simply visualizing the paths. 48 | 49 | ``` 50 | >>> from scripts.path_finder_wrapper import find_paths 51 | >>> documents = ["...this is doc 1 ...", "...this is doc 2 ...", ...] 52 | >>> question = "question text" 53 | >>> candidate = "candidate text" 54 | >>> pathlist = find_paths(documents, question, candidate, style="plain") 55 | -------------------------------------------------------------------------------- /pathnet/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.data.wikihop_data_reader.data_reader_wikihop import WikiHopMultiChoiceJsonReader 2 | from pathnet.data.obqa_data_reader.data_reader_obqa import OBQAMultiChoiceJsonReader 3 | -------------------------------------------------------------------------------- /pathnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.data.wikihop_data_reader.data_reader_wikihop import WikiHopMultiChoiceJsonReader 2 | from pathnet.data.obqa_data_reader.data_reader_obqa import OBQAMultiChoiceJsonReader 3 | -------------------------------------------------------------------------------- /pathnet/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import re 3 | import json 4 | from copy import deepcopy 5 | from numpy import array 6 | from typing import List, Tuple, Dict, Any 7 | from nltk import PorterStemmer 8 | 9 | stemmer = PorterStemmer() 10 | 11 | POSTAGS = ['UNKPOSTAG', '', 'VBP', '.', 'MD', 'SYM', 'VBD', 'POS', 'NNS', 'VBZ', 12 | 'PRP$', 'IN', '``', 'NN', 'WP', 'VBG', "''", 'TO', 'PRP', 13 | '-RRB-', 'LS', 'JJR', 'ADD', '$', 'UH', 'JJS', 'WP$', 'AFX', 14 | 'NNPS', 'VB', 'CD', 'DT', ':', ',', 'VBN', '_SP', '-LRB-', 'EX', 15 | 'RBS', 'WDT', 'FW', 'HYPH', 'PDT', 'RB', 'RP', 'CC', 'WRB', 'JJ', 16 | 'NFP', 'RBR', 'XX', 'NNP'] 17 | NERTAGS = ['UNKNERTAG', '', 'CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 18 | 'LAW', 'LOC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 'PERCENT', 19 | 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART'] 20 | 21 | 22 | def extract_ent_reps(docreps: List[array], 23 | docidx: int, locs: List[Tuple[int, int]], 24 | aggregate: str = 'max') -> array: 25 | """ 26 | :param docreps: N x T x H 27 | :param docidx: 28 | :param locs: 29 | :param aggregate: 30 | :return: 31 | """ 32 | hdim = docreps[0].shape[-1] # H 33 | combined_locs = [] 34 | for loc in locs: 35 | if loc[0] is None or loc[0] == -1: 36 | # x = numpy.zeros(hdim) 37 | # y = numpy.zeros(hdim) 38 | combined = numpy.zeros(hdim) 39 | else: 40 | # x = docreps[docidx][loc[0]] 41 | # y = docreps[docidx][loc[1]] 42 | combined = numpy.mean(docreps[docidx][loc[0]:loc[1] + 1, :], 0) 43 | if numpy.isnan(combined).any(): 44 | combined = numpy.zeros(hdim) 45 | # x, y format (TODO for other func types) 46 | # combined = numpy.concatenate((x, y)) 47 | combined_locs.append(combined) 48 | combined_locs = array(combined_locs) 49 | if aggregate == 'max': 50 | return numpy.max(combined_locs, 0) 51 | else: 52 | raise NotImplementedError 53 | # return None 54 | 55 | 56 | def pack_span_idxs(span_locs: List[List[List[Tuple[int, int]]]]): 57 | """ 58 | packing of span indices 59 | :param span_locs: C * P * L * 2 60 | :return: 61 | """ 62 | all_span_locs = sum(sum(span_locs, []), []) # CPL * 2 63 | loc_tracks = [] 64 | for cidx, c in enumerate(span_locs): 65 | for pidx, p in enumerate(c): 66 | for lidx, l in enumerate(p): 67 | loc_tracks.append([cidx, pidx, lidx]) 68 | return all_span_locs, loc_tracks 69 | 70 | 71 | def pack_doc_idxs(p_list: List[List[int]], 72 | span_locs: List[List[List[Tuple[int, int]]]]) -> List[int]: 73 | """ 74 | packing of document indices 75 | :param p_list: 76 | :param span_locs: 77 | :return: 78 | """ 79 | p_list_loc_rptd = [] 80 | for cidx, c in enumerate(span_locs): 81 | for pidx, p in enumerate(c): 82 | p_list_loc_rptd += [p_list[cidx][pidx] for _ in range(len(p))] 83 | return p_list_loc_rptd 84 | 85 | 86 | def find_cand_locs_fromdoclist(docs: List[str], 87 | choice_text_list: List[str], 88 | lowercase: bool = True): 89 | """ 90 | find candidate locations from a given document list 91 | :param docs: 92 | :param choice_text_list: 93 | :param lowercase: 94 | :return: 95 | """ 96 | offsets = [create_offsets(doc.split()) for doc in docs] 97 | if lowercase: 98 | docs = [' '.join(doc.split()).lower() for doc in docs] 99 | else: 100 | docs = [' '.join(doc.split()) for doc in docs] 101 | all_docidxs = [] 102 | all_spans = [] 103 | 104 | for choice_id, choice_text in enumerate(choice_text_list): 105 | # find the choice locations in doc 106 | if lowercase: 107 | choice_text = choice_text.lower() 108 | objs = [] 109 | docidxs = [] 110 | try: 111 | pat = re.compile('(^|\W)' + choice_text + '\W') 112 | for didx, doc in enumerate(docs): 113 | doc_objs = [] 114 | for x in pat.finditer(doc): 115 | doc_objs.append(x) 116 | objs += doc_objs 117 | docidxs += [didx] * len(doc_objs) 118 | except: 119 | print(f"Could not compile candidate {choice_text}") 120 | 121 | # might want to initialize with [(-1, -1)] as there will be path scores to avoid nan 122 | choice_spans = [(-1, -1)] # [(-1, -1)] 123 | found_words = [ob.group(0) for ob in objs] 124 | if len(objs) > 0: 125 | choice_ch_spans = [ob.span()[0] for ob in objs] 126 | assert len(found_words) == len(choice_ch_spans) 127 | widxs = [] 128 | for fwidx, fw in enumerate(found_words): 129 | start_offset = len(fw) - len(fw.lstrip()) 130 | choice_ch_spans[fwidx] += start_offset 131 | found_words[fwidx] = fw.strip() 132 | widx = get_widxs_from_chidxs([choice_ch_spans[fwidx]], 133 | deepcopy(offsets[docidxs[fwidx]]))[0] 134 | widxs.append(widx) 135 | 136 | # widxs = get_widxs_from_chidxs(choice_ch_spans, deepcopy(offsets)) 137 | choice_spans = [(widxs[i], widxs[i] + len(found_words[i].split()) - 1) 138 | for i in range(len(widxs))] 139 | if len(docidxs) == 0: 140 | docidxs = [0] 141 | assert len(choice_spans) == len(docidxs) 142 | all_spans.append(choice_spans) 143 | all_docidxs.append(docidxs) 144 | return all_spans, all_docidxs 145 | 146 | 147 | def get_widxs_from_chidxs(chidxs: List[int], 148 | offsets: List[List[int]]) -> List[int]: 149 | """ 150 | Find word indices given character indices 151 | :param chidxs: 152 | :param offsets: 153 | :return: 154 | """ 155 | last_ch_idx = offsets[0][0] 156 | assert max(chidxs) < offsets[-1][1] - last_ch_idx 157 | widxs = [] 158 | for chidx in chidxs: 159 | for oi in range(len(offsets)): 160 | if chidx in range(offsets[oi][0] - last_ch_idx, offsets[oi][1] - last_ch_idx): 161 | widxs.append(oi) 162 | break 163 | elif chidx in range(offsets[oi][1] - last_ch_idx, 164 | offsets[min(oi + 1, len(offsets))][0] - last_ch_idx): 165 | widxs.append(oi) 166 | break 167 | assert len(chidxs) == len(widxs) 168 | return widxs 169 | 170 | 171 | def create_offsets(doctoks: List[str]) -> List[List[int]]: 172 | """ 173 | create offsets for a document tokens 174 | :param doctoks: 175 | :return: 176 | """ 177 | offsets = [] 178 | char_count = 0 179 | for tok in doctoks: 180 | offsets.append([char_count, char_count + len(tok)]) 181 | char_count = char_count + len(tok) + 1 182 | return offsets 183 | 184 | 185 | def get_locs_forall(paths, doc_sent_boundary_dict, 186 | num_word_shift_dict, mod_docsents): 187 | """ 188 | get locations for corresponding to all the paths 189 | :param paths: 190 | :param doc_sent_boundary_dict: 191 | :param num_word_shift_dict: 192 | :param mod_docsents: 193 | :return: 194 | """ 195 | num_docs = len(mod_docsents) 196 | p1_list: List[List[int]] = [] # cands * num_paths 197 | p2_list: List[List[int]] = [] 198 | 199 | he_locs_list, e1wh_locs_list = [], [] # C * P * L * 2 200 | e1_locs_list, ca_locs_list = [], [] 201 | 202 | for cidx in range(len(paths)): 203 | p1s = [paths[cidx][i]['he_docidx'] if paths[cidx][i]['he_docidx'] is not None else 0 204 | for i in range(len(paths[cidx]))] 205 | p2s = [paths[cidx][i]['cand_docidx'] if paths[cidx][i]['cand_docidx'] is not None 206 | else num_docs - 1 207 | for i in range(len(paths[cidx]))] 208 | p1_list.append(p1s) 209 | p2_list.append(p2s) 210 | 211 | he_locs = [paths[cidx][i]['he_locs'] if paths[cidx][i]['he_locs'] is not None 212 | else [(-1, -1)] 213 | for i in range(len(paths[cidx]))] 214 | e1wh_locs = [paths[cidx][i]['e1wh_loc'] if paths[cidx][i]['e1wh_loc'] is not None 215 | else [(-1, -1)] 216 | for i in range(len(paths[cidx]))] 217 | e1_locs = [paths[cidx][i]['e1_locs'] if paths[cidx][i]['e1_locs'] is not None 218 | else [(-1, -1)] 219 | for i in range(len(paths[cidx]))] 220 | ca_locs = [paths[cidx][i]['cand_locs'] if paths[cidx][i]['cand_locs'] is not None 221 | else [(-1, -1)] 222 | for i in range(len(paths[cidx]))] 223 | 224 | if doc_sent_boundary_dict is not None and num_word_shift_dict is not None: 225 | for pidx in range(len(paths[cidx])): 226 | for locidx in range(len(he_locs[pidx])): 227 | if he_locs[pidx][locidx][0] != -1: 228 | he_locs[pidx][locidx] = (he_locs[pidx][locidx][0] - 229 | num_word_shift_dict[p1s[pidx]], 230 | he_locs[pidx][locidx][1] - 231 | num_word_shift_dict[p1s[pidx]]) 232 | if he_locs[pidx][locidx][0] > len(sum(mod_docsents[p1s[pidx]], [])) - 1: 233 | he_locs[pidx][locidx] = (len(sum(mod_docsents[p1s[pidx]], [])) - 1, 234 | len(sum(mod_docsents[p1s[pidx]], [])) - 1) 235 | 236 | for locidx in range(len(e1wh_locs[pidx])): 237 | if e1wh_locs[pidx][locidx][0] != -1: 238 | e1wh_locs[pidx][locidx] = ( 239 | e1wh_locs[pidx][locidx][0] - num_word_shift_dict[p1s[pidx]], 240 | e1wh_locs[pidx][locidx][1] - num_word_shift_dict[p1s[pidx]]) 241 | if e1wh_locs[pidx][locidx][0] > len(sum(mod_docsents[p1s[pidx]], [])) - 1: 242 | e1wh_locs[pidx][locidx] = (len(sum(mod_docsents[p1s[pidx]], [])) - 1, 243 | len(sum(mod_docsents[p1s[pidx]], [])) - 1) 244 | 245 | for locidx in range(len(e1_locs[pidx])): 246 | if e1_locs[pidx][locidx][0] != -1: 247 | e1_locs[pidx][locidx] = (e1_locs[pidx][locidx][0] - 248 | num_word_shift_dict[p2s[pidx]], 249 | e1_locs[pidx][locidx][1] - 250 | num_word_shift_dict[p2s[pidx]]) 251 | if e1_locs[pidx][locidx][0] > len(sum(mod_docsents[p2s[pidx]], [])) - 1: 252 | e1_locs[pidx][locidx] = (len(sum(mod_docsents[p2s[pidx]], [])) - 1, 253 | len(sum(mod_docsents[p2s[pidx]], [])) - 1) 254 | 255 | for locidx in range(len(ca_locs[pidx])): 256 | if ca_locs[pidx][locidx][0] != -1: 257 | ca_locs[pidx][locidx] = (ca_locs[pidx][locidx][0] - 258 | num_word_shift_dict[p2s[pidx]], 259 | ca_locs[pidx][locidx][1] - 260 | num_word_shift_dict[p2s[pidx]]) 261 | if ca_locs[pidx][locidx][0] > len(sum(mod_docsents[p2s[pidx]], [])) - 1: 262 | ca_locs[pidx][locidx] = (len(sum(mod_docsents[p2s[pidx]], [])) - 1, 263 | len(sum(mod_docsents[p2s[pidx]], [])) - 1) 264 | 265 | he_locs_list.append(he_locs) 266 | e1wh_locs_list.append(e1wh_locs) 267 | e1_locs_list.append(e1_locs) 268 | ca_locs_list.append(ca_locs) 269 | 270 | return p1_list, p2_list, he_locs_list, e1wh_locs_list, e1_locs_list, ca_locs_list 271 | 272 | 273 | def get_max_locs(paths, he_locs_list, e1wh_locs_list, 274 | e1_locs_list, ca_locs_list): 275 | """ 276 | obtaining the maximum number of locations 277 | :param paths: 278 | :param he_locs_list: 279 | :param e1wh_locs_list: 280 | :param e1_locs_list: 281 | :param ca_locs_list: 282 | :return: 283 | """ 284 | max_he_locs = 0 285 | max_e1wh_locs = 0 286 | max_e1_locs = 0 287 | max_ca_locs = 0 288 | for cidx in range(len(paths)): 289 | for pidx in range(len(paths[cidx])): 290 | num_he_locs = len(he_locs_list[cidx][pidx]) 291 | max_he_locs = max(max_he_locs, num_he_locs) 292 | 293 | num_e1wh_locs = len(e1wh_locs_list[cidx][pidx]) 294 | max_e1wh_locs = max(max_e1wh_locs, num_e1wh_locs) 295 | 296 | num_e1_locs = len(e1_locs_list[cidx][pidx]) 297 | max_e1_locs = max(max_e1_locs, num_e1_locs) 298 | 299 | num_ca_locs = len(ca_locs_list[cidx][pidx]) 300 | max_ca_locs = max(max_ca_locs, num_ca_locs) 301 | 302 | return max_he_locs, max_e1wh_locs, max_e1_locs, max_ca_locs 303 | -------------------------------------------------------------------------------- /pathnet/data/obqa_data_reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/PathNet/e47be03584f59beef1ea770b2894dadcaee75fc4/pathnet/data/obqa_data_reader/__init__.py -------------------------------------------------------------------------------- /pathnet/data/obqa_data_reader/data_reader_obqa.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any, Tuple 2 | import json 3 | import logging 4 | 5 | from overrides import overrides 6 | 7 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 8 | from allennlp.data.fields import Field, TextField, LabelField, ListField, MetadataField, \ 9 | IndexField, SpanField 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 12 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 13 | 14 | from allennlp.data.tokenizers import word_splitter 15 | 16 | from pathnet.data.data_utils import pack_span_idxs, find_cand_locs_fromdoclist, \ 17 | get_locs_forall, get_max_locs, pack_doc_idxs 18 | 19 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | @DatasetReader.register("obqa_data_reader") 23 | class OBQAMultiChoiceJsonReader(DatasetReader): 24 | """ 25 | This data is formatted as jsonl, one json-formatted instance per line. 26 | This dataset format is obtained after the path adjustment step. 27 | Parameters 28 | ---------- 29 | tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) 30 | We use this ``Tokenizer`` for all. See :class:`Tokenizer`. 31 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) 32 | See :class:`TokenIndexer`. 33 | lazy : whether to use lazy mode for training (do not keep everything in the memory) 34 | cut_context : Whether to strip out the portions from documents which do not 35 | participate in constructing paths 36 | """ 37 | 38 | def __init__(self, 39 | tokenizer: Tokenizer = None, 40 | token_indexers: Dict[str, TokenIndexer] = None, 41 | lazy: bool = False, 42 | cut_context: bool = False) -> None: 43 | super().__init__(lazy) 44 | self._tokenizer = tokenizer or WordTokenizer( 45 | word_splitter=word_splitter.JustSpacesWordSplitter()) 46 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 47 | self._cut_context = cut_context 48 | 49 | @staticmethod 50 | def find_sentidx(doctoks: List[List[str]], 51 | word_idx: int) -> int: 52 | """ 53 | find the sentidx given word idx 54 | :param doctoks: 55 | :param word_idx: 56 | :return: 57 | """ 58 | count = 0 59 | for idx, doc in enumerate(doctoks): 60 | count += len(doc) 61 | if word_idx < count: 62 | return idx 63 | return len(doctoks) - 1 64 | 65 | def get_documents_markers(self, docsents: List[List[List[str]]], 66 | paths: List[List[Any]]) -> Any: 67 | """ 68 | Cut the documents 69 | :param docsents: 70 | :param paths: 71 | :return: 72 | """ 73 | if not self._cut_context: 74 | return None 75 | min_max_sentidx_dict = {} 76 | num_word_shift_dict = {} 77 | for didx in range(len(docsents)): 78 | min_max_sentidx_dict[didx] = [len(docsents[didx]) - 1, 0] 79 | num_word_shift_dict[didx] = 0 80 | 81 | for cidx in range(len(paths)): 82 | for pidx in range(len(paths[cidx])): 83 | he_docidx = paths[cidx][pidx]['he_docidx'] 84 | ca_docidx = paths[cidx][pidx]['cand_docidx'] 85 | he_locs = paths[cidx][pidx]['he_locs'] 86 | e1wh_locs = paths[cidx][pidx]['e1wh_loc'] 87 | e1_locs = paths[cidx][pidx]['e1_locs'] 88 | ca_locs = paths[cidx][pidx]['cand_locs'] 89 | 90 | # if he_docidx is not None and ca_docidx is not None 91 | if he_docidx == ca_docidx: 92 | all_start_sent_inds = [] 93 | all_end_sent_inds = [] 94 | for x in [he_locs, ca_locs]: 95 | if x is not None: 96 | for loc in x: 97 | if loc[0] != -1 or loc[1] != -1: 98 | all_start_sent_inds.append(self.find_sentidx( 99 | docsents[he_docidx], loc[0]) 100 | ) 101 | all_end_sent_inds.append(self.find_sentidx( 102 | docsents[he_docidx], loc[1]) 103 | ) 104 | if len(all_start_sent_inds) > 0: 105 | min_max_sentidx_dict[he_docidx][0] = min(min(all_start_sent_inds), 106 | min_max_sentidx_dict[he_docidx][0]) 107 | if len(all_end_sent_inds) > 0: 108 | min_max_sentidx_dict[he_docidx][1] = max(max(all_end_sent_inds), 109 | min_max_sentidx_dict[he_docidx][1]) 110 | else: 111 | idxs = [he_docidx, ca_docidx] 112 | for yidx, y in enumerate([[he_locs, e1wh_locs], [e1_locs, ca_locs]]): 113 | docidx = idxs[yidx] 114 | all_start_sent_inds = [] 115 | all_end_sent_inds = [] 116 | for x in y: 117 | if x is not None: 118 | for loc in x: 119 | if loc[0] != -1 or loc[1] != -1: 120 | all_start_sent_inds.append(self.find_sentidx( 121 | docsents[docidx], loc[0]) 122 | ) 123 | all_end_sent_inds.append(self.find_sentidx( 124 | docsents[docidx], loc[1]) 125 | ) 126 | if len(all_start_sent_inds) > 0: 127 | min_max_sentidx_dict[docidx][0] = min(min(all_start_sent_inds), 128 | min_max_sentidx_dict[docidx][0]) 129 | if len(all_end_sent_inds) > 0: 130 | min_max_sentidx_dict[docidx][1] = max(max(all_end_sent_inds), 131 | min_max_sentidx_dict[docidx][1]) 132 | 133 | for key in min_max_sentidx_dict.keys(): 134 | if min_max_sentidx_dict[key][0] > min_max_sentidx_dict[key][1]: 135 | min_max_sentidx_dict[key] = [0, 0] # useless doc 136 | if min_max_sentidx_dict[key][0] > 0: 137 | num_word_shift_dict[key] = len(sum(docsents[key][:min_max_sentidx_dict[key][0]], [])) 138 | 139 | return min_max_sentidx_dict, num_word_shift_dict 140 | 141 | @overrides 142 | def _read(self, file_path: str): 143 | 144 | with open(file_path, 'r') as data_file: 145 | logger.info("Reading OBQA instances from jsonl data at: %s", file_path) 146 | for line in data_file: 147 | item_json = json.loads(line.strip()) 148 | 149 | item_id = item_json["id"] 150 | question = item_json['question'] 151 | docsents = item_json['docsents'] 152 | candidates_ = item_json['candidates'] 153 | orig_paths = item_json['paths'] # cands * num_paths 154 | if len(candidates_) != len(orig_paths): 155 | logger.info(f"******** {item_id} skipped ********") 156 | continue 157 | ans = ' '.join(item_json["answer"]) if "answer" in item_json else None 158 | 159 | yield self.text_to_instance(item_id, docsents, 160 | question, candidates_, orig_paths, 161 | ans) 162 | 163 | @overrides 164 | def text_to_instance(self, # type: ignore 165 | item_id, docsents, 166 | question, candidates_, orig_paths, 167 | ans): 168 | # filtering 169 | orig_candidates = [' '.join(c) for c in candidates_] 170 | assert len(orig_candidates) == len(orig_paths) 171 | candidates = [] 172 | paths = [] 173 | for c, p in zip(orig_candidates, orig_paths): 174 | if len(c) == 1 and ord(c) in range(91, 123): 175 | continue 176 | elif c in ["of", "he", "a", "an", "the", "as", 177 | "e .", "s .", "a .", '*', ',', '.', '"']: 178 | continue 179 | else: 180 | candidates.append(c) 181 | paths.append(p) 182 | 183 | choice_label_to_id: Dict = {} 184 | choice_text_list: List[str] = [] 185 | 186 | for choice_id, choice_text in enumerate(candidates): 187 | choice_label_to_id[choice_text] = choice_id 188 | choice_text_list.append(choice_text) 189 | 190 | if ans is not None: 191 | answer_id = choice_label_to_id[ans] 192 | else: 193 | answer_id = None 194 | 195 | question_text = " ".join(question) 196 | 197 | if self._cut_context and \ 198 | self.get_documents_markers(docsents, paths) is not None: 199 | doc_sent_boundary_dict, num_word_shift_dict = self.get_documents_markers( 200 | docsents, paths) 201 | mod_docsents = [] 202 | for didx in range(len(docsents)): 203 | start_sidx = doc_sent_boundary_dict[didx][0] 204 | end_sidx = doc_sent_boundary_dict[didx][1] 205 | mod_docsents.append(docsents[didx][start_sidx:end_sidx + 1]) 206 | else: 207 | mod_docsents = docsents 208 | doc_sent_boundary_dict, num_word_shift_dict = None, None 209 | 210 | documents_text_list: List[str] = [' '.join(sum(doc, [])) 211 | for doc in mod_docsents] 212 | 213 | p1_list, p2_list, he_locs_list, e1wh_locs_list, \ 214 | e1_locs_list, ca_locs_list = get_locs_forall(paths, 215 | doc_sent_boundary_dict, 216 | num_word_shift_dict, 217 | mod_docsents) 218 | 219 | max_he_locs, max_e1wh_locs, max_e1_locs, max_ca_locs = get_max_locs(paths, 220 | he_locs_list, 221 | e1wh_locs_list, 222 | e1_locs_list, 223 | ca_locs_list) 224 | 225 | flattened_he_locs_list, he_tracks = pack_span_idxs(he_locs_list) 226 | flattened_e1wh_locs_list, e1wh_tracks = pack_span_idxs(e1wh_locs_list) 227 | flattened_e1_locs_list, e1_tracks = pack_span_idxs(e1_locs_list) 228 | flattened_ca_locs_list, ca_tracks = pack_span_idxs(ca_locs_list) 229 | 230 | flattened_p1_list = pack_doc_idxs(p1_list, he_locs_list) 231 | flattened_p1_list_e1wh = pack_doc_idxs(p1_list, e1wh_locs_list) 232 | flattened_p2_list_e1 = pack_doc_idxs(p2_list, e1_locs_list) 233 | flattened_p2_list = pack_doc_idxs(p2_list, ca_locs_list) 234 | 235 | max_paths = max([len(p) for p in paths]) 236 | 237 | all_choice_locs, all_choice_docidxs = find_cand_locs_fromdoclist(documents_text_list, 238 | choice_text_list, 239 | lowercase=True) 240 | 241 | return self.formatted_text_to_instance(item_id, question_text, 242 | documents_text_list, 243 | flattened_p1_list, 244 | flattened_p1_list_e1wh, 245 | flattened_p2_list_e1, 246 | flattened_p2_list, 247 | flattened_he_locs_list, 248 | flattened_e1wh_locs_list, 249 | flattened_e1_locs_list, 250 | flattened_ca_locs_list, 251 | he_tracks, e1wh_tracks, 252 | e1_tracks, ca_tracks, 253 | max_paths, 254 | max_he_locs, max_e1wh_locs, 255 | max_e1_locs, max_ca_locs, 256 | choice_text_list, 257 | all_choice_locs, all_choice_docidxs, 258 | answer_id) 259 | 260 | def formatted_text_to_instance(self, # type: ignore 261 | item_id: Any, 262 | question_text: str, 263 | documents_text_list: List[str], 264 | flattened_p1_list: List[int], 265 | flattened_p1_list_e1wh: List[int], 266 | flattened_p2_list_e1: List[int], 267 | flattened_p2_list: List[int], 268 | flattened_he_locs_list: List[Tuple[int, int]], 269 | flattened_e1wh_locs_list: List[Tuple[int, int]], 270 | flattened_e1_locs_list: List[Tuple[int, int]], 271 | flattened_ca_locs_list: List[Tuple[int, int]], 272 | he_tracks: List[List[int]], 273 | e1wh_tracks: List[List[int]], 274 | e1_tracks: List[List[int]], 275 | ca_tracks: List[List[int]], 276 | max_paths: int, 277 | max_he_locs: int, max_e1wh_locs: int, 278 | max_e1_locs: int, max_ca_locs: int, 279 | choice_text_list: List[str], 280 | all_choice_locs: List[List[Tuple[int, int]]], 281 | all_choice_docidxs: List[List[int]], 282 | answer_id: int) -> Instance: 283 | # pylint: disable=arguments-differ 284 | fields: Dict[str, Field] = {} 285 | question_tokens = self._tokenizer.tokenize(question_text) 286 | documents_list_tokens = [self._tokenizer.tokenize(dt) for dt in documents_text_list] 287 | if len(sum(documents_list_tokens, [])) == 0: 288 | documents_list_tokens = [question_tokens] 289 | 290 | choices_list_tokens = [self._tokenizer.tokenize(x) for x in choice_text_list] 291 | 292 | fields['question'] = TextField(question_tokens, self._token_indexers) 293 | document_text_fields = [TextField(x, self._token_indexers) for x in documents_list_tokens] 294 | document_field = ListField(document_text_fields) 295 | fields['documents'] = document_field 296 | fields['candidates'] = ListField([TextField(x, self._token_indexers) for x in choices_list_tokens]) 297 | 298 | fields['flattened_p1list'] = ListField([IndexField(x, document_field) 299 | for x in flattened_p1_list]) 300 | fields['flattened_p1list_e1wh'] = ListField([IndexField(x, document_field) 301 | for x in flattened_p1_list_e1wh]) 302 | fields['flattened_p2list_e1'] = ListField([IndexField(x, document_field) 303 | for x in flattened_p2_list_e1]) 304 | fields['flattened_p2list'] = ListField([IndexField(x, document_field) 305 | for x in flattened_p2_list]) 306 | 307 | fields['flat_he_spans'] = ListField([SpanField(x[0], x[1], document_text_fields[flattened_p1_list[xidx]]) 308 | for xidx, x in enumerate(flattened_he_locs_list)]) 309 | fields['flat_e1wh_spans'] = ListField([SpanField(x[0], x[1], document_text_fields[flattened_p1_list_e1wh[xidx]]) 310 | for xidx, x in enumerate(flattened_e1wh_locs_list)]) 311 | fields['flat_e1_spans'] = ListField([SpanField(x[0], x[1], document_text_fields[flattened_p2_list_e1[xidx]]) 312 | for xidx, x in enumerate(flattened_e1_locs_list)]) 313 | fields['flat_choice_spans'] = ListField([SpanField(x[0], x[1], document_text_fields[flattened_p2_list[xidx]]) 314 | for xidx, x in enumerate(flattened_ca_locs_list)]) 315 | 316 | # all choice fields 317 | all_choice_docidx_field = [] 318 | all_choice_span_fileds = [] 319 | for choice_docidxs, choice_spans in zip(all_choice_docidxs, all_choice_locs): 320 | all_choice_docidx_field.append(ListField([IndexField(x, document_field) 321 | for x in choice_docidxs])) 322 | all_choice_span_fileds.append(ListField([SpanField(x[0], x[1], 323 | document_text_fields[choice_docidxs[xidx]]) 324 | for xidx, x in enumerate(choice_spans)])) 325 | fields['all_choice_docidxs'] = ListField(all_choice_docidx_field) 326 | fields['all_choice_locs'] = ListField(all_choice_span_fileds) 327 | 328 | if answer_id is not None: 329 | fields['label'] = LabelField(answer_id, skip_indexing=True) 330 | 331 | metadata = { 332 | "id": item_id, 333 | "question_text": question_text, 334 | "documents_text": documents_text_list, 335 | "choice_text_list": choice_text_list, 336 | "he_tracks": he_tracks, 337 | "e1wh_tracks": e1wh_tracks, 338 | "e1_tracks": e1_tracks, 339 | "choice_tracks": ca_tracks, 340 | "max_num_paths": max_paths, 341 | "max_num_he_locs": max_he_locs, 342 | "max_num_e1wh_locs": max_e1wh_locs, 343 | "max_num_e1_locs": max_e1_locs, 344 | "max_num_ca_locs": max_ca_locs, 345 | } 346 | 347 | fields["metadata"] = MetadataField(metadata) 348 | 349 | return Instance(fields) 350 | -------------------------------------------------------------------------------- /pathnet/data/rel_vocab.py: -------------------------------------------------------------------------------- 1 | rel_vocab = ["@norel", "time_of_spacecraft_landing", "inspired_by", "shape", 2 | "head_of_government", "winner", "highest_point", "symptoms", 3 | "journey_destination", "drug_used_for_treatment", "arterial_supply", 4 | "has_immediate_cause", "edition_or_translation_of", "legislative_body", 5 | "measured_physical_quantity", "relative", "affiliation", "located_on_street", 6 | "has_facility", "medical_treatment", "published_in", "eye_color", "head_coach", 7 | "product_certification", "streak_color", "first_performance", "crystal_system", 8 | "represents_organisation", "terminus", "item_operated", "origin_of_the_watercourse", 9 | "narrator", "doctoral_student", "statement_describes", "taxonomic_type", 10 | "given_name_version_for_other_gender", "physically_interacts_with", "used_by", 11 | "contains_administrative_territorial_entities", "temporal_range_end", 12 | "day_in_year_for_periodic_occurrence", "connecting_service", "codomain", 13 | "anatomical_location", "set_in_period", "writing_system", "academic_degree", 14 | "territory_claimed_by", "port_of_registry", "dedicated_to", "diocese", 15 | "honorific_prefix", "located_on_astronomical_body", "hymenium_type", 16 | "website_account_on", "diplomatic_relation", "ancestral_home", "enclave_within", 17 | "illustrator", "tonality", "currency", "conferred_by", 18 | "office_held_by_head_of_the_organisation", "location_of_final_assembly", 19 | "chairperson", "patron_saint", "wing_configuration", "replaces", "operating_system", 20 | "product", "interaction", "natural_product_of_taxon", "doctoral_advisor", 21 | "notable_work", "deity_of", "librettist", "first_flight", "found_in_taxon", 22 | "electoral_district", "participant", "temporal_range_start", "source_of_energy", 23 | "head_of_state", "ortholog", "home_venue", "journey_origin", "film_editor", 24 | "commissioned_by", "programmer", "student_of", "convicted_of", "tributary", 25 | "adjacent_station", "partner", "instrumentation", "afflicts", "heritage_status", 26 | "place_of_publication", "penalty", "legislated_by", "influenced_by", "different_from", 27 | "powerplant", "official_language", "canonization_status", "has_facet_polytope", 28 | "astronomical_body", "period", "deepest_point", "location_of_landing", 29 | "time_of_discovery", "floruit", "surface_played_on", "opposite_of", 30 | "time_of_spacecraft_launch", "has_cause", "replaced_by", "endemic_to", 31 | "sexual_orientation", "allegiance", "minor_planet_group", "space_launch_vehicle", 32 | "game_mode", "lake_outflow", "appointed_by", "field_of_this_profession", 33 | "lake_inflows", "depicts", "executive_producer", "stock_exchange", 34 | "discoverer_or_inventor", "noble_title", "programming_language", "legal_form", 35 | "from_fictional_universe", "said_to_be_the_same_as", "sister_city", "organizer", 36 | "collection", "director_of_photography", "spore_print_color", "medical_condition", 37 | "service_retirement", "parent_company", "date_of_official_opening", 38 | "place_of_burial", "characters", "voice_type", "native_language", 39 | "successful_candidate", "input_device", "cast_member", "applies_to_jurisdiction", 40 | "founder", "capital", "facet_of", "religious_order", "airline_alliance", "crosses", 41 | "point_in_time", "drafted_by", "start_time", "family_name", "connecting_line", 42 | "location_of_formation", "licensed_to_broadcast_to", "basic_form_of_government", 43 | "material_used", "site_of_astronomical_discovery", "maintained_by", "capital_of", 44 | "airline_hub", "end_time", "cause_of_death", "significant_event", "sister", 45 | "service_entry", "child", "language_of_work_or_name", "military_rank", "league", 46 | "software_engine", "residence", "based_on", "manner_of_death", "architectural_style", 47 | "mother", "nominated_for", "license", "spouse", "brother", "present_in_work", 48 | "occupant", "distribution", "constellation", "located_next_to_body_of_water", 49 | "work_location", "color", "use", "political_ideology", "lyrics_by", "architect", 50 | "continent", "designer", "medical_specialty", "participant_of", "named_after", 51 | "manufacturer", "screenwriter", "dissolved_or_abolished", "member_of", "industry", 52 | "member_of_sports_team", "series", "composer", "filming_location", "has_part", 53 | "performer", "location", "instrument", "military_branch", "director", "employer", 54 | "author", "creator", "mouth_of_the_watercourse", "award_received", "father", 55 | "field_of_work", "movement", "conflict", "operator", "located_on_terrain_feature", 56 | "ethnic_group", "distributor", "noble_family", "religion", "production_company", 57 | "platform", "office_contested", "original_network", "narrative_location", 58 | "owned_by", "developer", "main_subject", "educated_at", "given_name", "follows", 59 | "taxon_rank", "position_played_on_team_/_speciality", "producer", 60 | "shares_border_with", "is_a_list_of", "position_held", "followed_by", 61 | "original_language_of_work", "date_of_death", "languages_spoken_or_written", 62 | "country_of_origin", "publication_date", "publisher", "member_of_political_party", 63 | "sport", "subclass_of", "part_of", "headquarters_location", "country", 64 | "date_of_birth", "inception", "place_of_death", "parent_taxon", 65 | "country_of_citizenship", "genre", "record_label", "place_of_birth", 66 | "occupation", "located_in_the_administrative_territorial_entity", "instance_of"] -------------------------------------------------------------------------------- /pathnet/data/wikihop_data_reader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/PathNet/e47be03584f59beef1ea770b2894dadcaee75fc4/pathnet/data/wikihop_data_reader/__init__.py -------------------------------------------------------------------------------- /pathnet/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/PathNet/e47be03584f59beef1ea770b2894dadcaee75fc4/pathnet/model/__init__.py -------------------------------------------------------------------------------- /pathnet/model/qa_with_raw/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/PathNet/e47be03584f59beef1ea770b2894dadcaee75fc4/pathnet/model/qa_with_raw/__init__.py -------------------------------------------------------------------------------- /pathnet/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.nn.layers import JointEncoder 2 | from pathnet.nn.layers import AttnPooling 3 | -------------------------------------------------------------------------------- /pathnet/nn/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import logging 3 | import torch 4 | from allennlp.common.from_params import FromParams 5 | from allennlp.modules import FeedForward, Seq2SeqEncoder 6 | from allennlp.nn.util import masked_softmax 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | 11 | class CtxPathEncoder(torch.nn.Module, FromParams): 12 | """ 13 | Context-based Path Encoder Class 14 | """ 15 | def __init__(self, type='tree', he_e1_comp: Optional[FeedForward] = None, 16 | e1_ca_comp: Optional[FeedForward] = None, 17 | r1r2_comp: Optional[FeedForward] = None, 18 | rnn_comp: Optional[Seq2SeqEncoder] = None, 19 | ) -> None: 20 | """ 21 | :param type: tree | tree_shared | rnn | rnn_delim 22 | :param he_e1_comp: 23 | :param e1_ca_comp: 24 | :param r1r2_comp: 25 | :param rnn_comp: 26 | """ 27 | super(CtxPathEncoder, self).__init__() 28 | self._type = type 29 | if self._type == "tree": 30 | assert he_e1_comp is not None 31 | assert e1_ca_comp is not None 32 | assert r1r2_comp is not None 33 | elif self._type == "tree_shared": 34 | assert he_e1_comp is not None 35 | assert r1r2_comp is not None 36 | elif self._type in ["rnn", "rnn_delim"]: 37 | assert rnn_comp is not None 38 | else: 39 | raise NotImplementedError 40 | self._he_e1wh_projector = he_e1_comp 41 | self._e1_ca_projector = e1_ca_comp 42 | self._path_projector = r1r2_comp 43 | self._path_rnn = rnn_comp 44 | 45 | def forward(self, he_rep, e1wh_rep, e1_rep, ca_rep, he_mask_, e1wh_mask_, e1_mask_, ca_mask_): 46 | """ 47 | compute path representation based on type 48 | :param he_rep: BCP * 2H 49 | :param e1wh_rep: BCP * 2H 50 | :param e1_rep: BCP * 2H 51 | :param ca_rep: BCP * 2H 52 | :param he_mask_: BCP * 1 53 | :param e1wh_mask_: BCP * 1 54 | :param e1_mask_: BCP * 1 55 | :param ca_mask_: BCP * 1 56 | :return: BCP * H 57 | """ 58 | bcp = he_rep.size(0) 59 | hdim = he_rep.size(1) 60 | if torch.cuda.is_available(): 61 | device = he_rep.get_device() 62 | if self._type in ["tree", "tree_shared"]: 63 | he_e1wh_rep = self._he_e1wh_projector(torch.cat([he_rep, e1wh_rep], 1)) 64 | he_e1wh_rep = he_e1wh_rep * he_mask_.float() 65 | if self._type == "tree_shared": 66 | e1_ca_rep = self._he_e1wh_projector(torch.cat([e1_rep, ca_rep], 1)) 67 | else: 68 | e1_ca_rep = self._e1_ca_projector(torch.cat([e1_rep, ca_rep], 1)) 69 | e1_ca_rep = e1_ca_rep * ca_mask_.float() 70 | encoded_paths = self._path_projector(torch.cat([he_e1wh_rep, e1_ca_rep], 1)) 71 | encoded_paths = encoded_paths * ca_mask_.float() # BCP * H 72 | elif self._type in ["rnn", "rnn_delim"]: 73 | if self._type == "rnn_delim": 74 | if torch.cuda.is_available(): 75 | ones_float_ = torch.ones([bcp, 1, hdim]).cuda(device=device) 76 | ones_long_ = torch.ones([bcp, 1]).long().cuda(device=device) 77 | else: 78 | ones_float_ = torch.ones([bcp, 1, hdim]) 79 | ones_long_ = torch.ones([bcp, 1]).long() 80 | 81 | path_rnn_in = torch.cat([he_rep.unsqueeze(1), 82 | e1wh_rep.unsqueeze(1), 83 | ones_float_, 84 | e1_rep.unsqueeze(1), 85 | ca_rep.unsqueeze(1), 86 | ones_float_], 1) # BCP * 6 * 2H 87 | mask_in = torch.cat([he_mask_, e1wh_mask_, ones_long_, 88 | e1_mask_, ca_mask_, ones_long_], 1) # BCP * 6 89 | else: 90 | path_rnn_in = torch.cat([he_rep.unsqueeze(1), 91 | e1wh_rep.unsqueeze(1), 92 | e1_rep.unsqueeze(1), 93 | ca_rep.unsqueeze(1)], 1) # BCP * 4 * 2H 94 | mask_in = torch.cat([he_mask_, e1wh_mask_, e1_mask_, ca_mask_], 1) # BCP * 4 95 | encoded_paths = self._path_rnn(path_rnn_in, mask_in) # BCP * 4/5 * H 96 | encoded_paths = encoded_paths[:, -1, :] # BCP * H 97 | encoded_paths = encoded_paths * ca_mask_.float() # BCP * H 98 | else: 99 | raise NotImplementedError 100 | return encoded_paths 101 | 102 | 103 | class JointEncoder(torch.nn.Module, FromParams): 104 | def __init__(self, seq_encoder: Optional[Seq2SeqEncoder] = None) -> None: 105 | super(JointEncoder, self).__init__() 106 | self._seq_encoder = seq_encoder 107 | 108 | def forward(self, doc_encoding, q_encoding, doc_mask, q_mask): 109 | """ 110 | 111 | :param doc_encoding: B * N * T * H 112 | :param q_encoding: B * U * H 113 | :param doc_mask: B * N * T 114 | :param q_mask: B * U 115 | :return: B * N * T * 2H 116 | """ 117 | batch_size = doc_encoding.shape[0] 118 | num_docs = doc_encoding.shape[1] 119 | num_doc_tokens = doc_encoding.shape[2] 120 | num_q_tokens = q_encoding.shape[1] 121 | doc_encoding = doc_encoding.view(batch_size, num_docs * num_doc_tokens, -1) # B * NT * H 122 | attn_unnorm = doc_encoding.bmm(q_encoding.transpose(2, 1)) # B * NT * U 123 | attn = masked_softmax(attn_unnorm, q_mask.unsqueeze(1).expand(attn_unnorm.size()), 124 | dim=-1) # B * NT * U 125 | aggq = attn.bmm(q_encoding) # B * NT * H 126 | attn_t = attn_unnorm.transpose(2, 1).contiguous().view(batch_size, 127 | -1, num_docs, 128 | num_doc_tokens) # B * U * N * T 129 | attn_t = masked_softmax(attn_t, doc_mask.unsqueeze(1).expand(attn_t.size()), dim=-1) 130 | attn_t = attn_t.view(batch_size, num_q_tokens, -1) # B * U * NT 131 | aggdoc = attn_t.bmm(doc_encoding) # B * U * H 132 | aggq2 = attn.bmm(aggdoc) # B * NT * H 133 | if self._seq_encoder is not None: 134 | aggq2 = aggq2.view(batch_size * num_docs, num_doc_tokens, -1) # BN * T * H 135 | aggq2 = self._seq_encoder(aggq2, doc_mask.view(batch_size * num_docs, -1)) # BN * T * H 136 | 137 | aggq2 = aggq2.view(doc_encoding.size()) # B * N * T * H 138 | aggq = aggq.view(doc_encoding.size()) # B * N * T * H 139 | return torch.cat([aggq, aggq2], -1) 140 | 141 | 142 | class AttnPooling(torch.nn.Module, FromParams): 143 | def __init__(self, projector: FeedForward, 144 | intermediate_projector: FeedForward = None) -> None: 145 | super(AttnPooling, self).__init__() 146 | self._projector = projector 147 | self._int_proj = intermediate_projector 148 | 149 | def forward(self, xinit: torch.FloatTensor, 150 | xmask: torch.LongTensor) -> torch.FloatTensor: 151 | """ 152 | 153 | :param xinit: B * T * H 154 | :param xmask: B * T 155 | :return: B * H 156 | """ 157 | if self._int_proj is not None: 158 | x = self._int_proj(xinit) 159 | x = x * xmask.unsqueeze(-1) 160 | else: 161 | x = xinit 162 | attn = self._projector(x) # B * T * 1 163 | attn = attn.squeeze(-1) # B * T 164 | attn = masked_softmax(attn, xmask, dim=-1) 165 | pooled = attn.unsqueeze(1).bmm(xinit).squeeze(1) # B * H 166 | return pooled 167 | -------------------------------------------------------------------------------- /pathnet/nn/util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | import logging 3 | 4 | import torch 5 | from allennlp.common.util import gpu_memory_mb 6 | from allennlp.nn.util import combine_tensors 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | 11 | def path_encoding(x, y, combine_str, fforward, gate): 12 | z = combine_tensors(combine_str, [x, y]) 13 | z = fforward(z) 14 | gatef = gate(z) 15 | return gatef * z 16 | 17 | 18 | def get_final_encoder_states(encoder_outputs: torch.Tensor, 19 | mask: torch.Tensor, 20 | bidirectional: bool = False) -> torch.Tensor: 21 | """ 22 | Modified over the original Allennlp function 23 | 24 | Given the output from a ``Seq2SeqEncoder``, with shape ``(batch_size, sequence_length, 25 | encoding_dim)``, this method returns the final hidden state for each element of the batch, 26 | giving a tensor of shape ``(batch_size, encoding_dim)``. This is not as simple as 27 | ``encoder_outputs[:, -1]``, because the sequences could have different lengths. We use the 28 | mask (which has shape ``(batch_size, sequence_length)``) to find the final state for each batch 29 | instance. 30 | 31 | Additionally, if ``bidirectional`` is ``True``, we will split the final dimension of the 32 | ``encoder_outputs`` into two and assume that the first half is for the forward direction of the 33 | encoder and the second half is for the backward direction. We will concatenate the last state 34 | for each encoder dimension, giving ``encoder_outputs[:, -1, :encoding_dim/2]`` concated with 35 | ``encoder_outputs[:, 0, encoding_dim/2:]``. 36 | """ 37 | # These are the indices of the last words in the sequences (i.e. length sans padding - 1). We 38 | # are assuming sequences are right padded. 39 | # Shape: (batch_size,) 40 | last_word_indices = mask.sum(1).long() - 1 41 | 42 | # handle -1 cases 43 | ll_ = (last_word_indices != -1).long() 44 | last_word_indices = last_word_indices * ll_ 45 | 46 | batch_size, _, encoder_output_dim = encoder_outputs.size() 47 | expanded_indices = last_word_indices.view(-1, 1, 1).expand(batch_size, 1, encoder_output_dim) 48 | # Shape: (batch_size, 1, encoder_output_dim) 49 | final_encoder_output = encoder_outputs.gather(1, expanded_indices) 50 | final_encoder_output = final_encoder_output.squeeze(1) # (batch_size, encoder_output_dim) 51 | if bidirectional: 52 | final_forward_output = final_encoder_output[:, :(encoder_output_dim // 2)] 53 | final_backward_output = encoder_outputs[:, 0, (encoder_output_dim // 2):] 54 | final_encoder_output = torch.cat([final_forward_output, final_backward_output], dim=-1) 55 | return final_encoder_output 56 | 57 | 58 | def seq2vec_seq_aggregate(seq_tensor, mask, aggregate, bidirectional, dim=1): 59 | """ 60 | Takes the aggregation of sequence tensor 61 | :param seq_tensor: Batched sequence requires [batch, seq, hs] 62 | :param mask: binary mask with shape batch, seq_len, 1 63 | :param aggregate: max, avg, sum 64 | :param bidirectional: bool(True/False) 65 | :param dim: The dimension to take the max. for batch, seq, hs it is 1 66 | :return: 67 | """ 68 | seq_tensor_masked = seq_tensor * mask.unsqueeze(-1).float() 69 | aggr_func = None 70 | if aggregate == "last": 71 | seq = get_final_encoder_states(seq_tensor, mask, bidirectional) 72 | elif aggregate == "max": 73 | aggr_func = torch.max 74 | seq, _ = aggr_func(seq_tensor_masked, dim=dim) 75 | elif aggregate == "min": 76 | aggr_func = torch.min 77 | seq, _ = aggr_func(seq_tensor_masked, dim=dim) 78 | elif aggregate == "sum": 79 | aggr_func = torch.sum 80 | seq = aggr_func(seq_tensor_masked, dim=dim) 81 | elif aggregate == "avg": 82 | aggr_func = torch.sum 83 | seq = aggr_func(seq_tensor_masked, dim=dim) 84 | seq_lens = torch.sum(mask, dim=dim) # this returns batch_size, 1 85 | seq = seq / seq_lens.view([-1, 1]) 86 | else: 87 | raise NotImplementedError 88 | 89 | return seq 90 | 91 | 92 | def gather_vectors_using_index(src_tensor, index_tensor) -> torch.FloatTensor: 93 | """ 94 | Uses the indices in index_tensor to select vectors from src_tensor 95 | :param src_tensor: batch x N x h 96 | :param index_tensor: Indices with dim: batch x C x P x 1 97 | :return: selected embeddings with dim: batch x C x P x h 98 | """ 99 | if index_tensor.size()[-1] != 1: 100 | raise ValueError("Expecting last index to be 1. Found {}".format(index_tensor.size())) 101 | flat_idx_tensor = index_tensor.view(index_tensor.size(0), -1, 1) # B * CP * 1 102 | 103 | # B * CP * Th 104 | expanded_index_size = [x for x in flat_idx_tensor.size()[:-1]] + [src_tensor.size()[-1]] 105 | expanded_index_tensor = flat_idx_tensor.expand(expanded_index_size).long() # B * CP * H 106 | 107 | flat_extracted = torch.gather(src_tensor, 1, expanded_index_tensor) # B * CP * H 108 | 109 | extracted = flat_extracted.view(src_tensor.size(0), index_tensor.size(1), 110 | index_tensor.size(2), -1) # B * C * P * H 111 | return extracted 112 | 113 | 114 | def gather_tensors_using_index(src_tensor, index_tensor) -> torch.FloatTensor: 115 | """ 116 | Uses the indices in index_tensor to select matrices from src_tensor 117 | :param src_tensor: batch x N x T x h 118 | :param index_tensor: Indices with dim: batch x C x P x 1 119 | :return: selected embeddings with dim: batch x C x P x T x h 120 | """ 121 | if index_tensor.size()[-1] != 1: 122 | raise ValueError("Expecting last index to be 1. Found {}".format(index_tensor.size())) 123 | flat_idx_tensor = index_tensor.view(index_tensor.size(0), -1, 1, 1) # B * CP * 1 * 1 124 | 125 | # B * CP * T * h 126 | expanded_index_tensor = flat_idx_tensor.expand(flat_idx_tensor.shape[:-2] 127 | + src_tensor.shape[-2:]).long() # B * CP * T * h 128 | 129 | flat_extracted = torch.gather(src_tensor, 1, expanded_index_tensor) # B * CP * T * h 130 | 131 | extracted = flat_extracted.view(src_tensor.size(0), index_tensor.size(1), 132 | index_tensor.size(2), src_tensor.size(2), -1) # B * C * P * T * h 133 | return extracted 134 | 135 | 136 | def gather_tensor_masks_using_index(src_tensor_mask, index_tensor) -> torch.FloatTensor: 137 | """ 138 | Uses the indices in index_tensor to select vectors from src_tensor_mask 139 | :param src_tensor_mask: batch x N x T 140 | :param index_tensor: Indices with dim: batch x C x P x 1 141 | :return: selected embeddings with dim: batch x C x P x T 142 | """ 143 | if index_tensor.size()[-1] != 1: 144 | raise ValueError("Expecting last index to be 1. Found {}".format(index_tensor.size())) 145 | 146 | flat_idx_tensor = index_tensor.view(index_tensor.size(0), -1, 1) # B * CP * 1 147 | 148 | # B * CP * T 149 | expanded_index_size = [x for x in flat_idx_tensor.size()[:-1]] + [src_tensor_mask.size()[-1]] 150 | expanded_index_tensor = flat_idx_tensor.expand(expanded_index_size).long() # B * CP * T 151 | 152 | flat_extracted = torch.gather(src_tensor_mask, 1, expanded_index_tensor) # B * CP * T 153 | 154 | extracted_mask = flat_extracted.view(src_tensor_mask.size(0), index_tensor.size(1), 155 | index_tensor.size(2), -1) # B * C * P * T 156 | return extracted_mask 157 | 158 | 159 | def pad_packed_loc_tensors(tensor: torch.FloatTensor, 160 | num_cand: int, 161 | num_path: int, num_loc: int, 162 | track_list: List[List[List[int]]], 163 | mask_tensor: torch.LongTensor = None): 164 | """ 165 | Packing the location-based tensors 166 | This helps to reduce memory usage 167 | :param tensor: B * (cpl) * H 168 | :param num_cand: maximum number of candidates (C) 169 | :param num_path: maximum number of paths (P) 170 | :param num_loc: maximum number of locations (L) 171 | :param track_list: B * (cpl) * 3 172 | :param mask_tensor: B * (cpl) 173 | :return: B * C * P * L * H 174 | """ 175 | batch_size = tensor.size(0) 176 | cpl = tensor.size(1) 177 | hdim = tensor.size(-1) 178 | ind1_tensor = torch.zeros(batch_size, num_cand, num_path, num_loc) # B * C * P * L 179 | ind2_tensor = ind1_tensor + cpl # B * C * P * L 180 | if torch.cuda.is_available(): 181 | device = tensor.get_device() 182 | _zeros = torch.zeros([batch_size, 1, hdim]).cuda(device=device) 183 | _mask_zeros = torch.zeros([batch_size, 1]).long().cuda(device=device) 184 | ind1_tensor = ind1_tensor.cuda(device=device) 185 | ind2_tensor = ind2_tensor.cuda(device=device) 186 | else: 187 | _zeros = torch.zeros([batch_size, 1, hdim]) 188 | _mask_zeros = torch.zeros([batch_size, 1]).long() 189 | padded_tensor = torch.cat([tensor, _zeros], dim=1) 190 | 191 | for bidx in range(batch_size): 192 | tracks = track_list[bidx] # cpl * 3 193 | for trackidx, track in enumerate(tracks): 194 | candidx = track[0] 195 | pathidx = track[1] 196 | locidx = track[2] 197 | ind1_tensor[bidx, candidx, pathidx, locidx] = bidx 198 | ind2_tensor[bidx, candidx, pathidx, locidx] = trackidx 199 | 200 | output_tensor = padded_tensor[ind1_tensor.long(), ind2_tensor.long()] 201 | if torch.cuda.is_available(): 202 | output_tensor = output_tensor.cuda(device=device) 203 | if mask_tensor is not None: 204 | padded_mask_tensor = torch.cat([mask_tensor, _mask_zeros], dim=1) 205 | output_mask_tensor = padded_mask_tensor[ind1_tensor.long(), ind2_tensor.long()] 206 | if torch.cuda.is_available(): 207 | output_mask_tensor = output_mask_tensor.cuda(device=device) 208 | return output_tensor, output_mask_tensor 209 | 210 | return output_tensor 211 | 212 | 213 | def pad_packed_loc_tensors_with_docidxs(tensor: torch.FloatTensor, 214 | docidx_tensor: torch.LongTensor, 215 | num_cand: int, 216 | num_path: int, num_loc: int, 217 | track_list: List[List[List[int]]], 218 | mask_tensor: torch.LongTensor = None): 219 | """ 220 | padding and packing of the location-based tensors with document indices 221 | :param tensor: B * (cpl) * H 222 | :param docidx_tensor: B * (cpl) * 1 223 | :param num_cand: maximum number of candidates (C) 224 | :param num_path: maximum number of paths (P) 225 | :param num_loc: maximum number of locations (L) 226 | :param track_list: B * (cpl) * 3 227 | :param mask_tensor: B * (cpl) 228 | :return: B * C * P * L * H 229 | """ 230 | assert tensor.shape[:2] == docidx_tensor.shape[:2] 231 | batch_size = tensor.size(0) 232 | cpl = tensor.size(1) 233 | hdim = tensor.size(-1) 234 | ind1_tensor = torch.zeros(batch_size, num_cand, num_path, num_loc) # B * C * P * L 235 | ind2_tensor = ind1_tensor + cpl # B * C * P * L 236 | if torch.cuda.is_available(): 237 | device = tensor.get_device() 238 | _zeros = torch.zeros([batch_size, 1, hdim]).cuda(device=device) 239 | _docidx_zeros = torch.zeros([batch_size, 1, 1]).long().cuda(device=device) 240 | _mask_zeros = torch.zeros([batch_size, 1]).long().cuda(device=device) 241 | ind1_tensor = ind1_tensor.cuda(device=device) 242 | ind2_tensor = ind2_tensor.cuda(device=device) 243 | else: 244 | _zeros = torch.zeros([batch_size, 1, hdim]) 245 | _docidx_zeros = torch.zeros([batch_size, 1, 1]).long() 246 | _mask_zeros = torch.zeros([batch_size, 1]).long() 247 | padded_tensor = torch.cat([tensor, _zeros], dim=1) 248 | padded_docidx_tensor = torch.cat([docidx_tensor, _docidx_zeros], dim=1) 249 | 250 | for bidx in range(batch_size): 251 | tracks = track_list[bidx] # cpl * 3 252 | for trackidx, track in enumerate(tracks): 253 | candidx = track[0] 254 | pathidx = track[1] 255 | locidx = track[2] 256 | ind1_tensor[bidx, candidx, pathidx, locidx] = bidx 257 | ind2_tensor[bidx, candidx, pathidx, locidx] = trackidx 258 | 259 | output_tensor = padded_tensor[ind1_tensor.long(), ind2_tensor.long()] 260 | output_docidx_tensor = padded_docidx_tensor[ind1_tensor.long(), ind2_tensor.long()] 261 | if torch.cuda.is_available(): 262 | output_tensor = output_tensor.cuda(device=device) 263 | output_docidx_tensor = output_docidx_tensor.cuda(device=device) 264 | if mask_tensor is not None: 265 | padded_mask_tensor = torch.cat([mask_tensor, _mask_zeros], dim=1) 266 | output_mask_tensor = padded_mask_tensor[ind1_tensor.long(), ind2_tensor.long()] 267 | if torch.cuda.is_available(): 268 | output_mask_tensor = output_mask_tensor.cuda(device=device) 269 | return output_tensor, output_docidx_tensor, output_mask_tensor 270 | 271 | return output_tensor, output_docidx_tensor 272 | -------------------------------------------------------------------------------- /pathnet/pathfinder/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.pathfinder.path_extractor import * -------------------------------------------------------------------------------- /pathnet/pathfinder/obqa_path_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Tuple 2 | from copy import deepcopy 3 | 4 | from pathnet.pathfinder.util import create_offsets, get_widxs_from_chidxs, get_locations, \ 5 | cluster_postags, cluster_nertags, get_all_ents, find_sentidx, remove_overlapping_ents, \ 6 | get_lookup_words, get_locations_words 7 | 8 | 9 | class ObqaPathFinder(object): 10 | """class for path finder for the OBQA dataset""" 11 | 12 | def __init__(self, qid, docsentlist: List[List[List[str]]], 13 | # entity: str, relation: str, 14 | question: List[str], 15 | qpos: List[str], 16 | qner: List[str], 17 | candidates: List[str], 18 | candpos: List[List[str]], 19 | candner: List[List[str]], 20 | answer: str = None, 21 | sentlimit: int = 1, minentlen: int = 3, 22 | nearest_only: bool = False, 23 | qenttype="nps", candenttype="nps") -> None: 24 | self.docsentlist = docsentlist 25 | self.docoffsets = [] 26 | for docsent in self.docsentlist: 27 | doctoks = sum(docsent, []) 28 | self.docoffsets.append(create_offsets(doctoks)) 29 | 30 | self.qenttype = qenttype 31 | self.candenttype = candenttype 32 | self.question = question 33 | self.qpos = qpos 34 | self.qner = qner 35 | self.entity_set = get_lookup_words(self.question, self.qpos, 36 | self.qner, type=self.qenttype) 37 | 38 | self.candidates = candidates 39 | self.candpos = candpos 40 | self.candner = candner 41 | self.answer = answer 42 | self.sentlimit = sentlimit 43 | self.minentlen = minentlen 44 | self.nearest_only = nearest_only 45 | self.id = qid 46 | 47 | def find_entity_in_all_docs(self, entity_list: List[str]) -> (int, List[List[Any]]): 48 | num_locs = 0 49 | alldoc_he_locs = [] 50 | for docidx, doc in enumerate(self.docsentlist): 51 | doc_he_locs = [] 52 | for sentidx, sent in enumerate(doc): 53 | prev_num_words = len(sum(doc[:sentidx], [])) 54 | offsets_for_sent = self.docoffsets[docidx][prev_num_words: 55 | prev_num_words + len(sent)] 56 | he_locs = {} 57 | for entity in entity_list: 58 | cur_entity_locs = get_locations(sent, entity, 59 | offsets_for_sent, 60 | allow_partial=True) # {w1: [start_ind_1, ...], w2: [start_ind_1, ], ...} 61 | if len(cur_entity_locs) > 0: 62 | for key, values in cur_entity_locs.items(): 63 | if key not in he_locs.keys(): 64 | he_locs[key] = values 65 | else: 66 | he_locs[key] += values 67 | doc_he_locs.append(he_locs) 68 | num_locs += len(he_locs) 69 | alldoc_he_locs.append(doc_he_locs) 70 | return num_locs, alldoc_he_locs 71 | 72 | def get_all_he_locs(self) -> List[List[Any]]: 73 | """ 74 | get all possible head entity locations 75 | :return: 76 | """ 77 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs(self.entity_set) 78 | if num_locs > 0: 79 | return alldoc_he_locs 80 | entity_list = [] 81 | for entity in self.entity_set: 82 | if "-" in entity: 83 | entity_list.append(entity.replace('-', ' - ')) 84 | if " of " in entity: 85 | entity_list.append(entity.split(" of ")[0].strip()) 86 | if ":" in entity: 87 | entity_list.append(entity.split(":")[0].strip()) 88 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs(entity_list) 89 | if num_locs > 0: 90 | return alldoc_he_locs 91 | entity_list = [] 92 | for entity in self.entity_set: 93 | if entity[-1:] == 's': 94 | entity_list.append(entity[:-1]) 95 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs(entity_list) 96 | if num_locs > 0: 97 | return alldoc_he_locs 98 | entity_list = [] 99 | for entity in self.entity_set: 100 | if entity[-2:] == 'es': 101 | entity_list.append(entity[:-2]) 102 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs(entity_list) 103 | if num_locs > 0: 104 | return alldoc_he_locs 105 | return None 106 | 107 | def check_candidate_validity(self, cand: str) -> bool: 108 | # if length is 1 character and not a digit 109 | if len(cand) == 1 and ord(cand) in range(91, 123): 110 | return False 111 | if cand in ["of", "he", "a", "an", "the", "as", "e .", "s .", "a .", '*', ',', '.', '"']: 112 | return False 113 | if cand == ' '.join(self.question): 114 | return False 115 | return True 116 | 117 | def check_sentdist(self, entlocs: List[int], 118 | candlocs: List[int], k=2) -> bool: 119 | """ 120 | check if there any combination which 121 | falls under a specified distance threshold 122 | :param entlocs: 123 | :param candlocs: 124 | :param k: 125 | :return: 126 | """ 127 | for el in entlocs: 128 | for cl in candlocs: 129 | if abs(cl - el) < self.sentlimit + k: # window size 5 130 | return True 131 | return False 132 | 133 | def find_path_for_cand(self, allents: List[Tuple[int, str]], 134 | cand: str, 135 | candpos: List[str], candner: List[str], 136 | curidx: int, head_ent_locs: Any, 137 | cand_find_window=2) -> List[Dict]: 138 | """ 139 | find paths for a candidate 140 | :param allents: 141 | :param curidx: 142 | :param cand: 143 | :param candpos: 144 | :param candner: 145 | :param curidx: 146 | :param head_ent_locs: 147 | :param cand_find_window: 148 | :return: List[ent, docidx, List[wordidx]] 149 | """ 150 | idxs = [] 151 | 152 | # code for single hop path 153 | cand_locs = get_locations_words(sum(self.docsentlist[curidx], []), 154 | get_lookup_words(cand.split(' '), 155 | candpos, candner, 156 | type=self.candenttype), 157 | self.docoffsets[curidx], 158 | allow_partial=False) 159 | if len(cand_locs) > 0: 160 | idxs.append({"head_ent_docidx": curidx, 161 | "head_ent": head_ent_locs, 162 | "e1_with_head_ent": None, 163 | "e1_with_head_widx": None, 164 | "e1": None, 165 | "e1_docidx": None, 166 | "e1_locs": None, 167 | "cand_locs": cand_locs}) 168 | 169 | for entwidx, ent in allents: 170 | if ent == cand: 171 | continue 172 | for i in list(range(curidx)) + list(range(curidx + 1, len(self.docsentlist))): 173 | doctoks = sum(self.docsentlist[i], []) 174 | assert len(doctoks) == len(self.docoffsets[i]) 175 | cand_locs = get_locations_words(doctoks, get_lookup_words(cand.split(' '), 176 | candpos, candner, 177 | type=self.candenttype), 178 | self.docoffsets[i], 179 | allow_partial=True) 180 | if len(cand_locs) > 0: 181 | ent_locs = get_locations(doctoks, ent, self.docoffsets[i], 182 | allow_partial=True) 183 | if len(ent_locs) > 0: 184 | for eloc in ent_locs.keys(): 185 | idxs.append({"head_ent_docidx": curidx, 186 | "head_ent": head_ent_locs, 187 | "e1_with_head_ent": ent, 188 | "e1_with_head_widx": entwidx, 189 | "e1": eloc, 190 | "e1_docidx": i, 191 | "e1_locs": ent_locs[eloc], 192 | "cand_locs": cand_locs}) 193 | return idxs 194 | 195 | def accum_paths_for_cand(self, alldoc_he_locs: List[List[Any]], 196 | docners: List[List[str]], 197 | docpostags: List[List[str]], 198 | cand: str, candpos: List[str], 199 | candner: List[str]) -> List[Any]: 200 | """ 201 | accumulate all the paths for a particular candidate 202 | :param alldoc_he_locs: 203 | :param docners: 204 | :param docpostags: 205 | :param cand: 206 | :param candpos: 207 | :param candner: 208 | :return: 209 | """ 210 | paths_to_cand = [] 211 | 212 | # exclude the ill-posed candidates 213 | if not self.check_candidate_validity(cand): 214 | return paths_to_cand 215 | 216 | for docidx, doc in enumerate(self.docsentlist): 217 | for sentidx, sent in enumerate(doc): 218 | prev_num_words = len(sum(doc[:sentidx], [])) 219 | valid_sents = doc[sentidx:min(len(doc), 220 | sentidx + self.sentlimit + 1)] 221 | nerstidx = prev_num_words 222 | nerendidx = nerstidx + len(sum(valid_sents, [])) 223 | 224 | he_locs = deepcopy(alldoc_he_locs[docidx][sentidx]) 225 | if len(he_locs) > 0: 226 | for key in he_locs.keys(): 227 | for i in range(len(he_locs[key])): 228 | he_locs[key][i] += prev_num_words 229 | 230 | ners = docners[docidx][nerstidx:nerendidx] 231 | postags = docpostags[docidx][nerstidx:nerendidx] 232 | word_toks = sum(valid_sents, []) 233 | assert len(ners) == len(word_toks) 234 | assert len(ners) == len(postags) 235 | 236 | neidxs, cwn, cne = cluster_nertags(word_toks, ners) 237 | poidxs, cwp, cpo = cluster_postags(word_toks, postags) 238 | neidxs = [n + nerstidx for n in neidxs] 239 | poidxs = [p + nerstidx for p in poidxs] 240 | 241 | entset = get_all_ents(cwn, cne, cwp, cpo, neidxs, poidxs) 242 | # remove the head entity(s) from the entset 243 | # entset = [es for es in entset if es[1].lower() != self.entity.lower()] 244 | entset = remove_overlapping_ents(he_locs, entset) 245 | 246 | paths_to_cand += self.find_path_for_cand(entset, cand, candpos, candner, 247 | docidx, he_locs) 248 | return paths_to_cand 249 | 250 | def get_paths(self, docners: List[List[str]], 251 | docpostags: List[List[str]]) -> Dict: 252 | """ 253 | get path lists for all candidates 254 | :param docners: 255 | :param docpostags: 256 | :return: 257 | """ 258 | alldoc_he_locs = None 259 | if len(self.question) > 0: 260 | alldoc_he_locs = self.get_all_he_locs() 261 | 262 | all_paths = {} 263 | for candidx, cand in enumerate(self.candidates): 264 | if len(self.question) == 0 or alldoc_he_locs is None: 265 | paths_to_cand = [] 266 | else: 267 | candpos = self.candpos[candidx] 268 | candner = self.candner[candidx] 269 | paths_to_cand = self.accum_paths_for_cand(alldoc_he_locs, 270 | docners, 271 | docpostags, 272 | cand, candpos, candner) 273 | 274 | all_paths[cand] = paths_to_cand 275 | 276 | return all_paths 277 | -------------------------------------------------------------------------------- /pathnet/pathfinder/path_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Tuple 2 | from copy import deepcopy 3 | 4 | from pathnet.pathfinder.util import create_offsets, get_locations, \ 5 | cluster_postags, cluster_nertags, get_all_ents, find_sentidx, remove_overlapping_ents 6 | 7 | 8 | class PathFinder(object): 9 | """class for path finder """ 10 | 11 | def __init__(self, qid, docsentlist: List[List[List[str]]], 12 | entity: str, relation: str, 13 | candidates: List[str], answer: str = None, 14 | sentlimit: int = 1, minentlen: int = 3, 15 | nearest_only: bool = False) -> None: 16 | self.docsentlist = docsentlist 17 | self.docoffsets = [] 18 | for docsent in self.docsentlist: 19 | doctoks = sum(docsent, []) 20 | self.docoffsets.append(create_offsets(doctoks)) 21 | self.entity = entity 22 | self.relation = relation 23 | self.candidates = candidates 24 | self.answer = answer 25 | self.sentlimit = sentlimit 26 | self.minentlen = minentlen 27 | self.nearest_only = nearest_only 28 | self.id = qid 29 | 30 | def find_entity_in_all_docs(self, entity_list: List[str]) -> (int, List[List[Any]]): 31 | """ 32 | find the entities in all the documents 33 | :param entity_list: 34 | :return: 35 | """ 36 | num_locs = 0 37 | alldoc_he_locs = [] 38 | for docidx, doc in enumerate(self.docsentlist): 39 | doc_he_locs = [] 40 | for sentidx, sent in enumerate(doc): 41 | prev_num_words = len(sum(doc[:sentidx], [])) 42 | offsets_for_sent = self.docoffsets[docidx][prev_num_words: 43 | prev_num_words + len(sent)] 44 | he_locs = {} 45 | for entity in entity_list: 46 | cur_entity_locs = get_locations(sent, entity, 47 | offsets_for_sent, 48 | allow_partial=True) # {w1: [start_ind_1, ...], w2: [start_ind_1, ], ...} 49 | if len(cur_entity_locs) > 0: 50 | for key, values in cur_entity_locs.items(): 51 | if key not in he_locs.keys(): 52 | he_locs[key] = values 53 | else: 54 | he_locs[key] += values 55 | doc_he_locs.append(he_locs) 56 | num_locs += len(he_locs) 57 | alldoc_he_locs.append(doc_he_locs) 58 | return num_locs, alldoc_he_locs 59 | 60 | def get_all_he_locs(self) -> List[List[Any]]: 61 | """ 62 | get the locations of the head entity 63 | :return: 64 | """ 65 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs([self.entity]) 66 | if num_locs > 0: 67 | return alldoc_he_locs 68 | entity_list = [] 69 | if "-" in self.entity: 70 | entity_list.append(self.entity.replace('-', ' - ')) 71 | if " of " in self.entity: 72 | entity_list.append(self.entity.split(" of ")[0].strip()) 73 | if ":" in self.entity: 74 | entity_list.append(self.entity.split(":")[0].strip()) 75 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs(entity_list) 76 | if num_locs > 0: 77 | return alldoc_he_locs 78 | if self.entity[-1:] == 's': 79 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs([self.entity[:-1]]) 80 | if num_locs > 0: 81 | return alldoc_he_locs 82 | if self.entity[-2:] == 'es': 83 | num_locs, alldoc_he_locs = self.find_entity_in_all_docs([self.entity[:-1]]) 84 | if num_locs > 0: 85 | return alldoc_he_locs 86 | num_toks = len(self.entity.split()) 87 | part1 = int(num_toks / 2) 88 | first_part = ' '.join(self.entity.split()[:part1 + 1]) 89 | last_part = ' '.join(self.entity.split()[part1 + 1:]) 90 | _, alldoc_he_locs = self.find_entity_in_all_docs([first_part, last_part]) 91 | return alldoc_he_locs 92 | 93 | def check_candidate_validity(self, cand: str) -> bool: 94 | # if length is 1 character and not a digit 95 | if len(cand) == 1 and ord(cand) in range(91, 123): 96 | return False 97 | if cand in ["of", "he", "a", "an", "the", "as", "e .", "s .", "a .", '*', ',', '.', '"']: 98 | return False 99 | if cand == self.entity: 100 | return False 101 | return True 102 | 103 | def check_sentdist(self, entlocs: List[int], 104 | candlocs: List[int], k=2) -> bool: 105 | """ 106 | check if there any combination which 107 | falls under a specified sentence distance threshold 108 | :param entlocs: 109 | :param candlocs: 110 | :param k: 111 | :return: 112 | """ 113 | for el in entlocs: 114 | for cl in candlocs: 115 | if abs(cl - el) < self.sentlimit + k: # window size 5 116 | return True 117 | return False 118 | 119 | def find_path_for_cand(self, allents: List[Tuple[int, str]], 120 | cand: str, 121 | curidx: int, head_ent_locs: Any, 122 | cand_find_window=2) -> List[Dict]: 123 | """ 124 | find the paths for a particular candidate 125 | :param allents: 126 | :param curidx: 127 | :param cand: 128 | :param curidx: 129 | :param head_ent_locs: 130 | :param cand_find_window: 131 | :return: List[ent, docidx, List[wordidx]] 132 | """ 133 | idxs = [] 134 | 135 | # code for single hop path 136 | cand_locs = get_locations(sum(self.docsentlist[curidx], []), 137 | cand, self.docoffsets[curidx], 138 | allow_partial=False) 139 | if len(cand_locs) > 0: 140 | idxs.append({"head_ent_docidx": curidx, 141 | "head_ent": head_ent_locs, 142 | "e1_with_head_ent": None, 143 | "e1_with_head_widx": None, 144 | "e1": None, 145 | "e1_docidx": None, 146 | "e1_locs": None, 147 | "cand_locs": cand_locs}) 148 | 149 | for entwidx, ent in allents: 150 | if ent == cand: 151 | continue 152 | for i in list(range(curidx)) + list(range(curidx + 1, len(self.docsentlist))): 153 | doctoks = sum(self.docsentlist[i], []) 154 | assert len(doctoks) == len(self.docoffsets[i]) 155 | cand_locs = get_locations(doctoks, cand, self.docoffsets[i], 156 | allow_partial=False) 157 | if len(cand_locs) > 0: 158 | ent_locs = get_locations(doctoks, ent, self.docoffsets[i], 159 | allow_partial=True) 160 | if len(ent_locs) > 0: 161 | for eloc in ent_locs.keys(): 162 | ewlocs = ent_locs[eloc] 163 | eslocs = [find_sentidx(self.docsentlist[i], ew) for 164 | ew in ewlocs] 165 | cand_sent_locs = {} 166 | for key in cand_locs.keys(): 167 | cand_sent_locs[key] = [find_sentidx( 168 | self.docsentlist[i], cw) 169 | for cw in cand_locs[key]] 170 | cand_sent_locs_values = sum(list(cand_sent_locs.values()), []) 171 | 172 | # check for atleast self.sentlimit + k sent gap 173 | if self.check_sentdist(eslocs, cand_sent_locs_values, 174 | k=cand_find_window): 175 | # check the closest ones 176 | valid_cand_locs = {} 177 | if self.nearest_only: 178 | for window in range(cand_find_window + 1): 179 | valid_cand_locs = {} 180 | for key in cand_locs.keys(): 181 | valid_cand_locs_for_key = [cand_locs[key][i] 182 | for i in range(len(cand_locs[key])) 183 | if self.check_sentdist( 184 | eslocs, [cand_sent_locs[key][i]], k=window 185 | )] 186 | if len(valid_cand_locs_for_key) > 0: 187 | valid_cand_locs[key] = valid_cand_locs_for_key 188 | if len(sum(list(valid_cand_locs.values()), [])) > 0: 189 | break 190 | else: 191 | valid_cand_locs = cand_locs 192 | idxs.append({"head_ent_docidx": curidx, 193 | "head_ent": head_ent_locs, 194 | "e1_with_head_ent": ent, 195 | "e1_with_head_widx": entwidx, 196 | "e1": eloc, 197 | "e1_docidx": i, 198 | "e1_locs": ent_locs[eloc], 199 | "cand_locs": valid_cand_locs}) 200 | return idxs 201 | 202 | def accum_paths_for_cand(self, alldoc_he_locs: List[List[Any]], 203 | docners: List[List[str]], 204 | docpostags: List[List[str]], 205 | cand: str) -> List[Any]: 206 | """ 207 | accumulate all the paths for a particular candidate 208 | :param alldoc_he_locs: 209 | :param docners: 210 | :param docpostags: 211 | :param cand: 212 | :return: 213 | """ 214 | paths_to_cand = [] 215 | 216 | # exclude the ill-posed candidates 217 | if not self.check_candidate_validity(cand): 218 | return paths_to_cand 219 | 220 | for docidx, doc in enumerate(self.docsentlist): 221 | for sentidx, sent in enumerate(doc): 222 | prev_num_words = len(sum(doc[:sentidx], [])) 223 | valid_sents = doc[sentidx:min(len(doc), 224 | sentidx + self.sentlimit + 1)] 225 | nerstidx = prev_num_words 226 | nerendidx = nerstidx + len(sum(valid_sents, [])) 227 | 228 | he_locs = deepcopy(alldoc_he_locs[docidx][sentidx]) 229 | if len(he_locs) > 0: 230 | for key in he_locs.keys(): 231 | for i in range(len(he_locs[key])): 232 | he_locs[key][i] += prev_num_words 233 | 234 | ners = docners[docidx][nerstidx:nerendidx] 235 | postags = docpostags[docidx][nerstidx:nerendidx] 236 | word_toks = sum(valid_sents, []) 237 | assert len(ners) == len(word_toks) 238 | assert len(ners) == len(postags) 239 | 240 | neidxs, cwn, cne = cluster_nertags(word_toks, ners) 241 | poidxs, cwp, cpo = cluster_postags(word_toks, postags) 242 | neidxs = [n + nerstidx for n in neidxs] 243 | poidxs = [p + nerstidx for p in poidxs] 244 | 245 | entset = get_all_ents(cwn, cne, cwp, cpo, neidxs, poidxs) 246 | # remove the head entity(s) from the entset 247 | # entset = [es for es in entset if es[1].lower() != self.entity.lower()] 248 | entset = remove_overlapping_ents(he_locs, entset) 249 | 250 | paths_to_cand += self.find_path_for_cand(entset, cand, 251 | docidx, he_locs) 252 | return paths_to_cand 253 | 254 | def get_paths(self, docners: List[List[str]], 255 | docpostags: List[List[str]]) -> Dict: 256 | """ 257 | get path lists for all candidates 258 | :param docners: 259 | :param docpostags: 260 | :return: 261 | """ 262 | alldoc_he_locs = None 263 | if len(self.entity.split()) > 0: 264 | alldoc_he_locs = self.get_all_he_locs() 265 | 266 | all_paths = {} 267 | for cand in self.candidates: 268 | if len(self.entity.split()) == 0: 269 | paths_to_cand = [] 270 | else: 271 | paths_to_cand = self.accum_paths_for_cand(alldoc_he_locs, 272 | docners, 273 | docpostags, 274 | cand) 275 | 276 | all_paths[cand] = paths_to_cand 277 | 278 | return all_paths 279 | -------------------------------------------------------------------------------- /pathnet/pathfinder/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict, Any, Tuple 3 | from copy import deepcopy 4 | import string 5 | 6 | PUNKT_SET = set(string.punctuation) 7 | 8 | VALIDNE_TAGS = ['PRODUCT', 'NORP', 'WORK_OF_ART', 9 | 'LANGUAGE', 'LOC', 'GPE', 'PERSON', 10 | 'FAC', 'ORG', 'EVENT'] 11 | 12 | VALIDPOS_TAGS = ['NN', 'NNP', 'NNPS', 'NNS'] 13 | 14 | STOPWORDS = { 15 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 16 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 17 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 18 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 19 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 20 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 21 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 22 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 23 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 24 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 25 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 26 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 27 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 28 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 29 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 30 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 31 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 32 | } 33 | 34 | 35 | def get_widxs_from_chidxs(chidxs: List[int], 36 | offsets: List[List[int]]) -> List[int]: 37 | """ 38 | Find word indices given character indices 39 | :param chidxs: 40 | :param offsets: 41 | :return: 42 | """ 43 | last_ch_idx = offsets[0][0] 44 | assert max(chidxs) < offsets[-1][1] - last_ch_idx 45 | widxs = [] 46 | for chidx in chidxs: 47 | for oi in range(len(offsets)): 48 | if chidx in range(offsets[oi][0] - last_ch_idx, offsets[oi][1] - last_ch_idx): 49 | widxs.append(oi) 50 | break 51 | elif chidx in range(offsets[oi][1] - last_ch_idx, 52 | offsets[min(oi + 1, len(offsets))][0] - last_ch_idx): 53 | widxs.append(oi) 54 | break 55 | assert len(chidxs) == len(widxs) 56 | return widxs 57 | 58 | 59 | def create_offsets(doctoks: List[str]) -> List[List[int]]: 60 | """ 61 | create offsets for a document tokens 62 | :param doctoks: 63 | :return: 64 | """ 65 | offsets = [] 66 | char_count = 0 67 | for tok in doctoks: 68 | offsets.append([char_count, char_count + len(tok)]) 69 | char_count = char_count + len(tok) + 1 70 | return offsets 71 | 72 | 73 | def cluster_nertags(toks: List[str], 74 | nertags: List[str] 75 | ) -> (List[int], List[List[str]], List[List[str]]): 76 | """ 77 | cluster based on ner tags 78 | """ 79 | newtags = [] 80 | newtoks = [] 81 | startidxs = [] 82 | 83 | tidx = 0 84 | while tidx < len(nertags): 85 | curtags = [] 86 | curtoks = [] 87 | curtag = nertags[tidx] 88 | curtok = toks[tidx] 89 | 90 | startidxs.append(tidx) 91 | curtags.append(curtag) 92 | curtoks.append(curtok) 93 | tidx += 1 94 | prevtag = curtag 95 | 96 | while True: 97 | if tidx < len(nertags) and prevtag in VALIDNE_TAGS: 98 | curtag = nertags[tidx] 99 | curtok = toks[tidx] 100 | if curtag == prevtag: 101 | curtags.append(curtag) 102 | curtoks.append(curtok) 103 | tidx += 1 104 | else: 105 | break 106 | else: 107 | break 108 | 109 | newtags.append(curtags) 110 | newtoks.append(curtoks) 111 | 112 | return startidxs, newtoks, newtags 113 | 114 | 115 | def cluster_postags(toks: List[str], 116 | postags: List[str] 117 | ) -> (List[int], List[List[str]], List[List[str]]): 118 | """ 119 | cluster based on postags 120 | """ 121 | newtags = [] 122 | newtoks = [] 123 | startidxs = [] 124 | 125 | tidx = 0 126 | while tidx < len(postags): 127 | curtags = [] 128 | curtoks = [] 129 | curtag = postags[tidx] 130 | curtok = toks[tidx] 131 | 132 | startidxs.append(tidx) 133 | curtags.append(curtag) 134 | curtoks.append(curtok) 135 | tidx += 1 136 | prevtag = curtag 137 | 138 | while True: 139 | if tidx < len(postags) and prevtag in VALIDPOS_TAGS: 140 | curtag = postags[tidx] 141 | curtok = toks[tidx] 142 | if (prevtag == 'NNP' and curtag == 'NNP') or \ 143 | (curtag == 'NN' or curtag == 'NNS' and prevtag == 'NN') or \ 144 | (prevtag == 'NNP' and curtag == 'NNS') or \ 145 | (prevtag == 'NNP' and curtag == 'NNS'): 146 | curtags.append(curtag) 147 | curtoks.append(curtok) 148 | tidx += 1 149 | else: 150 | break 151 | else: 152 | break 153 | 154 | newtags.append(curtags) 155 | newtoks.append(curtoks) 156 | 157 | return startidxs, newtoks, newtags 158 | 159 | 160 | def get_all_ents(cwn: List[List[str]], cne: List[List[str]], 161 | cwp: List[List[str]], cpo: List[List[str]], 162 | neidxs: List[int], poidxs: List[int]) -> List[Tuple[int, str]]: 163 | """ 164 | pick entities based on NER and POS tags 165 | :param cwn: clustered words based on NER 166 | :param cne: clustered NER 167 | :param cwp: clustered words based on POS 168 | :param cpo: clustered POS 169 | :param neidxs: start indices for NER clusters 170 | :param poidxs: start indices for POS clusters 171 | :return: set of other entities 172 | """ 173 | entset = [] 174 | for ne in VALIDNE_TAGS: 175 | for idx in range(len(cne)): 176 | if ne in cne[idx]: 177 | entset.append((neidxs[idx], cwn[idx])) 178 | 179 | for pos in VALIDPOS_TAGS: 180 | for idx in range(len(cpo)): 181 | if pos in cpo[idx]: 182 | entset.append((poidxs[idx], cwp[idx])) 183 | entset = set([(e[0], ' '.join(e[1])) for e in entset]) 184 | uniq_entset = [] 185 | for e in entset: 186 | if e not in uniq_entset: 187 | uniq_entset.append(e) 188 | uniq_entset = filter_trailing_puncts(uniq_entset) 189 | return uniq_entset 190 | 191 | 192 | def filter_trailing_puncts(entity_list: List[Tuple[int, str]]): 193 | """ 194 | filtering based on trailing punctuations 195 | :param entity_list: [(0, word1), (10, word2), ...] 196 | :return: 197 | """ 198 | new_ent_list = [] 199 | for eidx, e in enumerate(entity_list): 200 | e_toks = e[1].split(' ') 201 | if len(e_toks) == 1 and e_toks[0] in PUNKT_SET: 202 | continue 203 | else: 204 | if e_toks[0] in PUNKT_SET and e_toks[-1] in PUNKT_SET: 205 | new_ent_list.append((e[0] + 1, ' '.join(e_toks[1:-1]))) 206 | elif e_toks[0] in PUNKT_SET: 207 | new_ent_list.append((e[0] + 1, ' '.join(e_toks[1:]))) 208 | else: 209 | new_ent_list.append(e) 210 | return new_ent_list 211 | 212 | 213 | def find_word_re(toks: List[str], 214 | w: str) -> (bool, Any): 215 | """ 216 | find word in a doc using re 217 | :param toks: 218 | :param w: 219 | :return: 220 | """ 221 | try: 222 | pat = re.compile(w.lower() + '\W') 223 | except: 224 | return False, None 225 | 226 | objs = [] 227 | doc_str = ' '.join(toks).lower() 228 | for x in pat.finditer(doc_str): 229 | objs.append(x) 230 | if len(objs) > 0: 231 | # found_words = [ob.group(0).strip() for ob in objs] 232 | start_ch_idxs = [ob.span()[0] for ob in objs] 233 | start_widxs = [len(doc_str[:chidx].split()) for chidx in start_ch_idxs] 234 | for i in range(len(start_widxs)): 235 | if start_widxs[i] > len(toks) - 1: 236 | start_widxs[i] = len(toks) - 1 237 | return True, start_widxs 238 | else: 239 | return False, None 240 | 241 | 242 | def get_locations(doctoks: List[str], word: str, 243 | docoffsets: List[List[int]], 244 | allow_partial: bool = False) -> Dict: 245 | """ 246 | check the locations. if word = x; find X 247 | if word = XY; find X*Y 248 | if word = XYZ; find x*Z 249 | :param doctoks: 250 | :param word: 251 | :param docoffsets: 252 | :param allow_partial: 253 | :return: 254 | """ 255 | assert len(doctoks) == len(docoffsets) 256 | wordspanss = {} 257 | word = word.lower() 258 | try: 259 | re.compile(word) 260 | except: 261 | return wordspanss 262 | 263 | w_toks = word.split() 264 | if len(w_toks) < 1: 265 | return wordspanss 266 | 267 | if len(w_toks) == 1: 268 | pat = re.compile('(^|\W)' + re.escape(word) + '\W') 269 | else: 270 | if allow_partial: 271 | pat = re.compile('((^|\W)' + re.escape(word) + '\W|(^|\W)' + re.escape(w_toks[0]) + 272 | '\W(.{,40}?\W)??' + re.escape(w_toks[-1]) + '\W)') 273 | else: 274 | pat = re.compile('(^|\W)' + re.escape(word) + '\W') 275 | 276 | objs = [] 277 | doc_str = ' '.join(doctoks).lower() 278 | for x in pat.finditer(doc_str): 279 | objs.append(x) 280 | if len(objs) > 0: 281 | found_words = [ob.group(0) for ob in objs] 282 | start_ch_idxs = [ob.span()[0] for ob in objs] 283 | assert len(found_words) == len(start_ch_idxs) 284 | for fwidx, fw in enumerate(found_words): 285 | start_offset = len(fw) - len(fw.lstrip()) 286 | start_ch_idxs[fwidx] += start_offset 287 | found_words[fwidx] = fw.strip() 288 | start_widxs = get_widxs_from_chidxs(start_ch_idxs, deepcopy(docoffsets)) 289 | for swidx in range(len(start_widxs)): 290 | if start_widxs[swidx] > len(doctoks) - 1: 291 | start_widxs[swidx] = len(doctoks) - 1 292 | for i, w in enumerate(found_words): 293 | if w not in list(wordspanss.keys()): 294 | wordspanss[w] = [start_widxs[i]] 295 | else: 296 | wordspanss[w].append(start_widxs[i]) 297 | return wordspanss 298 | 299 | 300 | def find_sentidx(doctoks: List[List[str]], 301 | word_idx: int) -> int: 302 | """ 303 | find the sentidx given word idx 304 | :param doctoks: 305 | :param word_idx: 306 | :return: 307 | """ 308 | count = 0 309 | for idx, doc in enumerate(doctoks): 310 | count += len(doc) 311 | if word_idx < count: 312 | return idx 313 | return len(doctoks) - 1 314 | 315 | 316 | def remove_overlapping_ents(he_loc_dict: Dict, 317 | ent_list: List[Tuple[int, str]]) -> List[Tuple[int, str]]: 318 | """ 319 | remove the overlapping entities (potential duplicates) 320 | :param he_loc_dict: 321 | :param ent_list: 322 | :return: 323 | """ 324 | 325 | def is_present(idx: int, spans: List[Tuple[int, int]]): 326 | for sp in spans: 327 | if idx in range(sp[0], sp[1]): 328 | return True 329 | return False 330 | 331 | he_span_dict = {} 332 | for key in list(he_loc_dict.keys()): 333 | # exclusive end 334 | he_span_dict[key] = [(he_loc_dict[key][i], 335 | he_loc_dict[key][i] + len(key.split(' '))) 336 | for i in range(len(he_loc_dict[key]))] 337 | he_all_spans = sum(he_span_dict.values(), []) 338 | ent_list_spans = [(e[0], e[0] + len(e[1].split(' ')), e[1]) for e in ent_list] 339 | new_ent_list = [] 340 | for eidx, e in enumerate(ent_list_spans): 341 | if is_present(e[0], he_all_spans) or is_present(e[1], he_all_spans): 342 | continue 343 | else: 344 | new_ent_list.append((e[0], e[2])) 345 | return new_ent_list 346 | 347 | 348 | def unroll(counts: List[int], l: List[Any]) -> List[List[Any]]: 349 | counts = [0] + counts 350 | unrolled_list = [] 351 | for idx in range(len(counts) - 1): 352 | curr_idx = sum(counts[:idx + 1]) 353 | next_idx = curr_idx + counts[idx + 1] 354 | unrolled_list.append(l[curr_idx:next_idx]) 355 | return unrolled_list 356 | 357 | 358 | def get_non_stop_words(toks: List[str]): 359 | """ 360 | retrieve non-stopwords 361 | :param toks: 362 | :return: 363 | """ 364 | lw = [] 365 | for tok in toks: 366 | tok = tok.lower() 367 | if tok not in STOPWORDS and tok not in PUNKT_SET: 368 | lw.append(tok) 369 | if len(lw) == 0: 370 | lw = [toks[0]] 371 | return lw 372 | 373 | 374 | def get_lookup_words(toks: List[str], 375 | pos: List[str], ner: List[str], 376 | type='nonstopwords') -> List[str]: 377 | """ 378 | get the potential (head) entities when (head) entity 379 | is not given specifically. This is necessary for 380 | OBQA like settings 381 | :param toks: question/candidate 382 | :param pos: postags 383 | :param ner: ners 384 | :param type: nonstopwords/noun phrases,ners(nps) 385 | :return: 386 | """ 387 | if type == 'nonstopwords': 388 | lw = get_non_stop_words(toks) 389 | elif type == 'nps': 390 | neidxs, cwn, cne = cluster_nertags(toks, ner) 391 | poidxs, cwp, cpo = cluster_postags(toks, pos) 392 | lw = get_all_ents(cwn, cne, cwp, cpo, neidxs, poidxs) 393 | lw = [tup[1] for tup in lw] 394 | if len(lw) == 0: 395 | # fall back to non stopwords 396 | lw = get_non_stop_words(toks) 397 | else: 398 | raise NotImplementedError 399 | return lw 400 | 401 | 402 | def get_locations_words(doctoks: List[str], words: List[str], 403 | docoffsets: List[List[int]], 404 | allow_partial: bool = False) -> Dict: 405 | """ 406 | get locations for the words 407 | :param doctoks: 408 | :param words: list of words 409 | :param docoffsets: 410 | :param allow_partial: 411 | :return: 412 | """ 413 | wordspans = {} 414 | for word in words: 415 | ws = get_locations(doctoks, word, docoffsets, 416 | allow_partial=allow_partial) 417 | if len(ws) > 0: 418 | for key, values in ws.items(): 419 | if key in wordspans: 420 | wordspans[key] += values 421 | else: 422 | wordspans[key] = values 423 | return wordspans 424 | 425 | 426 | def lemmatize_docsents(doc_sents, stem): 427 | """ 428 | lemmatize the document sentences 429 | :param doc_sents: 430 | :param stem: 431 | :return: 432 | """ 433 | doc_sents_lemma = [] 434 | for idx, docl in enumerate(doc_sents): 435 | doc_sent = doc_sents[idx] 436 | docsent_lemma = [] 437 | for senttoks in doc_sent: 438 | docsent_lemma.append([stem(tok) for tok in senttoks]) 439 | doc_sents_lemma.append(docsent_lemma) 440 | return doc_sents_lemma 441 | -------------------------------------------------------------------------------- /pathnet/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.predictors.wikihop_predictor import WikiHopPredictor 2 | -------------------------------------------------------------------------------- /pathnet/predictors/wikihop_predictor.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from overrides import overrides 3 | from typing import List 4 | 5 | from allennlp.common.util import JsonDict, sanitize 6 | from allennlp.data import Instance 7 | from allennlp.predictors.predictor import Predictor 8 | 9 | 10 | @Predictor.register('wikihop_predictor') 11 | class WikiHopPredictor(Predictor): 12 | """ 13 | predictor interface for WikiHop 14 | """ 15 | 16 | def predict_instance(self, instance: Instance) -> JsonDict: 17 | """ 18 | Override this method to create a formatted JSON 19 | :param instance: 20 | :return: 21 | """ 22 | output = self._model.forward_on_instance(instance) 23 | num_cands = len(output['metadata']['choice_text_list']) 24 | lp = output['label_probs'][:num_cands] 25 | ans = output['metadata']['choice_text_list'][numpy.argmax(lp)] 26 | item_id = output['metadata']['id'] 27 | 28 | output_json = { 29 | "id": item_id, 30 | "answer": ans 31 | } 32 | return sanitize(output_json) 33 | 34 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 35 | outputs = self._model.forward_on_instances(instances) 36 | output_json_list = [] 37 | for i in range(len(outputs)): 38 | num_cands = len(outputs[i]['metadata']['choice_text_list']) 39 | lp = outputs[i]['label_probs'][:num_cands] 40 | ans = outputs[i]['metadata']['choice_text_list'][numpy.argmax(lp)] 41 | item_id = outputs[i]['metadata']['id'] 42 | output_json_list.append({ 43 | "id": item_id, 44 | "answer": ans 45 | }) 46 | 47 | return sanitize(output_json_list) 48 | 49 | @overrides 50 | def _json_to_instance(self, item_json: JsonDict) -> Instance: 51 | """ 52 | instatiate the data for the dataset reader 53 | """ 54 | item_id = item_json["id"] 55 | docsents = item_json['docsents'] 56 | question = item_json['question'] 57 | candidates = item_json['candidates'] 58 | paths = item_json['paths'] 59 | answer_str = None 60 | return self._dataset_reader.text_to_instance(item_id, docsents, 61 | question, candidates, paths, 62 | answer_str) 63 | -------------------------------------------------------------------------------- /pathnet/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.tokenizers.spacy_tokenizer import Tokens, SpacyTokenizer -------------------------------------------------------------------------------- /pathnet/tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spacy package and english model is required. 3 | """ 4 | 5 | import spacy 6 | import copy 7 | from typing import List, Any, Set 8 | 9 | 10 | class Tokens(object): 11 | """ 12 | A class to represent a list of tokenized text. 13 | """ 14 | TEXT = 0 15 | CHAR = 1 16 | TEXT_WS = 2 17 | SPAN = 3 18 | POS = 4 19 | LEMMA = 5 20 | NER = 6 21 | 22 | def __init__(self, data: List[Any], annotators: Set, 23 | opts: Any = None, sents: Any = None) -> None: 24 | self.data = data 25 | self.annotators = annotators 26 | self.opts = opts or {} 27 | self.sents = sents 28 | 29 | def __len__(self): 30 | """ 31 | The number of tokens. 32 | """ 33 | return len(self.data) 34 | 35 | def slice(self, i: int = None, j: int = None): 36 | """ 37 | Return a view of the list of tokens from [i, j). 38 | """ 39 | new_tokens = copy.copy(self) 40 | new_tokens.data = self.data[i: j] 41 | return new_tokens 42 | 43 | def untokenize(self) -> str: 44 | """ 45 | Returns the original text (with whitespace reinserted). 46 | """ 47 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 48 | 49 | def words(self, uncased: bool = False) -> List[str]: 50 | """ 51 | Returns a list of the text of each token 52 | Args: 53 | uncased: lower cases text 54 | """ 55 | if uncased: 56 | return [t[self.TEXT].lower() for t in self.data] 57 | else: 58 | return [t[self.TEXT] for t in self.data] 59 | 60 | def sentences(self, uncased: bool = False) -> List[List[str]]: 61 | """ 62 | Returns a list of the tokenized sentences 63 | Args: 64 | uncased: lower cases text 65 | """ 66 | if self.sents is not None: 67 | if uncased: 68 | sentences = [] 69 | for sen in self.sents: 70 | sentences.append([t.lower() for t in sen]) 71 | return sentences 72 | else: 73 | return self.sents 74 | 75 | def offsets(self) -> List[List[int]]: 76 | """ 77 | Returns a list of [start, end) character 78 | offsets of each token. 79 | """ 80 | return [t[self.SPAN] for t in self.data] 81 | 82 | def pos(self) -> Any: 83 | """ 84 | Returns a list of part-of-speech tags of each token. 85 | Returns None if this annotation was not included. 86 | """ 87 | if 'pos' not in self.annotators: 88 | return None 89 | return [t[self.POS] for t in self.data] 90 | 91 | def lemmas(self) -> Any: 92 | """ 93 | Returns a list of the lemmatized text of each token. 94 | Returns None if this annotation was not included. 95 | """ 96 | if 'lemma' not in self.annotators: 97 | return None 98 | return [t[self.LEMMA] for t in self.data] 99 | 100 | def entities(self) -> Any: 101 | """ 102 | Returns a list of named-entity-recognition tags of each token. 103 | Returns None if this annotation was not included. 104 | """ 105 | if 'ner' not in self.annotators: 106 | return None 107 | return [t[self.NER] for t in self.data] 108 | 109 | def ngrams(self, n: int = 1, uncased: bool = False, 110 | filter_fn: Any = None, 111 | as_strings: bool = True) -> Any: 112 | """ 113 | Returns a list of all ngrams from length 1 to n. 114 | Args: 115 | n: upper limit of ngram length 116 | uncased: lower cases text 117 | filter_fn: user function that takes in an ngram list and returns 118 | True or False to keep or not keep the ngram 119 | as_strings: return the ngram as a string vs list 120 | """ 121 | 122 | def _skip(gram): 123 | if not filter_fn: 124 | return False 125 | return filter_fn(gram) 126 | 127 | words = self.words(uncased) 128 | ngrams = [(s, e + 1) 129 | for s in range(len(words)) 130 | for e in range(s, min(s + n, len(words))) 131 | if not _skip(words[s:e + 1])] 132 | 133 | if as_strings: 134 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 135 | 136 | return ngrams 137 | 138 | def entity_groups(self) -> Any: 139 | """ 140 | Group consecutive entity tokens 141 | with the same NER tag. 142 | """ 143 | entities = self.entities() 144 | if not entities: 145 | return None 146 | non_ent = self.opts.get('non_ent', 'O') 147 | groups = [] 148 | idx = 0 149 | while idx < len(entities): 150 | ner_tag = entities[idx] 151 | if ner_tag != non_ent: 152 | start = idx 153 | while (idx < len(entities) and entities[idx] == ner_tag): 154 | idx += 1 155 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 156 | else: 157 | idx += 1 158 | return groups 159 | 160 | 161 | class SpacyTokenizer(object): 162 | 163 | def __init__(self, **kwargs) -> None: 164 | """ 165 | Args: 166 | annotators: set that can include pos, lemma, and ner. 167 | model: spaCy model to use (either path, or keyword like 'en'). 168 | """ 169 | model = kwargs.get('model', 'en') 170 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 171 | self.nlp = spacy.load(model) 172 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 173 | self.nlp.remove_pipe('tagger') 174 | if 'ner' not in self.annotators: 175 | self.nlp.remove_pipe('ner') 176 | 177 | def tokenize(self, text: str) -> Tokens: 178 | clean_text = text.replace('\n', ' ') 179 | tokens = self.nlp(clean_text) 180 | 181 | sentences = [s for s in tokens.sents] 182 | sents = [] 183 | for s in sentences: 184 | sents.append([t.text for t in s]) 185 | 186 | data = [] 187 | for i in range(len(tokens)): 188 | # Get whitespace 189 | start_ws = tokens[i].idx 190 | if i + 1 < len(tokens): 191 | end_ws = tokens[i + 1].idx 192 | else: 193 | end_ws = tokens[i].idx + len(tokens[i].text) 194 | 195 | data.append(( 196 | tokens[i].text, 197 | tokens[i].text[0] if len(tokens[i].text) > 0 else '', 198 | text[start_ws: end_ws], 199 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 200 | tokens[i].tag_, 201 | tokens[i].lemma_, 202 | tokens[i].ent_type_, 203 | )) 204 | 205 | return Tokens(data, self.annotators, opts={'non_ent': ''}, sents=sents) 206 | 207 | def shutdown(self): 208 | pass 209 | 210 | def __del__(self): 211 | self.shutdown() 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/allenai/allennlp.git@0b852fb2ef1000fa6209815d1c91171a9b093fef#egg=allennlp 2 | 3 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.tokenizers.spacy_tokenizer import SpacyTokenizer -------------------------------------------------------------------------------- /scripts/break_orig_wikihop_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def load_dict(fname): 6 | with open(fname, 'r') as fp: 7 | data = json.load(fp) 8 | return data 9 | 10 | 11 | def dump_dict(fname, data): 12 | with open(fname, 'w') as fp: 13 | json.dump(data, fp) 14 | 15 | 16 | def break_data(data, fname, bucket_size=5000): 17 | num_buckets = len(data) / bucket_size 18 | if num_buckets > int(num_buckets): 19 | num_buckets = int(num_buckets) + 1 20 | else: 21 | num_buckets = int(num_buckets) 22 | for i in range(num_buckets): 23 | print("Bucket: ", i) 24 | data_bucket = data[i * bucket_size:min((i + 1) * bucket_size, len(data))] 25 | dump_dict(fname[:-5] + str(i) + ".json", data_bucket) 26 | 27 | 28 | if __name__ == "__main__": 29 | data = load_dict(sys.argv[1]) 30 | break_data(data, sys.argv[1]) 31 | -------------------------------------------------------------------------------- /scripts/break_train_data_obqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def load_examples(fpath): 6 | data = [] 7 | with open(fpath, 'r') as fp: 8 | for line in fp: 9 | data.append(json.loads(line)) 10 | return data 11 | 12 | 13 | def split_examples(data, dump_dir, bucket_size=500): 14 | num_split = int(len(data) / bucket_size) 15 | if num_split * bucket_size == len(data): 16 | splits = num_split 17 | else: 18 | splits = num_split + 1 19 | 20 | for i in range(splits): 21 | split_data = data[i * bucket_size: min((i + 1) * bucket_size, len(data))] 22 | with open(dump_dir + '/split_' + str(i) + '.json', 'w') as fp: 23 | for d in split_data: 24 | fp.write(json.dumps(d) + '\n') 25 | 26 | 27 | data = load_examples(sys.argv[1]) 28 | print("data loaded") 29 | split_examples(data, sys.argv[2]) 30 | print("done!") 31 | -------------------------------------------------------------------------------- /scripts/break_train_data_wikihop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | 5 | 6 | def load_examples(fpath): 7 | data = [] 8 | with open(fpath, 'r') as fp: 9 | for line in fp: 10 | data.append(json.loads(line)) 11 | return data 12 | 13 | 14 | def split_examples(data, dump_dir, bucket_size=2000): 15 | num_split = int(len(data) / bucket_size) 16 | if num_split * bucket_size == len(data): 17 | splits = num_split 18 | else: 19 | splits = num_split + 1 20 | 21 | for i in range(splits): 22 | split_data = data[i * bucket_size: min((i + 1) * bucket_size, len(data))] 23 | with open(dump_dir + '/split_' + str(i) + '.json', 'w') as fp: 24 | for d in split_data: 25 | fp.write(json.dumps(d) + '\n') 26 | 27 | 28 | data = load_examples(sys.argv[1]) 29 | print("data loaded") 30 | dumpdir = sys.argv[2] 31 | if not os.path.isdir(dumpdir): 32 | os.makedirs(dumpdir) 33 | split_examples(data, dumpdir) 34 | print("done!") 35 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | # Download GloVe embeddings 6 | EMBDIR="data/embeddings" 7 | mkdir -p $EMBDIR 8 | wget -P $EMBDIR https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz 9 | 10 | # Download WikiHop data 11 | WH_DATA_DIR="data/datasets/WikiHop" 12 | mkdir -p $WH_DATA_DIR 13 | cd $WH_DATA_DIR 14 | wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=1ytVZ4AhubFDOEL7o7XrIRIyhU8g9wvKA' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ytVZ4AhubFDOEL7o7XrIRIyhU8g9wvKA" -O qangaroo_v1.1.zip && rm -rf /tmp/cookies.txt 15 | unzip qangaroo_v1.1.zip 16 | mv qangaroo_v1.1/wikihop/train.json ./ 17 | mv qangaroo_v1.1/wikihop/dev.json ./ # alternaltively create a softlink 18 | rm qangaroo_v1.1.zip 19 | cd ../../.. 20 | 21 | # Download OBQA data 22 | OBQA_DATA_DIR="data/datasets/OBQA" 23 | mkdir -p $OBQA_DATA_DIR 24 | wget -P $OBQA_DATA_DIR http://data.allenai.org/downloads/pathnet/obqa/inputs/obqa-commonsense-590k-wh-sorted100-train.json 25 | wget -P $OBQA_DATA_DIR http://data.allenai.org/downloads/pathnet/obqa/inputs/obqa-commonsense-590k-wh-sorted100-dev.json 26 | wget -P $OBQA_DATA_DIR http://data.allenai.org/downloads/pathnet/obqa/inputs/obqa-commonsense-590k-wh-sorted100-test.json 27 | 28 | # Download the pretrained WikiHop Model 29 | WH_MODEL_DIR="data/datasets/WikiHop/pretrained-model" 30 | mkdir -p $WH_MODEL_DIR 31 | wget -P $WH_MODEL_DIR http://data.allenai.org/downloads/pathnet/wikihop/model.tar.gz 32 | -------------------------------------------------------------------------------- /scripts/evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | 5 | def load(path): 6 | with open(path, 'r') as f: 7 | data = json.load(f) 8 | return data 9 | 10 | 11 | def compute_accuracy(predictions, gold_answers): 12 | assert len(predictions) == len(gold_answers) 13 | correct_prediction_counter = 0 14 | for (ID, gold_answer) in gold_answers.items(): 15 | correct_prediction_counter += (gold_answer == predictions[ID]) 16 | accuracy = float(correct_prediction_counter)/float(len(gold_answers)) 17 | return accuracy 18 | 19 | 20 | def main(args): 21 | """ 22 | Takes two arguments: (1) the directory of a file with predictions in JSON format, 23 | and (2) the name of the dataset part for which these predictions were computed. 24 | """ 25 | 26 | # parse input arguments 27 | predictions_path = args[1] # predictions file 28 | gold_path = args[2] # gold annotations file 29 | 30 | # load predictions and gold answers from corresponding files 31 | predictions = load(predictions_path) 32 | gold_data = load(gold_path) 33 | 34 | # dictionary with gold answer for each id 35 | gold_answers = {element['id']: element['answer'] for element in gold_data} 36 | 37 | # # compute accuracy and print. 38 | accuracy = compute_accuracy(predictions, gold_answers) 39 | print(accuracy) 40 | 41 | 42 | if __name__ == "__main__": 43 | main(sys.argv) 44 | -------------------------------------------------------------------------------- /scripts/expand_vocabulary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mainly required for evaluating on blind test set 3 | """ 4 | 5 | import argparse 6 | import logging 7 | import os 8 | import sys 9 | from collections import defaultdict 10 | from typing import Dict 11 | 12 | import torch 13 | from allennlp.models.archival import CONFIG_NAME 14 | from allennlp.modules.token_embedders.embedding import _read_pretrained_embeddings_file 15 | from allennlp.common import Tqdm 16 | from allennlp.data import DatasetReader 17 | from allennlp.models import load_archive, archive_model 18 | sys.path.append("./") 19 | import pathnet 20 | 21 | logger = logging.getLogger('scripts.expand_vocabulary') 22 | logger.setLevel(logging.INFO) 23 | 24 | 25 | def main(file, embeddings, model, emb_wt_key, namespace, output_dir): 26 | archive = load_archive(model) 27 | config = archive.config 28 | os.makedirs(output_dir, exist_ok=True) 29 | config.to_file(os.path.join(output_dir, CONFIG_NAME)) 30 | 31 | model = archive.model 32 | # first expand the vocabulary 33 | dataset_reader = DatasetReader.from_params(config.pop('dataset_reader')) 34 | instances = dataset_reader.read(file) 35 | vocab = model.vocab 36 | 37 | # get all the tokens in the new file 38 | namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) 39 | for instance in Tqdm.tqdm(instances): 40 | instance.count_vocab_items(namespace_token_counts) 41 | old_token_size = vocab.get_vocab_size(namespace) 42 | print("Before expansion: Number of instances in {} namespace: {}".format(namespace, 43 | old_token_size)) 44 | if namespace not in namespace_token_counts: 45 | logger.error("No tokens found for namespace: {} in the new input file".format(namespace)) 46 | # identify the new tokens in the new instances 47 | token_to_add = set() 48 | token_hits = 0 49 | for token, count in namespace_token_counts[namespace].items(): 50 | if token not in vocab._token_to_index[namespace]: 51 | # new token, must add 52 | token_to_add.add(token) 53 | else: 54 | token_hits += 1 55 | print("Found {} existing tokens and {} new tokens in {}".format(token_hits, 56 | len(token_to_add), file)) 57 | 58 | # add the new tokens to the vocab 59 | for token in token_to_add: 60 | vocab.add_token_to_namespace(token=token, namespace=namespace) 61 | archived_parameters = dict(model.named_parameters()) 62 | 63 | # second, expand the embedding matrix 64 | for name, weights in archived_parameters.items(): 65 | # find the wt matrix for the embeddings 66 | if name == emb_wt_key: 67 | if weights.dim() != 2: 68 | logger.error("Expected an embedding matrix for the parameter: {} instead" 69 | "found {} tensor".format(emb_wt_key, weights.shape)) 70 | emb_dim = weights.shape[-1] 71 | print("Before expansion: Size of emb matrix: {}".format(weights.shape)) 72 | # Loading embeddings for old and new tokens since that is cleaner than copying all 73 | # the embedding loading logic here 74 | all_embeddings = _read_pretrained_embeddings_file(embeddings, emb_dim, 75 | vocab, namespace) 76 | # concatenate the new entries i.e last token_to_add embeddings to the original weights 77 | if len(token_to_add) > 0: 78 | weights.data = torch.cat([weights.data, all_embeddings[-len(token_to_add):, :]]) 79 | print("After expansion: Size of emb matrix: {}".format(weights.shape)) 80 | 81 | # save the files needed by the model archiver 82 | model_path = os.path.join(output_dir, "weight.th") 83 | model_state = model.state_dict() 84 | torch.save(model_state, model_path) 85 | vocab.save_to_files(os.path.join(output_dir, "vocabulary")) 86 | archive_model(output_dir, weights="weight.th") 87 | 88 | # more debug messages 89 | new_token_size = vocab.get_vocab_size(namespace) 90 | for name, weights in archived_parameters.items(): 91 | if name == emb_wt_key: 92 | print("Size of emb matrix: {}".format(weights.shape)) 93 | print("After expansion: Number of instances in {} namespace: {}".format(namespace, 94 | new_token_size)) 95 | 96 | 97 | if __name__ == "__main__": 98 | """ 99 | Usage: 100 | python -u scripts/expand_vocabulary.py \ 101 | --file \ 102 | --emb_wt_key _text_field_embedder.token_embedder_tokens.weight \ 103 | --embeddings \ 104 | --model \ 105 | --output_dir 106 | """ 107 | parser = argparse.ArgumentParser(description='Expand vocabulary (and embeddings) of a model ' 108 | 'based on a new file') 109 | parser.add_argument('--file', type=str, required=True, 110 | help='Path to the new file (should be readable by the model\'s ' 111 | 'dataset reader') 112 | parser.add_argument('--emb_wt_key', type=str, required=True, 113 | help='Parameter name for the token embedding weight matrix') 114 | parser.add_argument('--namespace', type=str, default="tokens", help='Namespace to expand') 115 | parser.add_argument('--embeddings', type=str, required=True, help='Path to the embeddings file') 116 | parser.add_argument('--model', type=str, help='Path to the model file') 117 | parser.add_argument('--output_dir', type=str, help='The output directory to store the ' 118 | 'final model') 119 | 120 | args = parser.parse_args() 121 | main(args.file, args.embeddings, args.model, args.emb_wt_key, args.namespace, args.output_dir) -------------------------------------------------------------------------------- /scripts/install_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | pip install -r requirements.txt 4 | pip install nltk 5 | pip install -U spacy 6 | python -m spacy download en -------------------------------------------------------------------------------- /scripts/path_adjustments_obqa.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # path finder for splitted train files (OBQA) 4 | PREPROCESSED_DATA_DIR=$1 5 | PATHDIR=$2 6 | DUMPDIR=$3 7 | 8 | for y in 100; do 9 | echo "Numpaths -- $y" 10 | mkdir -p $DUMPDIR/paths${y} 11 | for x in train dev test; do 12 | echo "Split -- $x" 13 | echo "==============================" 14 | python scripts/prepro/obqa_prep_data_with_lemma.py $PREPROCESSED_DATA_DIR $PATHDIR $DUMPDIR/paths${y} \ 15 | --mode $x --maxnumpaths $y 16 | done 17 | done -------------------------------------------------------------------------------- /scripts/path_adjustments_wikihop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # path finder for splitted train files (WikiHop) 4 | PREPROCESSED_DATA_DIR=$1 5 | PATHDIR=$2 6 | DUMPDIR=$3 7 | 8 | for y in 30; do 9 | mkdir -p $DUMPDIR/paths${y} 10 | for x in train dev; do 11 | echo "Split -- $x" 12 | echo "Numpaths -- $y" 13 | echo "==============================" 14 | python scripts/prepro/wikihop_prep_data_with_lemma.py $PREPROCESSED_DATA_DIR $PATHDIR $DUMPDIR/paths${y} \ 15 | --mode $x --maxnumpaths $y 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/path_finder_obqa.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # path finder for splitted train files 4 | PREPROCESSED_DATA_DIR=$1 5 | PATHDIR=$2 6 | mkdir -p ${PATHDIR}/train-split 7 | 8 | for x in $(seq 0 9); do 9 | echo "Split -- $x" 10 | echo "==============================" 11 | python scripts/prepro/obqa_path_finder.py ${PREPROCESSED_DATA_DIR}/train-split/split_${x}.json ${PATHDIR}/train-split 12 | done 13 | 14 | # for dev 15 | python scripts/prepro/obqa_path_finder.py ${PREPROCESSED_DATA_DIR}/dev-processed-spacy.txt $PATHDIR 16 | 17 | # for test 18 | python scripts/prepro/obqa_path_finder.py ${PREPROCESSED_DATA_DIR}/test-processed-spacy.txt $PATHDIR -------------------------------------------------------------------------------- /scripts/path_finder_wikihop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # path finder for splitted train files (WikiHop) 4 | PREPROCESSED_DATA_DIR=$1 5 | PATHDIR=$2 6 | mkdir -p ${PATHDIR}/train-split 7 | 8 | for x in $(seq 0 21); do 9 | echo "Split -- $x" 10 | echo "==============================" 11 | python scripts/prepro/path_finder_wikihop.py ${PREPROCESSED_DATA_DIR}/train-split/split_${x}.json ${PATHDIR}/train-split 12 | done 13 | 14 | # for dev 15 | python scripts/prepro/path_finder_wikihop.py ${PREPROCESSED_DATA_DIR}/dev-processed-spacy.txt $PATHDIR -------------------------------------------------------------------------------- /scripts/path_finder_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | from nltk import PorterStemmer 3 | 4 | import sys 5 | sys.path.append("./") 6 | from pathnet.tokenizers.spacy_tokenizer import SpacyTokenizer 7 | from pathnet.pathfinder.path_extractor import PathFinder 8 | from pathnet.pathfinder.obqa_path_extractor import ObqaPathFinder 9 | from pathnet.pathfinder.util import lemmatize_docsents 10 | 11 | stemmer = PorterStemmer() 12 | stem = stemmer.stem 13 | 14 | ANNTOTORS = {'lemma', 'pos', 'ner'} 15 | TOK = SpacyTokenizer(annotators=ANNTOTORS) 16 | 17 | 18 | def tokenize(text: str) -> Dict: 19 | """Call the global process tokenizer 20 | on the input text. 21 | """ 22 | tokens = TOK.tokenize(text) 23 | output = { 24 | 'words': tokens.words(), 25 | 'offsets': tokens.offsets(), 26 | 'pos': tokens.pos(), 27 | 'lemma': tokens.lemmas(), 28 | 'ner': tokens.entities(), 29 | 'sentences': tokens.sentences(), 30 | } 31 | return output 32 | 33 | 34 | def process_data(documents: List[str], question: str, candidate: str) -> Dict: 35 | """ 36 | process the instance 37 | :param documents: list of documents 38 | :param question: 39 | :param candidate: 40 | :return: 41 | """ 42 | data = dict() 43 | q_tokens = tokenize(question) 44 | data['question'] = q_tokens['words'] 45 | data['qlemma'] = q_tokens['lemma'] 46 | data['qpos'] = q_tokens['pos'] 47 | data['qner'] = q_tokens['ner'] 48 | cand_tokens = tokenize(candidate) 49 | data['candidate'] = cand_tokens['words'] 50 | data['cpos'] = cand_tokens['lemma'] 51 | data['cner'] = cand_tokens['ner'] 52 | data['clemma'] = cand_tokens['lemma'] 53 | doc_tokens = [tokenize(doc) for doc in documents] 54 | data['documents'] = [doc['words'] for doc in doc_tokens] 55 | data['docners'] = [doc['ner'] for doc in doc_tokens] 56 | data['docpostags'] = [doc['pos'] for doc in doc_tokens] 57 | data['docsents'] = [doc['sentences'] for doc in doc_tokens] 58 | return data 59 | 60 | 61 | def find_paths(documents: List[str], question: str, candidate: str, 62 | style='wikihop') -> Optional[List]: 63 | """ 64 | Get the list of paths for a given (documents, question, candidate) 65 | :param documents: list of documents 66 | :param question: 67 | :param candidate: 68 | :param style: "wikihop" or "OBQA" -- OBQA style is for plain text questions 69 | :return: 70 | """ 71 | sentlimit = 1 72 | nearest_only = False 73 | d = process_data(documents, question, candidate) 74 | 75 | doc_ners = d['docners'] 76 | doc_postags = d['docpostags'] 77 | doc_sents = d['docsents'] 78 | 79 | qpos = d["qpos"] 80 | qner = d["qner"] 81 | qlemma = d['qlemma'] 82 | rel = qlemma[0] 83 | entity = ' '.join(qlemma[1:]).lower() 84 | candidates = [] 85 | orig_candidates = [d['candidate']] 86 | for ctoks in orig_candidates: 87 | sctoks = [stemmer.stem(ca) for ca in ctoks] 88 | if sctoks in candidates: 89 | candidates.append(ctoks) 90 | else: 91 | candidates.append(sctoks) 92 | candidates = [' '.join(cand) for cand in candidates] 93 | candpos = [d['cpos']] 94 | candner = [d['cner']] 95 | 96 | doc_sents_lemma = lemmatize_docsents(doc_sents, stem) 97 | 98 | if style.strip().lower() == "wikihop": 99 | pf = PathFinder("qid", doc_sents_lemma, 100 | entity, rel, 101 | candidates, 102 | answer=None, 103 | sentlimit=sentlimit, 104 | nearest_only=nearest_only) 105 | else: 106 | pf = ObqaPathFinder("qid", doc_sents_lemma, 107 | qlemma, qpos, qner, 108 | candidates, candpos, candner, 109 | answer=None, sentlimit=sentlimit, 110 | nearest_only=nearest_only) 111 | 112 | paths = pf.get_paths(doc_ners, doc_postags) 113 | if len(paths) == 0: 114 | print("No Paths Found !!") 115 | return None 116 | # pathdict = {"id": "qid", "pathlist": paths[list(paths.keys())[0]]} 117 | return paths[list(paths.keys())[0]] 118 | -------------------------------------------------------------------------------- /scripts/predict_wikihop.sh: -------------------------------------------------------------------------------- 1 | # script for answer prediction on WikiHop dev/test data 2 | #!/usr/bin/env bash 3 | 4 | set -x 5 | set -e 6 | 7 | DATA_DIR=$1 8 | MODEL_FILE=$2 9 | OUTFILE=$3 10 | EMBFILE="data/embeddings/glove.840B.300d.txt.gz" 11 | 12 | PREPROCESSED_DATA_DIR="data/datasets/WikiHop/prepro-data" 13 | PATHDIR="data/datasets/WikiHop/dev-paths" 14 | DUMPDIR="data/datasets/WikiHop/dev-adjusted-paths" 15 | PREDDIR="data/datasets/WikiHop/model-preds" 16 | UPDATED_MODEL_DIR="data/datasets/WikiHop/models-updated" 17 | 18 | 19 | echo "Creating Data Directory" 20 | mkdir -p $PREPROCESSED_DATA_DIR 21 | echo "Preprocessing Data" 22 | python scripts/prepro/preprocess_wikihop.py $DATA_DIR $PREPROCESSED_DATA_DIR --split dev \ 23 | --num-workers 6 24 | 25 | echo "Path Finding" 26 | mkdir -p $PATHDIR 27 | python scripts/prepro/path_finder_wikihop.py $PREPROCESSED_DATA_DIR/dev-processed-spacy.txt $PATHDIR 28 | 29 | echo "Path Adjusting" 30 | mkdir -p $DUMPDIR 31 | python scripts/prepro/wikihop_prep_data_with_lemma.py $PREPROCESSED_DATA_DIR $PATHDIR $DUMPDIR \ 32 | --mode dev --maxnumpaths 30 33 | 34 | echo "Expanding vocabulary" 35 | mkdir -p $UPDATED_MODEL_DIR 36 | python scripts/expand_vocabulary.py --file $DUMPDIR/dev-path-lines.txt \ 37 | --emb_wt_key _text_field_embedder.token_embedder_tokens.weight \ 38 | --embeddings $EMBFILE \ 39 | --model $MODEL_FILE \ 40 | --output_dir $UPDATED_MODEL_DIR 41 | 42 | echo "Running Allennlp Prediction" 43 | mkdir -p $PREDDIR 44 | CUDA_VISIBLE_DEVICES=None allennlp predict --output-file $PREDDIR/dev_predictions.txt \ 45 | --predictor wikihop_predictor \ 46 | --batch-size 3 --include-package pathnet --cuda-device -1 \ 47 | --silent $UPDATED_MODEL_DIR/model.tar.gz \ 48 | $DUMPDIR/dev-path-lines.txt 49 | 50 | echo "Preparing Prediction Dictionary" 51 | python scripts/prepare_outfile.py $PREDDIR/dev_predictions.txt $OUTFILE 52 | -------------------------------------------------------------------------------- /scripts/prepare_outfile.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def get_out_json(predf, outf): 6 | """ 7 | :param predf: 8 | :param outf: 9 | :return: 10 | """ 11 | preddict = {} 12 | with open(predf, 'r') as fp: 13 | for line in fp: 14 | data = json.loads(line) 15 | preddict[data["id"]] = data["answer"] 16 | 17 | with open(outf, 'w') as fp: 18 | json.dump(preddict, fp) 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('predfile', type=str, 24 | help='Path to prdicted file') 25 | parser.add_argument('outfile', type=str, 26 | help='Output file') 27 | args = parser.parse_args() 28 | get_out_json(args.predfile, args.outfile) 29 | -------------------------------------------------------------------------------- /scripts/prepro/__init__.py: -------------------------------------------------------------------------------- 1 | from pathnet.tokenizers.spacy_tokenizer import * 2 | from pathnet.tokenizers import * 3 | from pathnet.tokenizers.spacy_tokenizer import SpacyTokenizer 4 | -------------------------------------------------------------------------------- /scripts/prepro/obqa_path_finder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for path finding task for the OBQA dataset 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | import time 9 | 10 | from multiprocessing import Pool 11 | from functools import partial 12 | from typing import * 13 | from tqdm import tqdm 14 | from nltk import PorterStemmer 15 | 16 | import sys 17 | 18 | sys.path.append("./") 19 | from pathnet.pathfinder.obqa_path_extractor import ObqaPathFinder 20 | from pathnet.pathfinder.util import lemmatize_docsents 21 | 22 | stemmer = PorterStemmer() 23 | stem = stemmer.stem 24 | 25 | 26 | def load_examples(fpath: str) -> List[Dict]: 27 | """Load the preprocessed examples 28 | """ 29 | data = [] 30 | with open(fpath, 'r') as fp: 31 | for line in fp: 32 | data.append(json.loads(line)) 33 | return data 34 | 35 | 36 | def init(): 37 | pass 38 | 39 | 40 | def process_examples(d: Dict) -> Dict: 41 | lemma = True 42 | sentlimit = 1 43 | nearest_only = False 44 | qid = d['id'] 45 | if "answer" in d: 46 | ans = ' '.join(d['answer']).lower() 47 | else: 48 | ans = 'DUMMYANSWER' 49 | doc_ners = d['docners'] 50 | doc_postags = d['docpostags'] 51 | doc_sents = d['docsents'] 52 | question = d['question'] 53 | qpos = d["qpos"] 54 | qner = d["qner"] 55 | orig_candidates = d['candidates'] 56 | candpos = d["candidatepos"] 57 | candner = d["candidatener"] 58 | 59 | if not lemma: 60 | candidates = [' '.join(cand) for cand in orig_candidates] 61 | pf = ObqaPathFinder(qid, doc_sents, 62 | question, 63 | qpos, qner, 64 | candidates, 65 | candpos, candner, 66 | answer=ans, 67 | sentlimit=sentlimit, 68 | nearest_only=nearest_only) 69 | else: 70 | qlemma = [stem(qtok) for qtok in question] 71 | candidates = [] 72 | for ctoks in orig_candidates: 73 | sctoks = [stem(ca) for ca in ctoks] 74 | if sctoks in candidates: 75 | candidates.append(ctoks) 76 | else: 77 | candidates.append(sctoks) 78 | candidates = [' '.join(cand) for cand in candidates] 79 | doc_sents_lemma = lemmatize_docsents(doc_sents, stem) 80 | 81 | pf = ObqaPathFinder(qid, doc_sents_lemma, 82 | qlemma, 83 | qpos, qner, 84 | candidates, 85 | candpos, candner, 86 | answer=ans, 87 | sentlimit=sentlimit, 88 | nearest_only=nearest_only) 89 | 90 | paths = pf.get_paths(doc_ners, doc_postags) 91 | if lemma: 92 | orig_candidates = [' '.join(cand) for cand in orig_candidates] 93 | keys = list(paths.keys()) 94 | if len(keys) < len(orig_candidates): 95 | uniq_cands = [] 96 | for cand in orig_candidates: 97 | if cand not in uniq_cands: 98 | uniq_cands.append(cand) 99 | else: 100 | uniq_cands = orig_candidates 101 | # assert len(orig_candidates) == 4 102 | new_paths = {} 103 | 104 | assert len(keys) == len(uniq_cands) 105 | for idx, key in enumerate(keys): 106 | new_paths[uniq_cands[idx]] = paths[key] 107 | pathdict = {"id": qid, "pathlist": new_paths} 108 | else: 109 | assert len(paths) == len(orig_candidates) 110 | pathdict = {"id": qid, "pathlist": paths} 111 | return pathdict 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('datafile', type=str, 117 | help='Path to preprocessed data file') 118 | parser.add_argument('dumpdir', type=str, 119 | help='Directory to dump the paths') 120 | parser.add_argument('--sentlimit', type=int, default=1, 121 | help='how many next sentences to look for ne/nouns') 122 | parser.add_argument('--take_nearest_only', type=bool, default=False, 123 | help='whether to take nearest candidate only') 124 | parser.add_argument('--numworkers', type=int, default=8, 125 | help='number of workers for multiprocessing') 126 | parser.add_argument('--qenttype', type=str, default='nps', 127 | help='nps (NP/NER) OR nonstopword') 128 | parser.add_argument('--candenttype', type=str, default='nps', 129 | help='nps (NP/NER) OR nonstopword') 130 | args = parser.parse_args() 131 | 132 | t0 = time.time() 133 | 134 | infile = args.datafile 135 | data = load_examples(infile) 136 | print("Data Loaded..") 137 | 138 | num_paths = 0 139 | num_cands = 0 140 | 141 | print("Computing paths..") 142 | num_workers = args.numworkers 143 | make_pool = partial(Pool, num_workers, initializer=init) 144 | workers = make_pool(initargs=()) 145 | path_list = tqdm(workers.map(process_examples, data), total=len(data)) 146 | workers.close() 147 | workers.join() 148 | # path_list = list(tqdm(map_per_process(process_examples, data), total=len(data))) 149 | 150 | print("Analysing stats..") 151 | max_paths_per_cand = 0 152 | max_paths_per_q = 0 153 | for ps in path_list: 154 | paths_per_q = 0 155 | for p in ps['pathlist'].values(): 156 | num_paths += len(p) 157 | num_cands += 1 158 | paths_per_q += len(p) 159 | if len(p) > max_paths_per_cand: 160 | max_paths_per_cand = len(p) 161 | if paths_per_q > max_paths_per_q: 162 | max_paths_per_q = paths_per_q 163 | 164 | if not os.path.isdir(args.dumpdir): 165 | os.makedirs(args.dumpdir) 166 | with open(os.path.join(args.dumpdir, os.path.basename(infile) + '.paths'), 'w') as fp: 167 | for pp in path_list: 168 | fp.write(json.dumps(pp) + '\n') 169 | 170 | print("Avg #paths/question: %.4f" % (num_paths / len(data))) 171 | print("Avg #paths/candidate: %.4f" % (num_paths / num_cands)) 172 | print("Max #paths/question: %d" % max_paths_per_q) 173 | print("Max #paths/candidate: %d" % max_paths_per_cand) 174 | print('Total time: %.4f (s)' % (time.time() - t0)) 175 | 176 | 177 | -------------------------------------------------------------------------------- /scripts/prepro/obqa_prep_data_with_lemma.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for path adjustment step for the OBQA dataset 3 | """ 4 | 5 | import os 6 | import json 7 | from typing import List, Dict, Tuple 8 | import re 9 | from numpy import array 10 | import time 11 | import argparse 12 | import sys 13 | sys.path.append('./') 14 | from pathnet.pathfinder.util import STOPWORDS 15 | from nltk import PorterStemmer 16 | 17 | stemmer = PorterStemmer() 18 | 19 | 20 | def load_dict(fname): 21 | with open(fname, 'r') as fp: 22 | data = json.load(fp) 23 | return data 24 | 25 | 26 | def load_lines(fname): 27 | data = [] 28 | with open(fname, 'r') as fp: 29 | for line in fp: 30 | data.append(json.loads(line)) 31 | return data 32 | 33 | 34 | def get_locs_given_objs(doc: str, word: str, objs: List): 35 | doctoks = doc.split() 36 | ch_locs = [ob.span()[0] for ob in objs] 37 | found_words = [ob.group(0) for ob in objs] 38 | assert len(found_words) == len(ch_locs) 39 | for fwidx, fw in enumerate(found_words): 40 | start_offset = len(fw) - len(fw.lstrip()) 41 | ch_locs[fwidx] += start_offset 42 | found_words[fwidx] = fw.strip() 43 | widxs = get_widxs_from_chidxs(ch_locs, create_offsets(doctoks)) 44 | locs = [(widx, widx + len(found_words[i].split()) - 1) for i, widx in enumerate(widxs)] 45 | return locs 46 | 47 | 48 | def get_widxs_from_chidxs(chidxs: List[int], 49 | offsets: List[List[int]]) -> List[int]: 50 | """ 51 | Find word indices given character indices 52 | :param chidxs: 53 | :param offsets: 54 | :return: 55 | """ 56 | last_ch_idx = offsets[0][0] 57 | assert max(chidxs) < offsets[-1][1] - last_ch_idx 58 | widxs = [] 59 | for chidx in chidxs: 60 | for oi in range(len(offsets)): 61 | if chidx in range(offsets[oi][0] - last_ch_idx, offsets[oi][1] - last_ch_idx): 62 | widxs.append(oi) 63 | break 64 | elif chidx in range(offsets[oi][1] - last_ch_idx, 65 | offsets[min(oi + 1, len(offsets))][0] - last_ch_idx): 66 | widxs.append(oi) 67 | break 68 | assert len(chidxs) == len(widxs) 69 | return widxs 70 | 71 | 72 | def create_offsets(doctoks: List[str]) -> List[List[int]]: 73 | """ 74 | create offsets for a document tokens 75 | :param doctoks: 76 | :return: 77 | """ 78 | offsets = [] 79 | char_count = 0 80 | for tok in doctoks: 81 | offsets.append([char_count, char_count + len(tok)]) 82 | char_count = char_count + len(tok) + 1 83 | return offsets 84 | 85 | 86 | def find_backup_path(docsents, q, cand, k=40): 87 | """ 88 | If no path is found create a dummy backup path 89 | :param docsents: 90 | :param q: 91 | :param cand: 92 | :param k: 93 | :return: 94 | """ 95 | path_for_cand_dict = {"he_docidx": None, 96 | "he_locs": None, 97 | "e1wh_loc": None, 98 | "e1_docidx": None, 99 | "e1_locs": None, 100 | "cand_docidx": None, 101 | "cand_locs": None, 102 | "he_words": ["BACKUP"], 103 | "e1wh": "BACKUP", 104 | "e1": "BACKUP", 105 | "cand_words": ["BACKUP"] 106 | } 107 | 108 | ent_words = [qtok for qtok in q if qtok not in STOPWORDS] 109 | flag = 0 110 | for entw in ent_words: 111 | he = entw.lower() 112 | if len(he.split()) == 0: 113 | path_for_cand_dict['he_docidx'] = 0 114 | path_for_cand_dict['he_locs'] = [(-1, -1)] 115 | else: 116 | pat_he = re.compile('(^|\W)' + re.escape(he) + '\W') 117 | 118 | for docssidx, docss in enumerate(docsents): 119 | doc = ' '.join(' '.join(sum(docss, [])).split()) 120 | doc = doc.lower() 121 | he_objs = [] 122 | for x in pat_he.finditer(doc): 123 | he_objs.append(x) 124 | if len(he_objs) > 0: 125 | flag = 1 126 | path_for_cand_dict['he_docidx'] = docssidx 127 | path_for_cand_dict['he_locs'] = get_locs_given_objs(doc, he, he_objs)[:k] 128 | break 129 | if flag == 1: 130 | break 131 | 132 | cand_toks = cand.split() 133 | cand_words = [candtok for candtok in cand_toks if candtok not in STOPWORDS] 134 | flag = 0 135 | for cand in cand_words: 136 | cand = cand.lower() 137 | pat_cand = re.compile('(^|\W)' + re.escape(cand) + '\W') 138 | for docssidx, docss in enumerate(docsents): 139 | doc = ' '.join(' '.join(sum(docss, [])).split()) 140 | doc = doc.lower() 141 | ca_objs = [] 142 | for x in pat_cand.finditer(doc): 143 | ca_objs.append(x) 144 | if len(ca_objs) > 0: 145 | flag = 1 146 | path_for_cand_dict['cand_docidx'] = docssidx 147 | path_for_cand_dict['cand_locs'] = get_locs_given_objs(doc, cand, ca_objs)[:k] 148 | break 149 | if flag == 1: 150 | break 151 | 152 | if path_for_cand_dict['he_docidx'] is None or path_for_cand_dict['he_locs'] is None: 153 | path_for_cand_dict['he_docidx'] = 0 154 | path_for_cand_dict['he_locs'] = [(-1, -1)] 155 | if path_for_cand_dict['cand_docidx'] is None or path_for_cand_dict['cand_locs'] is None: 156 | path_for_cand_dict['cand_docidx'] = 0 157 | path_for_cand_dict['cand_locs'] = [(0, 0)] 158 | 159 | return path_for_cand_dict 160 | 161 | 162 | def get_min_abs_diff(list1: List[int], list2: List[int]) -> float: 163 | return float(min([min(abs(array(list1) - bi)) for bi in list2])) 164 | 165 | 166 | def get_start_locs(list1: List[Tuple[int, int]]) -> List[int]: 167 | return [l[0] for l in list1] 168 | 169 | 170 | def filter_paths(paths_for_cand: List[Dict], k=50) -> List[Dict]: 171 | if len(paths_for_cand) < k: 172 | return paths_for_cand 173 | 174 | scores = [] 175 | for path in paths_for_cand: 176 | # he_docidx = path['he_docidx'] 177 | he_words = path['he_words'] 178 | cand_words = path['cand_words'] 179 | scores.append(len(he_words) + len(cand_words)) 180 | assert len(scores) == len(paths_for_cand) 181 | 182 | temp = sorted(zip(paths_for_cand, scores), 183 | key=lambda x: x[1], 184 | reverse=True) 185 | sorted_paths, _ = map(list, zip(*temp)) 186 | sorted_paths = sorted_paths[:k] 187 | 188 | return sorted_paths 189 | 190 | 191 | def filter_paths_povrlp(paths_for_cand: List[Dict], 192 | docsents: List[List[List[str]]], k=50) -> List[Dict]: 193 | if len(paths_for_cand) < k: 194 | return paths_for_cand 195 | 196 | scores = [] 197 | for path in paths_for_cand: 198 | he_docidx = path['he_docidx'] 199 | cand_docidx = path['cand_docidx'] 200 | psg1 = sum(docsents[he_docidx], []) 201 | psg2 = sum(docsents[cand_docidx], []) 202 | ovlp_score = overlap_score(psg1, psg2) 203 | scores.append(ovlp_score) 204 | assert len(scores) == len(paths_for_cand) 205 | 206 | temp = sorted(zip(paths_for_cand, scores), 207 | key=lambda x: x[1], 208 | reverse=False) # ascending order 209 | sorted_paths, _ = map(list, zip(*temp)) 210 | sorted_paths = sorted_paths[:k] 211 | 212 | return sorted_paths 213 | 214 | 215 | def overlap_score(psg1: List[str], psg2: List[str]) -> float: 216 | """ 217 | 218 | :param psg1: 219 | :param psg2: 220 | :return: 221 | """ 222 | valid_p1_toks = [tok.lower() for tok in psg1 if tok.lower() not in STOPWORDS] 223 | valid_p2_toks = [tok.lower() for tok in psg2 if tok.lower() not in STOPWORDS] 224 | valid_p1_toks = set(valid_p1_toks) 225 | valid_p2_toks = set(valid_p2_toks) 226 | count = len(valid_p1_toks.intersection(valid_p2_toks)) 227 | return count 228 | 229 | 230 | def get_doc_len(docsents, docidx): 231 | doc = docsents[docidx] 232 | doc = ' '.join(sum(doc, [])).split() 233 | return len(doc) 234 | 235 | 236 | def adjust_word_idxs(toks, widxs): 237 | split_toks = ' '.join(toks).split() 238 | if len(toks) == len(split_toks): 239 | return widxs 240 | else: 241 | mod_widxs = [widx - (len(toks[:widx]) - len(' '.join(toks[:widx]).split())) 242 | for widx in widxs] 243 | return mod_widxs 244 | 245 | 246 | def process_path_for_cand(path, docsents): 247 | """ 248 | processing the paths for a candidate 249 | :param path: 250 | :param docsents: 251 | :return: 252 | """ 253 | he_docidx = path['head_ent_docidx'] 254 | he_doc_len = get_doc_len(docsents, he_docidx) 255 | he_ent_loc_dict = path['head_ent'] 256 | he_locs = [] 257 | he_words = list(he_ent_loc_dict.keys()) 258 | for he_w in list(he_ent_loc_dict.keys()): 259 | start_loc_list = he_ent_loc_dict[he_w] 260 | start_loc_list = adjust_word_idxs(sum(docsents[he_docidx], []), 261 | start_loc_list) 262 | for s in start_loc_list: 263 | assert s < he_doc_len 264 | end_loc_list = [max(s, s + len(he_w.split()) - 1) 265 | for s in start_loc_list] # end is inclusive 266 | 267 | for e in end_loc_list: 268 | assert e < he_doc_len 269 | combined_locs = [(s, e) for s, e in zip(start_loc_list, end_loc_list)] 270 | he_locs += combined_locs 271 | 272 | e1 = path['e1'] 273 | if e1 is None: 274 | e1_with_head_loc = [(-1, -1)] 275 | e1_docidx = None 276 | e1_locs = [(-1, -1)] 277 | cand_docidx = he_docidx 278 | else: 279 | e1wh_locs = [path['e1_with_head_widx']] 280 | e1wh_locs = adjust_word_idxs(sum(docsents[he_docidx], []), e1wh_locs) 281 | e1_with_head_loc = [(e1wh_locs[0], 282 | e1wh_locs[0] + len( 283 | path['e1_with_head_ent'].split()) - 1)] # inclusive end 284 | for s, e in e1_with_head_loc: 285 | assert s < he_doc_len 286 | assert e < he_doc_len 287 | e1_docidx = path['e1_docidx'] 288 | e1_doc_len = get_doc_len(docsents, e1_docidx) 289 | e1_start_locs = path['e1_locs'] 290 | e1_start_locs = adjust_word_idxs(sum(docsents[e1_docidx], []), e1_start_locs) 291 | e1_locs = [(e1s, e1s + len(e1.split()) - 1) for e1s in e1_start_locs] 292 | for s, e in e1_locs: 293 | assert s < e1_doc_len 294 | assert e < e1_doc_len 295 | cand_docidx = e1_docidx 296 | 297 | cand_loc_dict = path['cand_locs'] 298 | cand_words = list(cand_loc_dict.keys()) 299 | cand_doc_len = get_doc_len(docsents, cand_docidx) 300 | cand_locs = [] 301 | for ca_w in list(cand_loc_dict.keys()): 302 | ca_w_start_loc_list = cand_loc_dict[ca_w] 303 | ca_w_start_loc_list = adjust_word_idxs(sum(docsents[cand_docidx], []), 304 | ca_w_start_loc_list) 305 | ca_w_end_loc_list = [s + len(ca_w.split()) - 1 306 | for s in ca_w_start_loc_list] # end is inclusive 307 | combined_locs_ca_w = [(s, e) for s, e in zip(ca_w_start_loc_list, ca_w_end_loc_list)] 308 | for s, e in combined_locs_ca_w: 309 | assert s < cand_doc_len 310 | assert e < cand_doc_len 311 | cand_locs += combined_locs_ca_w 312 | 313 | path_for_cand_dict = {"he_docidx": he_docidx, 314 | "he_locs": he_locs, 315 | "e1wh_loc": e1_with_head_loc, 316 | "e1_docidx": e1_docidx, 317 | "e1_locs": e1_locs, 318 | "cand_docidx": cand_docidx, 319 | "cand_locs": cand_locs, 320 | "he_words": he_words, 321 | "e1wh": path['e1_with_head_ent'], 322 | "e1": e1, 323 | "cand_words": cand_words 324 | } 325 | return path_for_cand_dict 326 | 327 | 328 | def process_allpaths_for_cand(cand, path_data_for_cand, 329 | qtoks, docsents, max_num_paths): 330 | """ 331 | process all paths for a particular candidate 332 | :param cand: 333 | :param path_data_for_cand: 334 | :param qtoks: 335 | :param docsents: 336 | :param max_num_paths: 337 | :return: 338 | """ 339 | if len(qtoks) == 1 or len(path_data_for_cand) == 0: 340 | paths_for_cand = [find_backup_path(docsents, qtoks, cand)] 341 | return paths_for_cand 342 | 343 | paths_for_cand = [] 344 | for pathforcand in path_data_for_cand: 345 | path_for_cand_dict_ = process_path_for_cand(pathforcand, docsents) 346 | paths_for_cand.append(path_for_cand_dict_) 347 | 348 | if len(paths_for_cand) == 0: 349 | paths_for_cand = [find_backup_path(docsents, qtoks, cand)] 350 | if len(paths_for_cand) > max_num_paths: 351 | # # passage overlap-based scoring 352 | # paths_for_cand = filter_paths_povrlp(paths_for_cand, docsents, k=max_num_paths) 353 | # entity-overlap 354 | paths_for_cand = filter_paths(paths_for_cand, k=max_num_paths) 355 | return paths_for_cand 356 | 357 | 358 | def lemmatize_docsents(doc_sents: List[List[List[str]]]): 359 | doc_sents_lemma = [] 360 | for idx, docl in enumerate(doc_sents): 361 | doc_sent = doc_sents[idx] 362 | docsent_lemma = [] 363 | for senttoks in doc_sent: 364 | # startidx = len(sum(doc_sent[:sidx], [])) 365 | docsent_lemma.append([stemmer.stem(tok) for tok in senttoks]) 366 | doc_sents_lemma.append(docsent_lemma) 367 | return doc_sents_lemma 368 | 369 | 370 | def process_paths(d: Dict, pathdata: Dict, 371 | max_num_paths: int = 30, lemmatize: int = False) -> Dict: 372 | """ 373 | Process paths 374 | :param d: data dictionary from extarcted paths 375 | :return: 376 | """ 377 | # ans = ' '.join(d['answer']) 378 | docsents = d['docsents'] # List[List[List[List[str]]] 379 | qid = d['id'] 380 | question = d['question'] 381 | mod_docsents = [] 382 | for doc in docsents: 383 | sents = [] 384 | for sent in doc: 385 | sents.append(' '.join(sent).split()) 386 | mod_docsents.append(sents) 387 | 388 | path_dict = {'id': qid, 'question': d['question'], 389 | "docsents": mod_docsents, 390 | 'answer': d['answer'], 391 | 'candidates': d['candidates'], 392 | } 393 | 394 | if lemmatize: 395 | docsents = lemmatize_docsents(docsents) 396 | question = [stemmer.stem(qw) for qw in question] 397 | 398 | path_for_all_cands = [] 399 | for cand in pathdata['pathlist'].keys(): 400 | path_data_for_cand_ = pathdata['pathlist'][cand] 401 | paths_for_cand_ = process_allpaths_for_cand(cand, path_data_for_cand_, 402 | question, docsents, 403 | max_num_paths=max_num_paths) 404 | path_for_all_cands.append(paths_for_cand_) 405 | 406 | path_dict['paths'] = path_for_all_cands 407 | 408 | return path_dict 409 | 410 | 411 | if __name__ == "__main__": 412 | parser = argparse.ArgumentParser() 413 | parser.add_argument('datadir', type=str, 414 | help='Path to preprocessed data dir') 415 | parser.add_argument('pathdir', type=str, 416 | help='Directory to read the paths') 417 | parser.add_argument('dumpdir', type=str, 418 | help='Directory to dump the paths') 419 | parser.add_argument('--mode', type=str, default='dev', 420 | help='train/dev/test') 421 | parser.add_argument('--maxnumpaths', type=int, default=100, 422 | help='How many maximum paths to consider') 423 | parser.add_argument('--lemmatize', type=bool, default=True, 424 | help='whether to lemmatize for path extraction') 425 | args = parser.parse_args() 426 | t0 = time.time() 427 | dumpdir = args.dumpdir 428 | if not os.path.isdir(args.dumpdir): 429 | os.makedirs(args.dumpdir) 430 | pathdir = args.pathdir 431 | mode = args.mode 432 | max_num_paths = args.maxnumpaths 433 | lemmatize = args.lemmatize 434 | if mode == 'dev': 435 | dev_data = load_lines(args.datadir + '/dev-processed-spacy.txt') 436 | paths = load_lines(args.pathdir + '/dev-processed-spacy.txt.paths') 437 | print("All data loaded") 438 | if mode == 'test': 439 | dev_data = load_lines(args.datadir + '/test-processed-spacy.txt') 440 | paths = load_lines(args.pathdir + '/test-processed-spacy.txt.paths') 441 | print("All data loaded") 442 | 443 | with open(args.dumpdir + '/' + mode + '-path-lines.txt', 'w') as fp: 444 | if mode == "train": 445 | num_splits = 10 446 | else: 447 | num_splits = 1 448 | for sp in range(num_splits): 449 | if mode == 'train': 450 | dev_data = load_lines(args.datadir + '/train-split/split_' + str(sp) + '.json') 451 | paths = load_lines(args.pathdir + '/train-split/split_' + str(sp) + '.json.paths') 452 | print("Data loaded for split %d " % sp) 453 | assert len(dev_data) == len(paths) 454 | for dataidx, data in enumerate(dev_data): 455 | pathdata_ = paths[dataidx] 456 | path_dict_ = process_paths(data, pathdata_, 457 | max_num_paths=max_num_paths, 458 | lemmatize=lemmatize) 459 | fp.write(json.dumps(path_dict_) + '\n') 460 | 461 | print("Done!") 462 | print('Total time: %.4f (s)' % (time.time() - t0)) 463 | 464 | 465 | -------------------------------------------------------------------------------- /scripts/prepro/path_finder_wikihop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for the pasth finding step for WikiHop 3 | """ 4 | 5 | import os 6 | import json 7 | import argparse 8 | import time 9 | from typing import List, Dict, Any 10 | from tqdm import tqdm 11 | 12 | from multiprocessing import Pool 13 | from functools import partial 14 | from nltk import PorterStemmer 15 | 16 | import sys 17 | sys.path.append("./") 18 | from pathnet.pathfinder.path_extractor import PathFinder 19 | 20 | stemmer = PorterStemmer() 21 | 22 | 23 | def load_examples(fpath: str) -> List[Dict]: 24 | """Load the preprocessed examples 25 | """ 26 | data = [] 27 | with open(fpath, 'r') as fp: 28 | for line in fp: 29 | data.append(json.loads(line)) 30 | return data 31 | 32 | 33 | def init(): 34 | pass 35 | 36 | 37 | def process_examples(d: Dict) -> Dict: 38 | lemma = True 39 | sentlimit = 1 40 | nearest_only = False 41 | qid = d['id'] 42 | if "answer" in d: 43 | ans = ' '.join(d['answer']).lower() 44 | else: 45 | ans = 'DUMMYANSWER' 46 | doc_ners = d['docners'] 47 | doc_postags = d['docpostags'] 48 | doc_sents = d['docsents'] 49 | 50 | if not lemma: 51 | rel = d['question'][0] 52 | entity = ' '.join(d['question'][1:]).lower() 53 | candidates = [' '.join(cand) for cand in d['candidates']] 54 | pf = PathFinder(qid, doc_sents, 55 | entity, rel, 56 | candidates, 57 | answer=ans, 58 | sentlimit=sentlimit, 59 | nearest_only=nearest_only) 60 | else: 61 | qlemma = [stemmer.stem(qtok) for qtok in d['question']] 62 | rel = qlemma[0] 63 | entity = ' '.join(qlemma[1:]).lower() 64 | candidates = [] 65 | orig_candidates = d['candidates'] 66 | for ctoks in orig_candidates: 67 | sctoks = [stemmer.stem(ca) for ca in ctoks] 68 | if sctoks in candidates: 69 | candidates.append(ctoks) 70 | else: 71 | candidates.append(sctoks) 72 | candidates = [' '.join(cand) for cand in candidates] 73 | doc_sents_lemma = [] 74 | for idx, docl in enumerate(doc_sents): 75 | doc_sent = doc_sents[idx] 76 | docsent_lemma = [] 77 | for senttoks in doc_sent: 78 | docsent_lemma.append([stemmer.stem(tok) for tok in senttoks]) 79 | doc_sents_lemma.append(docsent_lemma) 80 | 81 | pf = PathFinder(qid, doc_sents_lemma, 82 | entity, rel, 83 | candidates, 84 | answer=ans, 85 | sentlimit=sentlimit, 86 | nearest_only=nearest_only) 87 | 88 | paths = pf.get_paths(doc_ners, doc_postags) 89 | pathdict = {"id": qid, "pathlist": paths} 90 | return pathdict 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('datafile', type=str, 96 | help='Path to preprocessed data file') 97 | parser.add_argument('dumpdir', type=str, 98 | help='Directory to dump the paths') 99 | parser.add_argument('--sentlimit', type=int, default=1, 100 | help='how many next sentences to look for ne/nouns') 101 | parser.add_argument('--take_nearest_only', type=bool, default=False, 102 | help='whether to take nearest candidate only') 103 | parser.add_argument('--numworkers', type=int, default=6, 104 | help='number of workers for multiprocessing') 105 | args = parser.parse_args() 106 | 107 | t0 = time.time() 108 | 109 | infile = args.datafile 110 | data = load_examples(infile) 111 | print("Data Loaded..") 112 | 113 | num_paths = 0 114 | num_cands = 0 115 | 116 | print("Computing paths..") 117 | workers = args.numworkers 118 | make_pool = partial(Pool, workers, initializer=init) 119 | 120 | workers = make_pool(initargs=()) 121 | path_list = tqdm(workers.map(process_examples, data), total=len(data)) 122 | workers.close() 123 | workers.join() 124 | 125 | print("Analysing stats..") 126 | max_paths_per_cand = 0 127 | max_paths_per_q = 0 128 | for ps in path_list: 129 | paths_per_q = 0 130 | for p in ps['pathlist'].values(): 131 | num_paths += len(p) 132 | num_cands += 1 133 | paths_per_q += len(p) 134 | if len(p) > max_paths_per_cand: 135 | max_paths_per_cand = len(p) 136 | if paths_per_q > max_paths_per_q: 137 | max_paths_per_q = paths_per_q 138 | 139 | if not os.path.isdir(args.dumpdir): 140 | os.makedirs(args.dumpdir) 141 | with open(os.path.join(args.dumpdir, os.path.basename(infile) + '.paths'), 'w') as fp: 142 | for pp in path_list: 143 | fp.write(json.dumps(pp) + '\n') 144 | 145 | print("Avg #paths/question: %.4f" % (num_paths / len(data))) 146 | print("Avg #paths/candidate: %.4f" % (num_paths / num_cands)) 147 | print("Max #paths/question: %d" % max_paths_per_q) 148 | print("Max #paths/candidate: %d" % max_paths_per_cand) 149 | print('Total time: %.4f (s)' % (time.time() - t0)) 150 | -------------------------------------------------------------------------------- /scripts/prepro/preprocess_obqa.py: -------------------------------------------------------------------------------- 1 | """Preprocess the OBQA data.""" 2 | 3 | import sys 4 | import argparse 5 | import os 6 | from typing import List, Dict, Any 7 | 8 | import json 9 | import time 10 | import random 11 | 12 | from multiprocessing import Pool 13 | from multiprocessing.util import Finalize 14 | from functools import partial 15 | sys.path.append("./") 16 | from pathnet.tokenizers.spacy_tokenizer import SpacyTokenizer 17 | 18 | # ------------------------------------------------------------------------------ 19 | # Tokenize + annotate. 20 | # ------------------------------------------------------------------------------ 21 | 22 | TOK = None 23 | ANNTOTORS = {'lemma', 'pos', 'ner'} 24 | 25 | 26 | def init(): 27 | global TOK 28 | TOK = SpacyTokenizer(annotators=ANNTOTORS) 29 | Finalize(TOK, TOK.shutdown, exitpriority=100) 30 | 31 | 32 | def tokenize(text: str) -> Dict: 33 | """Call the global process tokenizer 34 | on the input text. 35 | """ 36 | global TOK 37 | tokens = TOK.tokenize(text) 38 | words = tokens.words() 39 | output = { 40 | 'words': words, 41 | 'offsets': tokens.offsets(), 42 | 'pos': tokens.pos(), 43 | 'lemma': tokens.lemmas(), 44 | 'ner': tokens.entities(), 45 | 'sentences': [words], 46 | } 47 | return output 48 | 49 | 50 | def splittext(text: str) -> Dict: 51 | output = { 52 | 'words': text.strip().split() 53 | } 54 | return output 55 | 56 | 57 | # ------------------------------------------------------------------------------ 58 | # Process data examples 59 | # ------------------------------------------------------------------------------ 60 | 61 | 62 | def load_dataset(path: str, 63 | shuffle_docs: bool = False) -> Dict: 64 | """Load json file and store 65 | fields separately. 66 | """ 67 | with open(path) as f: 68 | data = json.load(f) 69 | output = {'qids': [], 'questions': [], 'answers': [], 70 | 'contextlists': [], 'candidatelists': []} 71 | for ex in data: 72 | output['qids'].append(ex['id']) 73 | output['questions'].append(' '.join(ex['query'].split())) 74 | if "answer" in ex: 75 | output['answers'].append(' '.join(ex['answer'].split())) 76 | else: 77 | output['answers'].append("DUMMYANSWER") 78 | if shuffle_docs: 79 | random.shuffle(ex['supports']) 80 | supports = [' '.join(s.split()) for s in ex['supports']] 81 | output['contextlists'].append(supports) 82 | candidates = [' '.join(c.split()) for c in ex['candidates']] 83 | output['candidatelists'].append(candidates) 84 | return output 85 | 86 | 87 | def unroll(counts: List[int], l: List[Any]) -> List[List[Any]]: 88 | counts = [0] + counts 89 | unrolled_list = [] 90 | for idx in range(len(counts) - 1): 91 | curr_idx = sum(counts[:idx + 1]) 92 | next_idx = curr_idx + counts[idx + 1] 93 | unrolled_list.append(l[curr_idx:next_idx]) 94 | return unrolled_list 95 | 96 | 97 | def process_dataset(data: Dict, num_workers: int = None): 98 | """Iterate processing (tokenize, parse, etc) 99 | data multi-threaded. 100 | """ 101 | if num_workers > 1: 102 | make_pool = partial(Pool, num_workers, initializer=init) 103 | workers = make_pool(initargs=()) 104 | 105 | q_tokens = workers.map(tokenize, data['questions']) 106 | workers.close() 107 | workers.join() 108 | else: 109 | init() 110 | q_tokens = [tokenize(q) for q in data['questions']] 111 | 112 | # documents are in list format 113 | dcounts = [len(c) for c in data['contextlists']] 114 | if num_workers > 1: 115 | workers = make_pool(initargs=()) 116 | c_tokens = workers.map(tokenize, sum(data['contextlists'], [])) 117 | workers.close() 118 | workers.join() 119 | else: 120 | print("Tokenizing docs without multiprocessing..") 121 | c_tokens = [tokenize(c) for c in sum(data['contextlists'], [])] 122 | context_tokens = unroll(dcounts, c_tokens) 123 | 124 | if "answers" in data: 125 | if num_workers > 1: 126 | workers = make_pool(initargs=()) 127 | ans_tokens = workers.map(tokenize, data['answers']) 128 | workers.close() 129 | workers.join() 130 | else: 131 | ans_tokens = [tokenize(a) for a in data['answers']] 132 | else: 133 | ans_tokens = None 134 | 135 | candcounts = [len(c) for c in data['candidatelists']] 136 | if num_workers > 1: 137 | workers = make_pool(initargs=()) 138 | cnd_tokens = workers.map(tokenize, sum(data['candidatelists'], [])) 139 | workers.close() 140 | workers.join() 141 | else: 142 | print("Tokenizing candidates without multiprocessing..") 143 | cnd_tokens = [tokenize(ca) for ca in data['candidatelists']] 144 | cand_tokens = unroll(candcounts, cnd_tokens) 145 | 146 | for idx in range(len(data['qids'])): 147 | question = q_tokens[idx]['words'] 148 | qlemma = q_tokens[idx]['lemma'] 149 | qpos = q_tokens[idx]['pos'] 150 | qner = q_tokens[idx]['ner'] 151 | 152 | # supporting documents are in list format 153 | documents = [c['words'] for c in context_tokens[idx]] 154 | offsets = [c['offsets'] for c in context_tokens[idx]] 155 | cpostags = [c['pos'] for c in context_tokens[idx]] 156 | cners = [c['ner'] for c in context_tokens[idx]] 157 | clemmas = [c['lemma'] for c in context_tokens[idx]] 158 | doc_sents = [c['sentences'] for c in context_tokens[idx]] 159 | 160 | if ans_tokens is not None: 161 | answer = ans_tokens[idx]['words'] 162 | else: 163 | answer = ["DUMMYANSWER"] 164 | 165 | candidates = [ca['words'] for ca in cand_tokens[idx]] 166 | candidatelemmas = [ca['lemma'] for ca in cand_tokens[idx]] 167 | candidatepos = [ca['pos'] for ca in cand_tokens[idx]] 168 | candidatener = [ca['ner'] for ca in cand_tokens[idx]] 169 | 170 | yield { 171 | 'id': data['qids'][idx], 172 | 'question': question, 173 | 'qlemma': qlemma, 174 | 'qpos': qpos, 175 | 'qner': qner, 176 | 'documents': documents, 177 | 'offsets': offsets, 178 | 'docsents': doc_sents, 179 | 'docpostags': cpostags, 180 | 'docners': cners, 181 | 'doclemmas': clemmas, 182 | 'answer': answer, 183 | 'candidates': candidates, 184 | 'candidatelemmas': candidatelemmas, 185 | 'candidatepos': candidatepos, 186 | 'candidatener': candidatener 187 | } 188 | 189 | 190 | # ----------------------------------------------------------------------------- 191 | # Commandline options 192 | # ----------------------------------------------------------------------------- 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('data_file', type=str, help='Path to data file') 197 | parser.add_argument('out_dir', type=str, help='Path to output file dir') 198 | parser.add_argument('--split', type=str, help='Filename for train/dev split') 199 | parser.add_argument('--num-workers', type=int, default=6) 200 | args = parser.parse_args() 201 | 202 | t0 = time.time() 203 | 204 | in_file = args.data_file # os.path.join(args.data_dir, args.split + '.json') 205 | print('Loading data %s' % in_file, file=sys.stderr) 206 | dataset = load_dataset(in_file) 207 | 208 | if not os.path.isdir(args.out_dir): 209 | os.mkdir(args.out_dir) 210 | 211 | out_file = os.path.join( 212 | args.out_dir, '%s-processed-%s.txt' % (args.split, 'spacy') 213 | ) 214 | print('Will write to file %s' % out_file, file=sys.stderr) 215 | with open(out_file, 'w') as f: 216 | for ex in process_dataset(dataset, # args.tokenizer, 217 | args.num_workers): 218 | f.write(json.dumps(ex) + '\n') 219 | print('Total time: %.4f (s)' % (time.time() - t0)) 220 | -------------------------------------------------------------------------------- /scripts/prepro/preprocess_wikihop.py: -------------------------------------------------------------------------------- 1 | """Preprocess the WikiHop data.""" 2 | 3 | import sys 4 | import argparse 5 | import os 6 | from typing import List, Dict, Any 7 | 8 | import json 9 | import time 10 | import random 11 | from multiprocessing import Pool 12 | from multiprocessing.util import Finalize 13 | from functools import partial 14 | sys.path.append("./") 15 | from pathnet.tokenizers.spacy_tokenizer import SpacyTokenizer 16 | 17 | # ------------------------------------------------------------------------------ 18 | # Tokenize + annotate. 19 | # ------------------------------------------------------------------------------ 20 | 21 | TOK = None 22 | ANNTOTORS = {'lemma', 'pos', 'ner'} 23 | 24 | 25 | def init(): 26 | global TOK 27 | TOK = SpacyTokenizer(annotators=ANNTOTORS) 28 | Finalize(TOK, TOK.shutdown, exitpriority=100) 29 | 30 | 31 | def tokenize(text: str) -> Dict: 32 | """Call the global process tokenizer 33 | on the input text. 34 | """ 35 | global TOK 36 | tokens = TOK.tokenize(text) 37 | output = { 38 | 'words': tokens.words(), 39 | 'offsets': tokens.offsets(), 40 | 'pos': tokens.pos(), 41 | 'lemma': tokens.lemmas(), 42 | 'ner': tokens.entities(), 43 | 'sentences': tokens.sentences(), 44 | } 45 | return output 46 | 47 | 48 | def splittext(text: str) -> Dict: 49 | output = { 50 | 'words': text.strip().split() 51 | } 52 | return output 53 | 54 | 55 | # ------------------------------------------------------------------------------ 56 | # Process data examples 57 | # ------------------------------------------------------------------------------ 58 | 59 | 60 | def load_dataset(path: str, 61 | shuffle_docs: bool = False) -> Dict: 62 | """Load json file and store 63 | fields separately. 64 | """ 65 | with open(path) as f: 66 | data = json.load(f) 67 | output = {'qids': [], 'questions': [], 'answers': [], 68 | 'contextlists': [], 'candidatelists': []} 69 | for ex in data: 70 | output['qids'].append(ex['id']) 71 | output['questions'].append(ex['query']) 72 | if "answer" in ex: 73 | output['answers'].append(ex['answer']) 74 | else: 75 | output['answers'].append("DUMMYANSWER") 76 | if shuffle_docs: 77 | random.shuffle(ex['supports']) 78 | output['contextlists'].append(ex['supports']) 79 | output['candidatelists'].append(ex['candidates']) 80 | return output 81 | 82 | 83 | def unroll(counts: List[int], l: List[Any]) -> List[List[Any]]: 84 | counts = [0] + counts 85 | unrolled_list = [] 86 | for idx in range(len(counts) - 1): 87 | curr_idx = sum(counts[:idx + 1]) 88 | next_idx = curr_idx + counts[idx + 1] 89 | unrolled_list.append(l[curr_idx:next_idx]) 90 | return unrolled_list 91 | 92 | 93 | def process_dataset(data: Dict, workers: int = None): 94 | """Iterate processing (tokenize, parse, etc) 95 | data multi-threaded. 96 | """ 97 | make_pool = partial(Pool, workers, initializer=init) 98 | 99 | workers = make_pool(initargs=()) 100 | q_tokens = workers.map(tokenize, data['questions']) 101 | workers.close() 102 | workers.join() 103 | 104 | c_tokens = [] 105 | print("Tokenizing Passages...") 106 | dcounts = [len(c) for c in data['contextlists']] 107 | all_ctxs = sum(data['contextlists'], []) 108 | num_buckets = len(all_ctxs) / 5000 109 | if num_buckets > int(num_buckets): 110 | num_buckets = int(num_buckets) + 1 111 | else: 112 | num_buckets = int(num_buckets) 113 | for ii in range(num_buckets): 114 | print("Bucket: ", ii) 115 | ctx_bucket = all_ctxs[ii * 5000:min((ii + 1) * 5000, len(all_ctxs))] 116 | workers = make_pool(initargs=()) 117 | c_tokens += workers.map(tokenize, ctx_bucket) 118 | workers.close() 119 | workers.join() 120 | context_tokens = unroll(dcounts, c_tokens) 121 | 122 | if "answers" in data: 123 | workers = make_pool(initargs=()) 124 | ans_tokens = workers.map(tokenize, data['answers']) 125 | workers.close() 126 | workers.join() 127 | else: 128 | ans_tokens = None 129 | 130 | candcounts = [len(c) for c in data['candidatelists']] 131 | workers = make_pool(initargs=()) 132 | cnd_tokens = workers.map(tokenize, sum(data['candidatelists'], [])) 133 | workers.close() 134 | workers.join() 135 | cand_tokens = unroll(candcounts, cnd_tokens) 136 | 137 | for idx in range(len(data['qids'])): 138 | question = q_tokens[idx]['words'] 139 | qlemma = q_tokens[idx]['lemma'] 140 | qpos = q_tokens[idx]['pos'] 141 | qner = q_tokens[idx]['ner'] 142 | 143 | # supporting documents are in list format 144 | documents = [c['words'] for c in context_tokens[idx]] 145 | offsets = [c['offsets'] for c in context_tokens[idx]] 146 | cpostags = [c['pos'] for c in context_tokens[idx]] 147 | cners = [c['ner'] for c in context_tokens[idx]] 148 | clemmas = [c['lemma'] for c in context_tokens[idx]] 149 | doc_sents = [c['sentences'] for c in context_tokens[idx]] 150 | 151 | if ans_tokens is not None: 152 | answer = ans_tokens[idx]['words'] 153 | else: 154 | answer = ["DUMMYANSWER"] 155 | 156 | candidates = [ca['words'] for ca in cand_tokens[idx]] 157 | candidatelemmas = [ca['lemma'] for ca in cand_tokens[idx]] 158 | 159 | yield { 160 | 'id': data['qids'][idx], 161 | 'question': question, 162 | 'qlemma': qlemma, 163 | 'qpos': qpos, 164 | 'qner': qner, 165 | 'documents': documents, 166 | 'offsets': offsets, 167 | 'docsents': doc_sents, 168 | 'docpostags': cpostags, 169 | 'docners': cners, 170 | 'doclemmas': clemmas, 171 | 'answer': answer, 172 | 'candidates': candidates, 173 | 'candidatelemmas': candidatelemmas 174 | } 175 | 176 | 177 | # ----------------------------------------------------------------------------- 178 | # Commandline options 179 | # ----------------------------------------------------------------------------- 180 | 181 | if __name__ == "__main__": 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument('data_dir', type=str, help='Path to data directory') 184 | parser.add_argument('out_dir', type=str, help='Path to output file dir') 185 | parser.add_argument('--split', type=str, help='Filename for train/dev split') 186 | parser.add_argument('--num-workers', type=int, default=8) 187 | args = parser.parse_args() 188 | 189 | t0 = time.time() 190 | 191 | in_file = os.path.join(args.data_dir, args.split + '.json') 192 | print('Loading data %s' % in_file, file=sys.stderr) 193 | dataset = load_dataset(in_file) 194 | 195 | if not os.path.isdir(args.out_dir): 196 | os.mkdir(args.out_dir) 197 | 198 | out_file = os.path.join( 199 | args.out_dir, '%s-processed-%s.txt' % (args.split, 'spacy') 200 | ) 201 | print('Will write to file %s' % out_file, file=sys.stderr) 202 | with open(out_file, 'w') as f: 203 | for ex in process_dataset(dataset, # args.tokenizer, 204 | args.num_workers): 205 | f.write(json.dumps(ex) + '\n') 206 | print('Total time: %.4f (s)' % (time.time() - t0)) 207 | -------------------------------------------------------------------------------- /scripts/prepro/wikihop_prep_data_with_lemma.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict, Tuple 4 | import re 5 | from numpy import array 6 | import time 7 | from tqdm import tqdm 8 | import argparse 9 | from nltk import PorterStemmer 10 | 11 | stemmer = PorterStemmer() 12 | 13 | 14 | def load_dict(fname): 15 | with open(fname, 'r') as fp: 16 | data = json.load(fp) 17 | return data 18 | 19 | 20 | def load_lines(fname): 21 | data = [] 22 | with open(fname, 'r') as fp: 23 | for line in fp: 24 | data.append(json.loads(line)) 25 | return data 26 | 27 | 28 | def get_locs_given_objs(doc: str, word: str, objs: List): 29 | doctoks = doc.split() 30 | ch_locs = [ob.span()[0] for ob in objs] 31 | found_words = [ob.group(0) for ob in objs] 32 | assert len(found_words) == len(ch_locs) 33 | for fwidx, fw in enumerate(found_words): 34 | start_offset = len(fw) - len(fw.lstrip()) 35 | ch_locs[fwidx] += start_offset 36 | found_words[fwidx] = fw.strip() 37 | widxs = get_widxs_from_chidxs(ch_locs, create_offsets(doctoks)) 38 | locs = [(widx, widx + len(found_words[i].split()) - 1) for i, widx in enumerate(widxs)] 39 | return locs 40 | 41 | 42 | def get_widxs_from_chidxs(chidxs: List[int], 43 | offsets: List[List[int]]) -> List[int]: 44 | """ 45 | Find word indices given character indices 46 | :param chidxs: 47 | :param offsets: 48 | :return: 49 | """ 50 | last_ch_idx = offsets[0][0] 51 | assert max(chidxs) < offsets[-1][1] - last_ch_idx 52 | widxs = [] 53 | for chidx in chidxs: 54 | for oi in range(len(offsets)): 55 | if chidx in range(offsets[oi][0] - last_ch_idx, offsets[oi][1] - last_ch_idx): 56 | widxs.append(oi) 57 | break 58 | elif chidx in range(offsets[oi][1] - last_ch_idx, 59 | offsets[min(oi + 1, len(offsets))][0] - last_ch_idx): 60 | widxs.append(oi) 61 | break 62 | assert len(chidxs) == len(widxs) 63 | return widxs 64 | 65 | 66 | def create_offsets(doctoks: List[str]) -> List[List[int]]: 67 | """ 68 | create offsets for a document tokens 69 | :param doctoks: 70 | :return: 71 | """ 72 | offsets = [] 73 | char_count = 0 74 | for tok in doctoks: 75 | offsets.append([char_count, char_count + len(tok)]) 76 | char_count = char_count + len(tok) + 1 77 | return offsets 78 | 79 | 80 | def find_backup_path(docsents, q, cand, k=40): 81 | """ 82 | Find a dummy backup path is no path could be found for a candidate 83 | :param docsents: 84 | :param q: 85 | :param cand: 86 | :param k: 87 | :return: 88 | """ 89 | path_for_cand_dict = {"he_docidx": None, 90 | "he_locs": None, 91 | "e1wh_loc": None, 92 | "e1_docidx": None, 93 | "e1_locs": None, 94 | "cand_docidx": None, 95 | "cand_locs": None, 96 | "he_words": ["BACKUP"], 97 | "e1wh": "BACKUP", 98 | "e1": "BACKUP", 99 | "cand_words": ["BACKUP"] 100 | } 101 | 102 | he = ' '.join(q[1:]).lower() 103 | if len(he.split()) == 0: 104 | path_for_cand_dict['he_docidx'] = 0 105 | path_for_cand_dict['he_locs'] = [(-1, -1)] 106 | else: 107 | pat_he = re.compile('(^|\W)' + re.escape(he) + '\W') 108 | 109 | for docssidx, docss in enumerate(docsents): 110 | doc = ' '.join(' '.join(sum(docss, [])).split()) 111 | doc = doc.lower() 112 | he_objs = [] 113 | for x in pat_he.finditer(doc): 114 | he_objs.append(x) 115 | if len(he_objs) > 0: 116 | path_for_cand_dict['he_docidx'] = docssidx 117 | path_for_cand_dict['he_locs'] = get_locs_given_objs(doc, he, he_objs)[:k] 118 | break 119 | 120 | cand = cand.lower() 121 | pat_cand = re.compile('(^|\W)' + re.escape(cand) + '\W') 122 | for docssidx, docss in enumerate(docsents): 123 | doc = ' '.join(' '.join(sum(docss, [])).split()) 124 | doc = doc.lower() 125 | ca_objs = [] 126 | for x in pat_cand.finditer(doc): 127 | ca_objs.append(x) 128 | if len(ca_objs) > 0: 129 | path_for_cand_dict['cand_docidx'] = docssidx 130 | path_for_cand_dict['cand_locs'] = get_locs_given_objs(doc, cand, ca_objs)[:k] 131 | break 132 | if path_for_cand_dict['he_docidx'] is None or path_for_cand_dict['he_locs'] is None: 133 | path_for_cand_dict['he_docidx'] = 0 134 | path_for_cand_dict['he_locs'] = [(-1, -1)] 135 | if path_for_cand_dict['cand_docidx'] is None or path_for_cand_dict['cand_locs'] is None: 136 | path_for_cand_dict['cand_docidx'] = 0 137 | path_for_cand_dict['cand_locs'] = [(0, 0)] 138 | 139 | return path_for_cand_dict 140 | 141 | 142 | def get_min_abs_diff(list1: List[int], list2: List[int]) -> float: 143 | return float(min([min(abs(array(list1) - bi)) for bi in list2])) 144 | 145 | 146 | def get_start_locs(list1: List[Tuple[int, int]]) -> List[int]: 147 | return [l[0] for l in list1] 148 | 149 | 150 | def filter_paths(paths_for_cand: List[Dict], k=30) -> List[Dict]: 151 | """ 152 | Scoring function for paths 153 | Keep only top-k paths 154 | 155 | path_for_cand_dict = {"he_docidx": he_docidx, #int 156 | "he_locs": he_locs, # List[int] 157 | "e1wh_loc": e1_with_head_loc, # List[int] 158 | "e1_docidx": e1_docidx, # int 159 | "e1_locs": e1_locs, # List[int] 160 | "cand_docidx": cand_docidx, # int 161 | "cand_locs": cand_locs # int 162 | } 163 | 164 | :param paths_for_cand: 165 | :param k: 166 | :return: 167 | """ 168 | if len(paths_for_cand) < k: 169 | return paths_for_cand 170 | 171 | big_score = 1e7 172 | scores = [] 173 | for path in paths_for_cand: 174 | he_docidx = path['he_docidx'] 175 | he_locs = path['he_locs'] 176 | e1wh_locs = path['e1wh_loc'] 177 | # e1_docidx = path['e1_docidx'] 178 | e1_locs = path['e1_locs'] 179 | cand_docidx = path['cand_docidx'] 180 | cand_locs = path['cand_locs'] 181 | he_locs = get_start_locs(he_locs) 182 | e1wh_locs = get_start_locs(e1wh_locs) 183 | e1_locs = get_start_locs(e1_locs) 184 | cand_locs = get_start_locs(cand_locs) 185 | if he_docidx == cand_docidx: 186 | if get_min_abs_diff(he_locs, cand_locs) == 0: 187 | scores.append(big_score) 188 | else: 189 | scores.append(-1.0) 190 | else: 191 | he_e1wh_mindiff = get_min_abs_diff(he_locs, e1wh_locs) 192 | e1_cand_mindiff = get_min_abs_diff(cand_locs, e1_locs) 193 | if he_e1wh_mindiff == 0 or e1_cand_mindiff == 0: 194 | scores.append(big_score) 195 | else: 196 | scores.append(he_e1wh_mindiff + e1_cand_mindiff) 197 | assert len(scores) == len(paths_for_cand) 198 | 199 | temp = sorted(zip(paths_for_cand, scores), 200 | key=lambda x: x[1], 201 | reverse=False) # ascending order 202 | sorted_paths, _ = map(list, zip(*temp)) 203 | sorted_paths = sorted_paths[:k] 204 | 205 | return sorted_paths 206 | 207 | 208 | def get_doc_len(docsents, docidx): 209 | doc = docsents[docidx] 210 | doc = ' '.join(sum(doc, [])).split() 211 | return len(doc) 212 | 213 | 214 | def adjust_word_idxs(toks, widxs): 215 | split_toks = ' '.join(toks).split() 216 | if len(toks) == len(split_toks): 217 | return widxs 218 | else: 219 | mod_widxs = [widx - (len(toks[:widx]) - len(' '.join(toks[:widx]).split())) 220 | for widx in widxs] 221 | return mod_widxs 222 | 223 | 224 | def process_path_for_cand(path, docsents): 225 | """ 226 | 227 | :param path: 228 | :param docsents: 229 | :return: 230 | """ 231 | he_docidx = path['head_ent_docidx'] 232 | he_doc_len = get_doc_len(docsents, he_docidx) 233 | he_ent_loc_dict = path['head_ent'] 234 | he_locs = [] 235 | he_words = list(he_ent_loc_dict.keys()) 236 | for he_w in list(he_ent_loc_dict.keys()): 237 | start_loc_list = he_ent_loc_dict[he_w] 238 | start_loc_list = adjust_word_idxs(sum(docsents[he_docidx], []), 239 | start_loc_list) 240 | for s in start_loc_list: 241 | assert s < he_doc_len 242 | end_loc_list = [max(s, s + len(he_w.split()) - 1) 243 | for s in start_loc_list] # end is inclusive 244 | 245 | for e in end_loc_list: 246 | assert e < he_doc_len 247 | combined_locs = [(s, e) for s, e in zip(start_loc_list, end_loc_list)] 248 | he_locs += combined_locs 249 | 250 | e1 = path['e1'] 251 | if e1 is None: 252 | e1_with_head_loc = [(-1, -1)] 253 | e1_docidx = None 254 | e1_locs = [(-1, -1)] 255 | cand_docidx = he_docidx 256 | else: 257 | e1wh_locs = [path['e1_with_head_widx']] 258 | e1wh_locs = adjust_word_idxs(sum(docsents[he_docidx], []), e1wh_locs) 259 | e1_with_head_loc = [(e1wh_locs[0], 260 | e1wh_locs[0] + len( 261 | path['e1_with_head_ent'].split()) - 1)] # inclusive end 262 | for s, e in e1_with_head_loc: 263 | assert s < he_doc_len 264 | assert e < he_doc_len 265 | e1_docidx = path['e1_docidx'] 266 | e1_doc_len = get_doc_len(docsents, e1_docidx) 267 | e1_start_locs = path['e1_locs'] 268 | e1_start_locs = adjust_word_idxs(sum(docsents[e1_docidx], []), e1_start_locs) 269 | e1_locs = [(e1s, e1s + len(e1.split()) - 1) for e1s in e1_start_locs] 270 | for s, e in e1_locs: 271 | assert s < e1_doc_len 272 | assert e < e1_doc_len 273 | cand_docidx = e1_docidx 274 | 275 | cand_loc_dict = path['cand_locs'] 276 | cand_words = list(cand_loc_dict.keys()) 277 | cand_doc_len = get_doc_len(docsents, cand_docidx) 278 | cand_locs = [] 279 | for ca_w in list(cand_loc_dict.keys()): 280 | ca_w_start_loc_list = cand_loc_dict[ca_w] 281 | ca_w_start_loc_list = adjust_word_idxs(sum(docsents[cand_docidx], []), 282 | ca_w_start_loc_list) 283 | ca_w_end_loc_list = [s + len(ca_w.split()) - 1 284 | for s in ca_w_start_loc_list] # end is inclusive 285 | combined_locs_ca_w = [(s, e) for s, e in zip(ca_w_start_loc_list, ca_w_end_loc_list)] 286 | for s, e in combined_locs_ca_w: 287 | assert s < cand_doc_len 288 | assert e < cand_doc_len 289 | cand_locs += combined_locs_ca_w 290 | 291 | path_for_cand_dict = {"he_docidx": he_docidx, 292 | "he_locs": he_locs, 293 | "e1wh_loc": e1_with_head_loc, 294 | "e1_docidx": e1_docidx, 295 | "e1_locs": e1_locs, 296 | "cand_docidx": cand_docidx, 297 | "cand_locs": cand_locs, 298 | "he_words": he_words, 299 | "e1wh": path['e1_with_head_ent'], 300 | "e1": e1, 301 | "cand_words": cand_words 302 | } 303 | return path_for_cand_dict 304 | 305 | 306 | def process_allpaths_for_cand(cand, path_data_for_cand, 307 | qtoks, docsents, max_num_paths): 308 | """ 309 | process all paths for a particular candidate 310 | :param cand: 311 | :param path_data_for_cand: 312 | :param qtoks: 313 | :param docsents: 314 | :param max_num_paths: 315 | :return: 316 | """ 317 | if len(qtoks) == 1 or len(path_data_for_cand) == 0: 318 | paths_for_cand = [find_backup_path(docsents, qtoks, cand)] 319 | return paths_for_cand 320 | 321 | paths_for_cand = [] 322 | for pathforcand in path_data_for_cand: 323 | path_for_cand_dict_ = process_path_for_cand(pathforcand, docsents) 324 | paths_for_cand.append(path_for_cand_dict_) 325 | 326 | if len(paths_for_cand) == 0: 327 | paths_for_cand = [find_backup_path(docsents, qtoks, cand)] 328 | if len(paths_for_cand) > max_num_paths: 329 | paths_for_cand = filter_paths(paths_for_cand, k=max_num_paths) 330 | return paths_for_cand 331 | 332 | 333 | def lemmatize_docsents(doc_sents: List[List[List[str]]]): 334 | doc_sents_lemma = [] 335 | for idx, docl in enumerate(doc_sents): 336 | doc_sent = doc_sents[idx] 337 | docsent_lemma = [] 338 | for senttoks in doc_sent: 339 | # startidx = len(sum(doc_sent[:sidx], [])) 340 | docsent_lemma.append([stemmer.stem(tok) for tok in senttoks]) 341 | doc_sents_lemma.append(docsent_lemma) 342 | return doc_sents_lemma 343 | 344 | 345 | def process_paths(d: Dict, pathdata: Dict, 346 | max_num_paths: int = 30, 347 | lemmatize: bool = False) -> Dict: 348 | """ 349 | Process paths 350 | :param d: data dictionary from extarcted paths 351 | :param pathdata: 352 | :param max_num_paths: 353 | :param lemmatize: 354 | :return: 355 | """ 356 | # ans = ' '.join(d['answer']) 357 | docsents = d['docsents'] # List[List[List[List[str]]] 358 | qid = d['id'] 359 | question = d['question'] 360 | qpos = d['qpos'] 361 | qner = d['qner'] 362 | docpostags = d['docpostags'] 363 | docners = d['docners'] 364 | 365 | mod_docsents = [] 366 | mod_docpostags = [] 367 | mod_docners = [] 368 | for docidx, doc in enumerate(docsents): 369 | sents = [] 370 | sentpos = [] 371 | sentner = [] 372 | docpos = docpostags[docidx] 373 | docner = docners[docidx] 374 | for sentidx, sent in enumerate(doc): 375 | sents.append(' '.join(sent).split()) 376 | stidx = len(sum(doc[:sentidx], [])) 377 | endidx = stidx + len(' '.join(sent).split()) 378 | sentpos.append(docpos[stidx:endidx]) 379 | sentner.append(docner[stidx:endidx]) 380 | mod_docsents.append(sents) 381 | mod_docpostags.append(sentpos) 382 | mod_docners.append(sentner) 383 | 384 | path_dict = {'id': qid, 'question': d['question'], 385 | 'qpos': qpos, 'qner': qner, 386 | "docsents": mod_docsents, 387 | "docpostags": mod_docpostags, 388 | "docners": mod_docners, 389 | 'answer': d['answer'], 390 | 'candidates': d['candidates']} 391 | 392 | if lemmatize: 393 | docsents = lemmatize_docsents(docsents) 394 | question = [stemmer.stem(qw) for qw in question] 395 | 396 | # print(qid) 397 | assert len(list(pathdata['pathlist'].keys())) == len(d['candidates']) 398 | path_for_all_cands = [] 399 | for cand in pathdata['pathlist'].keys(): 400 | path_data_for_cand_ = pathdata['pathlist'][cand] 401 | paths_for_cand_ = process_allpaths_for_cand(cand, path_data_for_cand_, 402 | question, docsents, 403 | max_num_paths=max_num_paths) 404 | path_for_all_cands.append(paths_for_cand_) 405 | 406 | path_dict['paths'] = path_for_all_cands 407 | 408 | return path_dict 409 | 410 | 411 | if __name__ == "__main__": 412 | parser = argparse.ArgumentParser() 413 | parser.add_argument('datadir', type=str, 414 | help='Path to preprocessed data dir') 415 | parser.add_argument('pathdir', type=str, 416 | help='Directory to read the paths') 417 | parser.add_argument('dumpdir', type=str, 418 | help='Directory to dump the paths') 419 | parser.add_argument('--mode', type=str, default='dev', 420 | help='Train/Dev') 421 | parser.add_argument('--maxnumpaths', type=int, default=30, 422 | help='How many maximum paths to consider') 423 | parser.add_argument('--lemmatize', type=bool, default=False, 424 | help='whether to lemmatize for path extraction') 425 | args = parser.parse_args() 426 | t0 = time.time() 427 | dumpdir = args.dumpdir 428 | if not os.path.isdir(args.dumpdir): 429 | os.makedirs(args.dumpdir) 430 | mode = args.mode 431 | max_num_paths = args.maxnumpaths 432 | lemmatize = args.lemmatize 433 | if mode == 'dev': 434 | dev_data = load_lines(args.datadir + '/dev-processed-spacy.txt') 435 | paths = load_lines(args.pathdir + '/dev-processed-spacy.txt.paths') 436 | print("All data loaded") 437 | 438 | with open(args.dumpdir + '/' + mode + '-path-lines.txt', 'w') as fp: 439 | if mode == "train": 440 | num_splits = 22 # check and modify how many splits were there for the train data paths 441 | else: 442 | num_splits = 1 443 | for sp in range(num_splits): 444 | if mode == 'train': 445 | dev_data = load_lines(args.datadir + '/train-split/split_' + str(sp) + '.json') 446 | paths = load_lines(args.pathdir + '/train-split/split_' + str(sp) + '.json.paths') 447 | print("Data loaded for split %d " % sp) 448 | assert len(dev_data) == len(paths) 449 | for dataidx, data in tqdm(enumerate(dev_data)): 450 | pathdata_ = paths[dataidx] 451 | path_dict_ = process_paths(data, pathdata_, 452 | max_num_paths=max_num_paths, 453 | lemmatize=lemmatize) 454 | fp.write(json.dumps(path_dict_) + '\n') 455 | 456 | print("Done!") 457 | print('Total time: %.4f (s)' % (time.time() - t0)) 458 | 459 | 460 | -------------------------------------------------------------------------------- /scripts/preprocess_obqa.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Script to run preprocessing 3 | 4 | export OPENBLAS_NUM_THREADS=1 5 | set -e 6 | 7 | ORIG_DATA_DIR=$1 8 | PREPROCESSED_DATA_DIR=$2 9 | 10 | for x in train dev test; do 11 | echo "Split - $x" 12 | # Change the filename accoridngly 13 | python scripts/prepro/preprocess_obqa.py $ORIG_DATA_DIR/obqa-commonsense-590k-wh-sorted100-${x} \ 14 | $PREPROCESSED_DATA_DIR \ 15 | --split $x --num-workers 6 16 | done 17 | 18 | echo "Creating split directory" 19 | mkdir -p ${PREPROCESSED_DATA_DIR}/train-split 20 | echo "Done" 21 | 22 | # break the train data 23 | python scripts/break_train_data_obqa.py $PREPROCESSED_DATA_DIR/train-processed-spacy.txt \ 24 | ${PREPROCESSED_DATA_DIR}/train-split -------------------------------------------------------------------------------- /scripts/preprocess_wikihop.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Script to run preprocessing 4 | 5 | export OPENBLAS_NUM_THREADS=1 6 | set -e 7 | 8 | DATA_DIR=$1 9 | OUT_DIR=$2 10 | 11 | python scripts/break_orig_wikihop_train.py $DATA_DIR/train.json 12 | 13 | for x in $(seq 0 8); do 14 | python scripts/prepro/preprocess_wikihop.py $DATA_DIR $OUT_DIR \ 15 | --split train${x} --num-workers 8 16 | done 17 | 18 | if [ -f "$OUT_DIR/train-processed-spacy.txt" ] 19 | then 20 | rm $OUT_DIR/train-processed-spacy.txt 21 | fi 22 | 23 | for x in $(seq 0 8); do 24 | cat $OUT_DIR/train${x}-processed-spacy.txt >> $OUT_DIR/train-processed-spacy.txt 25 | done 26 | 27 | python scripts/prepro/preprocess_wikihop.py $DATA_DIR $OUT_DIR \ 28 | --split dev --num-workers 8 29 | 30 | #for x in train dev; do 31 | # echo "Split - $x" 32 | # python scripts/prepro/preprocess_wikihop.py $DATA_DIR $OUT_DIR \ 33 | # --split $x --num-workers 8 34 | #done 35 | 36 | echo "Creating split directory" 37 | mkdir -p ${OUT_DIR}/train-split 38 | echo "Done" 39 | 40 | # break the train data 41 | python scripts/break_train_data_wikihop.py $OUT_DIR/train-processed-spacy.txt ${OUT_DIR}/train-split 42 | -------------------------------------------------------------------------------- /scripts/run_full_obqa.sh: -------------------------------------------------------------------------------- 1 | # Script for running full system for OBQA dataset 2 | #!/usr/bin/env bash 3 | 4 | set -x 5 | ORIG_DATA_DIR=$1 6 | PREPROCESSED_DATA_DIR=$2 7 | PATH_DIR=$3 8 | PARAM_FILE=$4 9 | FINAL_PATH_DUMPDIR="data/datasets/OBQA/adjusted" 10 | mkdir -p $FINAL_PATH_DUMPDIR 11 | 12 | # Tokenization/tagging etc 13 | scripts/preprocess_obqa.sh $ORIG_DATA_DIR $PREPROCESSED_DATA_DIR 14 | 15 | # break the train data 16 | scripts/break_train_data_obqa.py $PREPROCESSED_DATA_DIR ${PREPROCESSED_DATA_DIR}/train-split 17 | 18 | # path extraction 19 | scripts/path_finder_obqa.sh $PREPROCESSED_DATA_DIR $PATH_DIR 20 | 21 | # path adjustments 22 | scripts/path_adjustments_obqa.sh $PREPROCESSED_DATA_DIR $PATH_DIR $FINAL_PATH_DUMPDIR 23 | 24 | # Training 25 | MODELDIR="models/obqa" 26 | if [ -d $MODELDIR ] 27 | then 28 | rm -r $MODELDIR 29 | fi 30 | mkdir -p $MODELDIR 31 | allennlp train --file-friendly-logging -s $MODELDIR --include-package pathnet $PARAM_FILE -------------------------------------------------------------------------------- /scripts/run_full_wikihop.sh: -------------------------------------------------------------------------------- 1 | # Preprocessing + Training script for WikiHop 2 | #!/usr/bin/env bash 3 | 4 | set -x 5 | ORIG_DATA_DIR=$1 6 | PREPROCESSED_DATA_DIR=$2 7 | PATH_DIR=$3 8 | PARAM_FILE=$4 9 | FINAL_PATH_DUMPDIR="data/datasets/WikiHop/adjusted" 10 | mkdir -p $FINAL_PATH_DUMPDIR 11 | 12 | # Tokenization/tagging etc 13 | scripts/preprocess_wikihop.sh $ORIG_DATA_DIR $PREPROCESSED_DATA_DIR 14 | 15 | # break the train data 16 | scripts/break_train_data_wikihop.py $PREPROCESSED_DATA_DIR ${PREPROCESSED_DATA_DIR}/train-split 17 | 18 | # path extraction 19 | scripts/path_finder_wikihop.sh $PREPROCESSED_DATA_DIR $PATH_DIR 20 | 21 | # path adjustments 22 | scripts/path_adjustments_wikihop.sh $PREPROCESSED_DATA_DIR $PATH_DIR $FINAL_PATH_DUMPDIR 23 | 24 | # Preparing Vocabulary 25 | allennlp make-vocab training_configs/config_wikihop_makevocab.json \ 26 | -s data/datasets/WikiHop/ --include-package pathnet 27 | 28 | # Training 29 | MODELDIR="models/wikihop" 30 | if [ -d $MODELDIR ] 31 | then 32 | rm -r $MODELDIR 33 | fi 34 | mkdir -p $MODELDIR 35 | allennlp train --file-friendly-logging -s $MODELDIR --include-package pathnet $PARAM_FILE 36 | -------------------------------------------------------------------------------- /training_configs/config_obqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "obqa_data_reader", 4 | "lazy": false, 5 | "cut_context": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | } 11 | }, 12 | "tokenizer": { 13 | "type": "word", 14 | "word_splitter": { 15 | "type": "just_spaces" 16 | }, 17 | "end_tokens": ["@@NULL@@"] 18 | } 19 | }, 20 | // "vocabulary": { 21 | // "directory_path": "data/datasets/OBQA/vocabulary/" 22 | // }, 23 | "train_data_path": "data/datasets/OBQA/adjusted/paths100/train-path-lines.txt", 24 | "validation_data_path": "/inputs/adjusted/paths100/dev-path-lines.txt", 25 | "test_data_path": "data/datasets/OBQA/adjusted/paths100/test-path-lines.txt", 26 | "evaluate_on_test": true, 27 | "model": { 28 | "type": "pathnet", 29 | "text_field_embedder": { 30 | "tokens": { 31 | "type": "embedding", 32 | "pretrained_file": "data/embeddings/glove.840B.300d.txt.gz", 33 | "embedding_dim": 300, 34 | "trainable": false 35 | } 36 | }, 37 | "embeddings_dropout_value": 0.5, 38 | "aggregate_feedforward": { 39 | "input_dim": 100, 40 | "num_layers": 1, 41 | "hidden_dims": 1, 42 | "activations": "linear" 43 | }, 44 | "question_encoder": { 45 | "type": "lstm", 46 | "bidirectional": true, 47 | "num_layers": 1, 48 | "input_size": 300, 49 | "hidden_size": 50 50 | }, 51 | "document_encoder": { 52 | "type": "lstm", 53 | "bidirectional": true, 54 | "num_layers": 1, 55 | "input_size": 300, 56 | "hidden_size": 50 57 | }, 58 | "choice_encoder": { 59 | "type": "lstm", 60 | "bidirectional": true, 61 | "num_layers": 1, 62 | "input_size": 300, 63 | "hidden_size": 50 64 | }, 65 | "he_e1wh_projector": { 66 | "input_dim": 400, 67 | "num_layers": 1, 68 | "hidden_dims": 100, 69 | "activations": "tanh", 70 | "dropout": 0.5 71 | }, 72 | "e1_ca_projector": { 73 | "input_dim": 400, 74 | "num_layers": 1, 75 | "hidden_dims": 100, 76 | "activations": "tanh", 77 | "dropout": 0.5 78 | }, 79 | "path_projector": { 80 | "input_dim": 200, 81 | "num_layers": 1, 82 | "hidden_dims": 100, 83 | "activations": "tanh", 84 | "dropout": 0.5 85 | }, 86 | "allchoice_projector": { 87 | "input_dim": 200, 88 | "num_layers": 1, 89 | "hidden_dims": 100, 90 | "activations": "tanh", 91 | "dropout": 0.5 92 | }, 93 | "question_projector": { 94 | "input_dim": 200, 95 | "num_layers": 1, 96 | "hidden_dims": 100, 97 | "activations": "tanh", 98 | "dropout": 0.5 99 | }, 100 | "combined_q_projector": { 101 | "input_dim": 100, 102 | "num_layers": 1, 103 | "hidden_dims": 100, 104 | "activations": "linear", 105 | "dropout": 0.5 106 | }, 107 | "combined_s_projector": { 108 | "input_dim": 100, 109 | "num_layers": 1, 110 | "hidden_dims": 100, 111 | "activations": "linear", 112 | "dropout": 0.5 113 | }, 114 | "joint_encoder": { 115 | "seq_encoder": { 116 | "type": "lstm", 117 | "bidirectional": true, 118 | "num_layers": 1, 119 | "input_size": 100, 120 | "hidden_size": 50 121 | } 122 | }, 123 | "doc_aggregator": { 124 | "projector": { 125 | "input_dim": 200, 126 | "num_layers": 1, 127 | "hidden_dims": 1, 128 | "activations": "linear", 129 | "dropout": 0.5 130 | }, 131 | "intermediate_projector": { 132 | "input_dim": 200, 133 | "num_layers": 1, 134 | "hidden_dims": 200, 135 | "activations": "linear", 136 | "dropout": 0.5 137 | } 138 | }, 139 | "choice_aggregator": { 140 | "projector": { 141 | "input_dim": 100, 142 | "num_layers": 1, 143 | "hidden_dims": 1, 144 | "activations": "linear", 145 | "dropout": 0.5 146 | }, 147 | "intermediate_projector": { 148 | "input_dim": 100, 149 | "num_layers": 1, 150 | "hidden_dims": 100, 151 | "activations": "linear", 152 | "dropout": 0.5 153 | } 154 | }, 155 | "path_aggregator": { 156 | "input_dim": 400, 157 | "num_layers": 1, 158 | "hidden_dims": 100, 159 | "activations": "linear", 160 | "dropout": 0.5 161 | }, 162 | "path_loc_aggregator": "mean", 163 | "allchoice_loc": false, 164 | "path_enc": true, 165 | "path_enc_doc_based": true, 166 | "path_enc_loc_based":true, 167 | "combine_scores": "add_cat", 168 | "span_extractor": { 169 | "type": "endpoint", 170 | "input_dim": 2, 171 | "combination": "x,y" 172 | }, 173 | "initializer": [ 174 | [".*linear_layers.*weight", {"type": "xavier_normal"}], 175 | [".*token_embedder_tokens._projection.*weight", {"type": "xavier_normal"}] 176 | ] 177 | }, 178 | "iterator": { 179 | "type": "bucket", 180 | "sorting_keys": [["candidates", "num_fields"],["documents", "list_num_tokens"]], 181 | "biggest_batch_first": true, 182 | "batch_size": 8, 183 | "cache_instances": false 184 | }, 185 | "trainer": { 186 | "num_epochs": 20, 187 | "patience": 7, 188 | "cuda_device": 0, 189 | "grad_norm": 5.0, 190 | "validation_metric": "+accuracy", 191 | "optimizer": { 192 | "type": "adam", 193 | "lr": 0.001 194 | }, 195 | "learning_rate_scheduler": { 196 | "type": "reduce_on_plateau", 197 | "factor": 0.5, 198 | "mode": "max", 199 | "patience": 2 200 | } 201 | } 202 | } -------------------------------------------------------------------------------- /training_configs/config_wikihop.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikihop_data_reader", 4 | "lazy": false, 5 | "cut_context": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | } 11 | }, 12 | "tokenizer": { 13 | "type": "word", 14 | "word_splitter": { 15 | "type": "just_spaces" 16 | }, 17 | "end_tokens": ["@@NULL@@"] 18 | } 19 | }, 20 | "vocabulary": { 21 | "directory_path": "data/datasets/WikiHop/vocabulary/" 22 | }, 23 | "train_data_path": "data/datasets/WikiHop/adjusted/paths30/train-path-lines.txt", 24 | "validation_data_path": "data/datasets/WikiHop/adjusted/paths30/dev-path-lines.txt", 25 | "model": { 26 | "type": "pathnet", 27 | "text_field_embedder": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "data/embeddings/glove.840B.300d.txt.gz", 31 | "embedding_dim": 300, 32 | "trainable": false 33 | } 34 | }, 35 | "embeddings_dropout_value": 0.25, 36 | "aggregate_feedforward": { 37 | "input_dim": 100, 38 | "num_layers": 1, 39 | "hidden_dims": 1, 40 | "activations": "linear" 41 | }, 42 | "question_encoder": { 43 | "type": "lstm", 44 | "bidirectional": true, 45 | "num_layers": 1, 46 | "input_size": 300, 47 | "hidden_size": 50 48 | }, 49 | "document_encoder": { 50 | "type": "lstm", 51 | "bidirectional": true, 52 | "num_layers": 1, 53 | "input_size": 300, 54 | "hidden_size": 50 55 | }, 56 | "choice_encoder": { 57 | "type": "lstm", 58 | "bidirectional": true, 59 | "num_layers": 1, 60 | "input_size": 300, 61 | "hidden_size": 50 62 | }, 63 | "loc_path_encoder": { 64 | "type": "tree", 65 | "he_e1_comp": { 66 | "input_dim": 400, 67 | "num_layers": 1, 68 | "hidden_dims": 100, 69 | "activations": "tanh", 70 | "dropout": 0.25 71 | }, 72 | "e1_ca_comp": { 73 | "input_dim": 400, 74 | "num_layers": 1, 75 | "hidden_dims": 100, 76 | "activations": "tanh", 77 | "dropout": 0.25 78 | }, 79 | "r1r2_comp": { 80 | "input_dim": 200, 81 | "num_layers": 1, 82 | "hidden_dims": 100, 83 | "activations": "tanh", 84 | "dropout": 0.25 85 | }, 86 | "rnn_comp": { 87 | "type": "lstm", 88 | "bidirectional": false, 89 | "num_layers": 1, 90 | "input_size": 200, 91 | "hidden_size": 100, 92 | "dropout": 0.25 93 | } 94 | }, 95 | "allchoice_projector": { 96 | "input_dim": 200, 97 | "num_layers": 1, 98 | "hidden_dims": 100, 99 | "activations": "tanh", 100 | "dropout": 0.25 101 | }, 102 | "question_projector": { 103 | "input_dim": 200, 104 | "num_layers": 1, 105 | "hidden_dims": 100, 106 | "activations": "tanh", 107 | "dropout": 0.25 108 | }, 109 | "combined_q_projector": { 110 | "input_dim": 100, 111 | "num_layers": 1, 112 | "hidden_dims": 100, 113 | "activations": "linear", 114 | "dropout": 0.25 115 | }, 116 | "combined_s_projector": { 117 | "input_dim": 100, 118 | "num_layers": 1, 119 | "hidden_dims": 100, 120 | "activations": "linear", 121 | "dropout": 0.25 122 | }, 123 | "joint_encoder": { 124 | "seq_encoder": { 125 | "type": "lstm", 126 | "bidirectional": true, 127 | "num_layers": 1, 128 | "input_size": 100, 129 | "hidden_size": 50 130 | } 131 | }, 132 | "doc_aggregator": { 133 | "projector": { 134 | "input_dim": 200, 135 | "num_layers": 1, 136 | "hidden_dims": 1, 137 | "activations": "linear", 138 | "dropout": 0.25 139 | }, 140 | "intermediate_projector": { 141 | "input_dim": 200, 142 | "num_layers": 1, 143 | "hidden_dims": 200, 144 | "activations": "linear", 145 | "dropout": 0.25 146 | } 147 | }, 148 | "choice_aggregator": { 149 | "projector": { 150 | "input_dim": 100, 151 | "num_layers": 1, 152 | "hidden_dims": 1, 153 | "activations": "linear", 154 | "dropout": 0.25 155 | }, 156 | "intermediate_projector": { 157 | "input_dim": 100, 158 | "num_layers": 1, 159 | "hidden_dims": 100, 160 | "activations": "linear", 161 | "dropout": 0.25 162 | } 163 | }, 164 | "path_aggregator": { 165 | "input_dim": 400, 166 | "num_layers": 1, 167 | "hidden_dims": 100, 168 | "activations": "linear", 169 | "dropout": 0.25 170 | }, 171 | "path_loc_aggregator": "mean", 172 | "allchoice_loc": false, 173 | "path_enc": true, 174 | "path_enc_doc_based": true, 175 | "path_enc_loc_based": true, 176 | "combine_scores": "add_cat", 177 | "span_extractor": { 178 | "type": "endpoint", 179 | "input_dim": 2, 180 | "combination": "x,y" 181 | }, 182 | "initializer": [ 183 | [".*linear_layers.*weight", {"type": "xavier_normal"}], 184 | [".*token_embedder_tokens._projection.*weight", {"type": "xavier_normal"}] 185 | ] 186 | }, 187 | "iterator": { 188 | "type": "bucket", 189 | "sorting_keys": [["candidates", "num_fields"],["documents", "list_num_tokens"]], 190 | "biggest_batch_first": true, 191 | "batch_size": 8, 192 | "cache_instances": false 193 | }, 194 | "trainer": { 195 | "num_epochs": 25, 196 | "patience": 12, 197 | "cuda_device": 0, 198 | "grad_norm": 5.0, 199 | "validation_metric": "+accuracy", 200 | "optimizer": { 201 | "type": "adam", 202 | "lr": 0.001 203 | }, 204 | "learning_rate_scheduler": { 205 | "type": "reduce_on_plateau", 206 | "factor": 0.5, 207 | "mode": "max", 208 | "patience": 2 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /training_configs/config_wikihop_makevocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikihop_data_reader", 4 | "lazy": false, 5 | "cut_context": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | } 11 | }, 12 | "tokenizer": { 13 | "type": "word", 14 | "word_splitter": { 15 | "type": "just_spaces" 16 | }, 17 | "end_tokens": ["@@NULL@@"] 18 | } 19 | }, 20 | "train_data_path": "data/datasets/WikiHop/adjusted/paths30/train-path-lines.txt", 21 | "validation_data_path": "data/datasets/WikiHop/adjusted/paths30/dev-path-lines.txt", 22 | "model": { 23 | "type": "pathnet", 24 | "text_field_embedder": { 25 | "tokens": { 26 | "type": "embedding", 27 | "pretrained_file": "data/embeddings/glove.840B.300d.txt.gz", 28 | "embedding_dim": 300, 29 | "trainable": false 30 | } 31 | }, 32 | "embeddings_dropout_value": 0.25, 33 | "aggregate_feedforward": { 34 | "input_dim": 100, 35 | "num_layers": 1, 36 | "hidden_dims": 1, 37 | "activations": "linear" 38 | }, 39 | "question_encoder": { 40 | "type": "lstm", 41 | "bidirectional": true, 42 | "num_layers": 1, 43 | "input_size": 300, 44 | "hidden_size": 50 45 | }, 46 | "document_encoder": { 47 | "type": "lstm", 48 | "bidirectional": true, 49 | "num_layers": 1, 50 | "input_size": 300, 51 | "hidden_size": 50 52 | }, 53 | "choice_encoder": { 54 | "type": "lstm", 55 | "bidirectional": true, 56 | "num_layers": 1, 57 | "input_size": 300, 58 | "hidden_size": 50 59 | }, 60 | "loc_path_encoder": { 61 | "type": "tree", 62 | "he_e1_comp": { 63 | "input_dim": 400, 64 | "num_layers": 1, 65 | "hidden_dims": 100, 66 | "activations": "tanh", 67 | "dropout": 0.25 68 | }, 69 | "e1_ca_comp": { 70 | "input_dim": 400, 71 | "num_layers": 1, 72 | "hidden_dims": 100, 73 | "activations": "tanh", 74 | "dropout": 0.25 75 | }, 76 | "r1r2_comp": { 77 | "input_dim": 200, 78 | "num_layers": 1, 79 | "hidden_dims": 100, 80 | "activations": "tanh", 81 | "dropout": 0.25 82 | }, 83 | "rnn_comp": { 84 | "type": "lstm", 85 | "bidirectional": false, 86 | "num_layers": 1, 87 | "input_size": 200, 88 | "hidden_size": 100, 89 | "dropout": 0.25 90 | } 91 | }, 92 | "allchoice_projector": { 93 | "input_dim": 200, 94 | "num_layers": 1, 95 | "hidden_dims": 100, 96 | "activations": "tanh", 97 | "dropout": 0.25 98 | }, 99 | "question_projector": { 100 | "input_dim": 200, 101 | "num_layers": 1, 102 | "hidden_dims": 100, 103 | "activations": "tanh", 104 | "dropout": 0.25 105 | }, 106 | "combined_q_projector": { 107 | "input_dim": 100, 108 | "num_layers": 1, 109 | "hidden_dims": 100, 110 | "activations": "linear", 111 | "dropout": 0.25 112 | }, 113 | "combined_s_projector": { 114 | "input_dim": 100, 115 | "num_layers": 1, 116 | "hidden_dims": 100, 117 | "activations": "linear", 118 | "dropout": 0.25 119 | }, 120 | "joint_encoder": { 121 | "seq_encoder": { 122 | "type": "lstm", 123 | "bidirectional": true, 124 | "num_layers": 1, 125 | "input_size": 100, 126 | "hidden_size": 50 127 | } 128 | }, 129 | "doc_aggregator": { 130 | "projector": { 131 | "input_dim": 200, 132 | "num_layers": 1, 133 | "hidden_dims": 1, 134 | "activations": "linear", 135 | "dropout": 0.25 136 | }, 137 | "intermediate_projector": { 138 | "input_dim": 200, 139 | "num_layers": 1, 140 | "hidden_dims": 200, 141 | "activations": "linear", 142 | "dropout": 0.25 143 | } 144 | }, 145 | "choice_aggregator": { 146 | "projector": { 147 | "input_dim": 100, 148 | "num_layers": 1, 149 | "hidden_dims": 1, 150 | "activations": "linear", 151 | "dropout": 0.25 152 | }, 153 | "intermediate_projector": { 154 | "input_dim": 100, 155 | "num_layers": 1, 156 | "hidden_dims": 100, 157 | "activations": "linear", 158 | "dropout": 0.25 159 | } 160 | }, 161 | "path_aggregator": { 162 | "input_dim": 400, 163 | "num_layers": 1, 164 | "hidden_dims": 100, 165 | "activations": "linear", 166 | "dropout": 0.25 167 | }, 168 | "path_loc_aggregator": "mean", 169 | "allchoice_loc": false, 170 | "path_enc": true, 171 | "path_enc_doc_based": true, 172 | "path_enc_loc_based": true, 173 | "combine_scores": "add_cat", 174 | "span_extractor": { 175 | "type": "endpoint", 176 | "input_dim": 2, 177 | "combination": "x,y" 178 | }, 179 | "initializer": [ 180 | [".*linear_layers.*weight", {"type": "xavier_normal"}], 181 | [".*token_embedder_tokens._projection.*weight", {"type": "xavier_normal"}] 182 | ] 183 | }, 184 | "iterator": { 185 | "type": "bucket", 186 | "sorting_keys": [["candidates", "num_fields"],["documents", "list_num_tokens"]], 187 | "biggest_batch_first": true, 188 | "batch_size": 8, 189 | "cache_instances": false 190 | }, 191 | "trainer": { 192 | "num_epochs": 25, 193 | "patience": 12, 194 | "cuda_device": 0, 195 | "grad_norm": 5.0, 196 | "validation_metric": "+accuracy", 197 | "optimizer": { 198 | "type": "adam", 199 | "lr": 0.001 200 | }, 201 | "learning_rate_scheduler": { 202 | "type": "reduce_on_plateau", 203 | "factor": 0.5, 204 | "mode": "max", 205 | "patience": 2 206 | } 207 | } 208 | } 209 | --------------------------------------------------------------------------------