├── .gitignore ├── LICENSE ├── README.md ├── model ├── __init__.py ├── data.py ├── qa_model.py ├── requirements.txt ├── run.py └── run_wiqa_classifier.sh ├── requirements.txt ├── src ├── __init__.py ├── eval │ ├── __init__.py │ ├── eval_utils.py │ └── evaluation.py ├── helpers │ ├── ProparaExtendedPara.py │ ├── SGFileLoaders.py │ ├── __init__.py │ ├── collections_util.py │ ├── dataset_info.py │ ├── situation_graph.py │ └── whatif_metadata.py ├── third_party_utils │ ├── __init__.py │ ├── allennlp_cached_filepath.py │ ├── nltk_porter_stemmer.py │ └── spacy_stop_words.py └── wiqa_wrapper.py ├── tests └── simple_test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wiqa-dataset 2 | 3 | Code repo for EMNLP 2019 WIQA dataset paper. 4 | 5 | ## Usage 6 | 7 | First, set up a virtual environment like this: 8 | 9 | ``` 10 | virtualenv venv 11 | source venv/bin/activate 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | (You can also use Conda.) 16 | 17 | Create a simple program `retrieve.py` like this: 18 | 19 | ``` 20 | from src.wiqa_wrapper import WIQADataPoint 21 | 22 | wimd = WIQADataPoint.get_default_whatif_metadata() 23 | sg = wimd.get_graph_for_id(graph_id="13") 24 | print(sg.to_json_v1()) 25 | ``` 26 | 27 | This program will read the What-If metadata (`wimd`), retrieve situation graph 13 (`sg`), and print a string representation in JSON format. To see the result, run it like this (in the virtual env): 28 | 29 | ``` 30 | % PYTHONPATH=. python retrieve.py 31 | {"V": ["water is exposed to high heat", "water is not protected from high heat"], "Z": ["water is shielded from heat", ... 32 | ``` 33 | 34 | ## Running tests 35 | 36 | Set up the virtual environment as above, then run the test like this: 37 | 38 | ``` 39 | PYTHONPATH=. 40 | pytest 41 | ``` 42 | 43 | ## Running Model 44 | 45 | ``` 46 | pip install -r model/requirements.txt 47 | bash model/run_wiqa_classifer.sh 48 | ``` 49 | 50 | Note: comment out the `--gpus` and `--accelerator` arguments in the script for CPU training 51 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/model/__init__.py -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | """Wrapper for a conditional generation dataset present in 2 tab-separated columns: 2 | source[TAB]target 3 | """ 4 | import logging 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | import pandas as pd 10 | from transformers import AutoTokenizer 11 | from tqdm import tqdm 12 | from collections import defaultdict 13 | # from src.data.creation.influence_graph import InfluenceGraph 14 | 15 | label_dict = {"less": 0, "attenuator": 0, "more": 1, "intensifier": 1, "no_effect": 2} 16 | rev_label_dict = defaultdict(list) 17 | 18 | for k, v in label_dict.items(): 19 | rev_label_dict[v].append(k) 20 | 21 | rev_label_dict = {k: "/".join(v) for k, v in rev_label_dict.items()} 22 | 23 | class GraphQaDataModule(pl.LightningDataModule): 24 | def __init__(self, basedir: str, tokenizer_name: str, batch_size: int, num_workers: int = 16): 25 | super().__init__() 26 | self.basedir = basedir 27 | self.batch_size = batch_size 28 | self.num_workers = num_workers 29 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=True) 30 | 31 | def train_dataloader(self): 32 | dataset = GraphQADataset(tokenizer=self.tokenizer, 33 | qa_pth=f"{self.basedir}/train.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 34 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 35 | shuffle=True, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 36 | 37 | def val_dataloader(self): 38 | dataset = GraphQADataset(tokenizer=self.tokenizer, 39 | qa_pth=f"{self.basedir}/dev.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 40 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 41 | shuffle=False, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 42 | 43 | def test_dataloader(self): 44 | dataset = GraphQADataset(tokenizer=self.tokenizer, 45 | qa_pth=f"{self.basedir}/test.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 46 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 47 | shuffle=False, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 48 | 49 | 50 | class GraphQADataset(Dataset): 51 | def __init__(self, tokenizer, qa_pth: str, graph_pth: str) -> None: 52 | super().__init__() 53 | self.qa_pth = qa_pth 54 | self.graph_pth = graph_pth 55 | self.tokenizer = tokenizer 56 | # self.read_graphs() 57 | self.read_qa() 58 | 59 | # def read_graphs(self): 60 | # influence_graphs = pd.read_json( 61 | # self.graph_pth, orient='records', lines=True).to_dict(orient='records') 62 | # self.graphs = {} 63 | # for graph_dict in tqdm(influence_graphs, desc="Reading graphs", total=len(influence_graphs)): 64 | # self.graphs[str(graph_dict["graph_id"])] = graph_dict 65 | 66 | def read_qa(self): 67 | logging.info("Reading data from {}".format(self.qa_pth)) 68 | data = pd.read_json(self.qa_pth, orient="records", lines=True) 69 | self.questions, self.answer_labels, self.paragraphs = [], [], [] 70 | logging.info(f"Reading QA file from {self.qa_pth}") 71 | for i, row in tqdm(data.iterrows(), total=len(data), desc="Reading QA examples"): 72 | self.answer_labels.append(row["question"]["answer_label"].strip()) 73 | para = " ".join([p.strip() for p in row["question"]["para_steps"] if len(p) > 0]) 74 | question = row["question"]["stem"].strip() 75 | self.questions.append(question) 76 | self.paragraphs.append(para) 77 | # self.graph_ids.append(row["metadata"]["graph_id"]) 78 | 79 | encoded_input = self.tokenizer(self.questions, self.paragraphs) 80 | self.input_ids = encoded_input["input_ids"] 81 | if "token_type_ids" in encoded_input: 82 | self.token_type_ids = encoded_input["token_type_ids"] 83 | else: 84 | self.token_type_ids = [[0] * len(s) for s in encoded_input["input_ids"]] # only BERT uses it anyways, so just set it to 0 85 | def __len__(self) -> int: 86 | return len(self.questions) 87 | 88 | def __getitem__(self, i): 89 | # We’ll pad at the batch level. 90 | return (self.input_ids[i], self.token_type_ids[i], self.answer_labels[i]) 91 | 92 | @staticmethod 93 | def collate_pad(batch): 94 | max_token_len = 0 95 | num_elems = len(batch) 96 | for i in range(num_elems): 97 | tokens, _, _ = batch[i] 98 | max_token_len = max(max_token_len, len(tokens)) 99 | 100 | tokens = torch.zeros(num_elems, max_token_len).long() 101 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 102 | token_type_ids = torch.zeros(num_elems, max_token_len).long() 103 | labels = torch.zeros(num_elems).long() 104 | # graphs = [] 105 | for i in range(num_elems): 106 | toks, type_ids, label = batch[i] 107 | length = len(toks) 108 | tokens[i, :length] = torch.LongTensor(toks) 109 | token_type_ids[i, :length] = torch.LongTensor(type_ids) 110 | tokens_mask[i, :length] = 1 111 | # graphs.append(graph) 112 | labels[i] = label_dict[label] 113 | return [tokens, token_type_ids, tokens_mask, labels] 114 | 115 | 116 | # class InfluenceGraphNNData: 117 | # """ 118 | # V Z 119 | # | / 120 | # - + 121 | # | / 122 | # X U 123 | # | \ | 124 | # - + - 125 | # | \ | 126 | # W Y 127 | # | \ / | 128 | # - + - 129 | # | / \ | 130 | # L M 131 | # """ 132 | # node_index = { 133 | # "V": 0, "Z": 1, "X": 2, "U": 3, "W": 4, "Y": 5, "dec": 6, "acc": 7} 134 | # index_node = {v: k for k, v in node_index.items()} 135 | # edge_index = [[0, 1, 2, 2, 3, 4, 4, 5, 5], 136 | # [2, 2, 4, 5, 5, 6, 7, 6, 7]] 137 | # EDGE_TYPE_HELPS, EDGE_TYPE_HURTS = 0, 1 138 | # def __init__(self, data) -> None: 139 | # super().__init__() 140 | # self.data = data 141 | 142 | # @staticmethod 143 | # def make_data_from_dict(graph_dict: dict, tokenizer, max_length=30): 144 | # igraph = InfluenceGraph(graph_dict) 145 | # if igraph.graph["Y_affects_outcome"] == "more": 146 | # # the final edges depend on para outcome 147 | # edge_features = [InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS] 148 | # else: 149 | # edge_features = [InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS] 150 | 151 | # node_sentences = [] 152 | # for node in InfluenceGraphNNData.node_index: 153 | # if node in igraph.nodes_dict and len(igraph.nodes_dict[node]) > 0: 154 | # node_sentences.append(" [OR] ".join(igraph.nodes_dict[node])) 155 | # else: 156 | # node_sentences.append(tokenizer.pad_token) 157 | # encoding_dict = tokenizer(node_sentences, max_length=max_length, truncation=True) 158 | # return Data(graph_id = str(igraph.graph_id), num_nodes = len(InfluenceGraphNNData.node_index), tokens=encoding_dict["input_ids"], 159 | # edge_index=torch.tensor(InfluenceGraphNNData.edge_index).long(), edge_attr=torch.tensor(edge_features).long()) 160 | 161 | if __name__ == "__main__": 162 | import sys 163 | dm = GraphQaDataModule( 164 | basedir=sys.argv[1], model_name=sys.argv[2], batch_size=32) 165 | for (tokens, tokens_mask, labels) in dm.train_dataloader(): 166 | print(torch.tensor(tokens_mask[0].tokens).shape) 167 | -------------------------------------------------------------------------------- /model/qa_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_lightning.core.lightning import LightningModule 6 | from torch.optim import AdamW 7 | from transformers import AutoModel 8 | from transformers import BertConfig 9 | from transformers.models.bert.modeling_bert import BertPooler 10 | 11 | 12 | class GraphQaModel(LightningModule): 13 | def __init__(self, hparams): 14 | super().__init__() 15 | self.hparams = hparams 16 | self.save_hyperparameters() 17 | config = BertConfig() 18 | #self.model = BertForSequenceClassification.from_pretrained(self.hparams.model_name, num_labels=self.hparams.n_class) 19 | self.model = AutoModel.from_pretrained(self.hparams.model_name) 20 | self.pooler = BertPooler(config) 21 | # self.attention = MultiheadedAttention(h_dim=self.hparams.h_dim, kqv_dim=self.hparams.kqv_dim, n_heads=self.hparams.n_heads) 22 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 23 | self.classifier = nn.Linear(config.hidden_size, self.hparams.n_class) 24 | self.loss = nn.CrossEntropyLoss() 25 | 26 | @staticmethod 27 | def add_model_specific_args(parent_parser): 28 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 29 | parser.add_argument("--min_lr", default=0, type=float, 30 | help="Minimum learning rate.") 31 | parser.add_argument("--h_dim", type=int, 32 | help="Size of the hidden dimension.", default=768) 33 | parser.add_argument("--n_heads", type=int, 34 | help="Number of attention heads.", default=1) 35 | parser.add_argument("--kqv_dim", type=int, 36 | help="Dimensionality of the each attention head.", default=256) 37 | parser.add_argument("--n_class", type=float, 38 | help="Number of classes.", default=3) 39 | parser.add_argument("--lr", default=5e-4, type=float, 40 | help="Initial learning rate.") 41 | parser.add_argument("--weight_decay", default=0.01, type=float, 42 | help="Weight decay rate.") 43 | parser.add_argument("--warmup_prop", default=0., type=float, 44 | help="Warmup proportion.") 45 | parser.add_argument("--num_relations", default=2, type=int, 46 | help="The number of relations (edge types) in your graph.") 47 | parser.add_argument( 48 | "--model_name", default='bert-base-uncased', help="Model to use.") 49 | return parser 50 | 51 | def configure_optimizers(self): 52 | return AdamW(self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.99), 53 | eps=1e-8) 54 | 55 | def forward(self, batch): 56 | question_tokens, question_type_ids, question_masks, labels = batch 57 | 58 | # step 1: encode the question/paragraph 59 | question_cls_embeddeding = self.forward_bert( 60 | input_ids=question_tokens, token_type_ids=question_type_ids, attention_mask=question_masks) 61 | 62 | 63 | logits = self.classifier(question_cls_embeddeding) 64 | predicted_labels = torch.argmax(logits, -1) 65 | acc = torch.true_divide( 66 | (predicted_labels == labels).sum(), labels.shape[0]) 67 | return logits, acc 68 | 69 | def forward_bert(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None): 70 | """Returns the pooled token from BERT 71 | """ 72 | outputs = self.model( 73 | input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) 74 | hidden_states = outputs["hidden_states"] 75 | cls_embeddeding = self.dropout(self.pooler(hidden_states[-1])) 76 | return cls_embeddeding 77 | 78 | def training_step(self, batch, batch_idx): 79 | # Load the data into variables 80 | logits, acc = self(batch) 81 | loss = self.loss(logits, batch[-1]) 82 | self.log('train_acc', acc, on_step=True, 83 | on_epoch=True, prog_bar=True, sync_dist=True) 84 | return {"loss": loss} 85 | 86 | 87 | def validation_step(self, batch, batch_idx): 88 | # Load the data into variables 89 | logits, acc = self(batch) 90 | 91 | loss_f = nn.CrossEntropyLoss() 92 | loss = loss_f(logits, batch[-1]) 93 | 94 | self.log('val_loss', loss, on_step=True, 95 | on_epoch=True, prog_bar=True, sync_dist=True) 96 | self.log('val_acc', acc, on_step=True, on_epoch=True, 97 | prog_bar=True, sync_dist=True) 98 | return {"loss": loss} 99 | 100 | def test_step(self, batch, batch_idx): 101 | # Load the data into variables 102 | logits, acc = self(batch) 103 | 104 | loss_f = nn.CrossEntropyLoss() 105 | loss = loss_f(logits, batch[-1]) 106 | return {"loss": loss} 107 | 108 | def get_progress_bar_dict(self): 109 | tqdm_dict = super().get_progress_bar_dict() 110 | tqdm_dict.pop("v_num", None) 111 | tqdm_dict.pop("val_loss_step", None) 112 | tqdm_dict.pop("val_acc_step", None) 113 | return tqdm_dict 114 | 115 | 116 | if __name__ == "__main__": 117 | sentences = ['This framework generates embeddings for each input sentence', 118 | 'Sentences are passed as a list of string.', 119 | 'The quick brown fox jumps over the lazy dog.'] 120 | -------------------------------------------------------------------------------- /model/requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | # gnn_qa/data.py: 13 4 | # gnn_qa/infer.py: 4 5 | # gnn_qa/qa_model.py: 11 6 | # gnn_qa/run.py: 8 7 | numpy == 1.16.6 8 | 9 | # gnn_qa/data.py: 11 10 | pandas == 1.0.3 11 | 12 | # gnn_qa/data.py: 5 13 | # gnn_qa/ignn.py: 10 14 | # gnn_qa/infer.py: 5 15 | # gnn_qa/qa_model.py: 7 16 | # gnn_qa/run.py: 6,9,12,13 17 | pytorch_lightning == 1.1.2 18 | 19 | # gnn_qa/data.py: 6,7,9,10 20 | # gnn_qa/ignn.py: 3,5 21 | # gnn_qa/infer.py: 1 22 | # gnn_qa/qa_model.py: 4,8,9,12 23 | # gnn_qa/run.py: 3 24 | # gnn_qa/utils.py: 1,2 25 | torch == 1.6.0 26 | 27 | # gnn_qa/ignn.py: 4,6,9 28 | # gnn_qa/qa_model.py: 2,5,13 29 | torch_geometric == 1.6.3 30 | 31 | # gnn_qa/data.py: 17 32 | # gnn_qa/infer.py: 7 33 | tqdm == 4.48.0 34 | 35 | # gnn_qa/data.py: 14,16 36 | # gnn_qa/qa_model.py: 3,15 37 | transformers == 4.1.1 38 | -------------------------------------------------------------------------------- /model/run.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning 2 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 3 | import random 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import logging 7 | from argparse import ArgumentParser 8 | import resource 9 | from model.data import GraphQaDataModule 10 | from model.qa_model import GraphQaModel 11 | 12 | def get_train_steps(dm): 13 | total_devices = args.num_gpus * args.num_nodes 14 | train_batches = len(dm.train_dataloader()) // total_devices 15 | return (args.max_epochs * train_batches) // args.accumulate_grad_batches 16 | 17 | 18 | 19 | 20 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 21 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 22 | # init: important to make sure every node initializes the same weights 23 | SEED = 43 24 | np.random.seed(SEED) 25 | random.seed(SEED) 26 | pl.utilities.seed.seed_everything(SEED) 27 | pytorch_lightning.seed_everything(SEED) 28 | 29 | 30 | # argparser 31 | parser = ArgumentParser() 32 | parser.add_argument('--num_gpus', type=int) 33 | parser.add_argument('--batch_size', type=int, default=16) 34 | parser.add_argument('--clip_grad', type=float, default=1.0) 35 | parser.add_argument("--dataset_basedir", help="Base directory where the dataset is located.", type=str) 36 | 37 | parser = pl.Trainer.add_argparse_args(parser) 38 | parser = GraphQaModel.add_model_specific_args(parser) 39 | 40 | args = parser.parse_args() 41 | args.num_gpus = len(str(args.gpus).split(",")) 42 | 43 | 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | # Step 1: Init Data 47 | logging.info("Loading the data module") 48 | dm = GraphQaDataModule(basedir=args.dataset_basedir, tokenizer_name=args.model_name, batch_size=args.batch_size) 49 | 50 | # Step 2: Init Model 51 | logging.info("Initializing the model") 52 | model = GraphQaModel(hparams=args) 53 | model.hparams.warmup_steps = int(get_train_steps(dm) * model.hparams.warmup_prop) 54 | lr_monitor = LearningRateMonitor(logging_interval='step') 55 | 56 | # Step 3: Start 57 | logging.info("Starting the training") 58 | checkpoint_callback = ModelCheckpoint( 59 | filename='{epoch}-{step}-{val_acc_epoch:.2f}', 60 | save_top_k=3, 61 | verbose=True, 62 | monitor='val_acc_epoch', 63 | mode='max' 64 | ) 65 | 66 | trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], val_check_interval=0.5, gradient_clip_val=args.clip_grad, track_grad_norm=2) 67 | trainer.fit(model, dm) -------------------------------------------------------------------------------- /model/run_wiqa_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export TOKENIZERS_PARALLELISM=false 3 | python model/run.py --dataset_basedir data/wiqa-qa/ \ 4 | --lr 2e-5 --max_epochs 20 \ 5 | --gpus 2 \ 6 | --accelerator ddp -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | requests 3 | pytest 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/__init__.py -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/eval/__init__.py -------------------------------------------------------------------------------- /src/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Any 2 | 3 | from src.third_party_utils.nltk_porter_stemmer import PorterStemmer 4 | from src.third_party_utils.spacy_stop_words import STOP_WORDS 5 | 6 | 7 | def get_label_from_question_metadata(correct_answer_key: str, 8 | question_dict: List[Dict[str, str]]) -> str: 9 | for item in question_dict: 10 | try: 11 | if correct_answer_key == item['label']: 12 | return item['text'] 13 | except KeyError as exc: 14 | raise KeyError(f"key='label' or 'text' absent in item:\n{item}.\nException = {exc}") 15 | raise KeyError(f"key='label' absent in metadata:\n{question_dict}") 16 | 17 | 18 | def split_question_cause_effect(question: str) -> Tuple[str, str]: 19 | question = question.lower() 20 | question_split = question.split("happens, how will it affect") 21 | cause_part = question_split[0].replace("suppose", "") 22 | effect_part = question_split[1] 23 | return cause_part, effect_part 24 | 25 | 26 | def get_most_similar_idx_word_overlap(p1s: List[str], p2:str): 27 | if not p1s or not p2: 28 | return -1 29 | max_idx = -1 30 | max_sim = 0.0 31 | k1 = set(get_content_words(p2)) 32 | for idx, p1 in enumerate(p1s): 33 | k2 = set(get_content_words(p1)) 34 | sim = 1.0 * len(k1.intersection(k2)) / (1.0 * (len(k1.union(k2)))) 35 | if sim > max_sim: 36 | max_sim = sim 37 | max_idx = idx 38 | return max_idx 39 | 40 | def predict_word_overlap_best(input_steps, input_cq, input_eq): 41 | xsentid = get_most_similar_idx_word_overlap(p1s=input_steps, p2=input_cq) 42 | ysentid = get_most_similar_idx_word_overlap(p1s=input_steps, p2=input_eq) 43 | return xsentid, ysentid 44 | 45 | 46 | def is_stop(cand): 47 | return cand.lower() in STOP_WORDS 48 | 49 | # your other hand => hand 50 | # the hand => hand 51 | def drop_leading_articles_and_stopwords(p): 52 | # other and another can only appear after the primary articles in first line. 53 | articles = ["a ", "an ", "the ", "your ", "his ", "their ", "my ", "this ", "that ", 54 | "another ", "other ", "more ", "less "] 55 | for article in articles: 56 | if p.lower().startswith(article): 57 | p = p[len(article):] 58 | words = p.split(" ") 59 | answer = "" 60 | for idx, w in enumerate(words): 61 | if is_stop(w): 62 | continue 63 | else: 64 | answer = " ".join(words[idx:]) 65 | break 66 | return answer 67 | 68 | 69 | def stem(w: str): 70 | if not w or len(w.strip()) == 0: 71 | return "" 72 | w_lower = w.lower() 73 | # Remove leading articles from the phrase (e.g., the rays => rays). 74 | # FIXME: change this logic to accept a list of leading articles. 75 | if w_lower.startswith("a "): 76 | w_lower = w_lower[2:] 77 | elif w_lower.startswith("an "): 78 | w_lower = w_lower[3:] 79 | elif w_lower.startswith("the "): 80 | w_lower = w_lower[4:] 81 | elif w_lower.startswith("your "): 82 | w_lower = w_lower[5:] 83 | elif w_lower.startswith("his "): 84 | w_lower = w_lower[4:] 85 | elif w_lower.startswith("their "): 86 | w_lower = w_lower[6:] 87 | elif w_lower.startswith("my "): 88 | w_lower = w_lower[3:] 89 | elif w_lower.startswith("another "): 90 | w_lower = w_lower[8:] 91 | elif w_lower.startswith("other "): 92 | w_lower = w_lower[6:] 93 | elif w_lower.startswith("this "): 94 | w_lower = w_lower[5:] 95 | elif w_lower.startswith("that "): 96 | w_lower = w_lower[5:] 97 | # Porter stemmer: rays => ray 98 | return PorterStemmer().stem(w_lower).strip() 99 | 100 | 101 | def stem_words(words): 102 | return [stem(w) for w in words] 103 | 104 | 105 | def get_content_words(s): 106 | para_words_prev = s.strip().lower().split(" ") 107 | para_words = set() 108 | for word in para_words_prev: 109 | if not is_stop(word): 110 | para_words.add(word) 111 | return stem_words(list(para_words)) 112 | 113 | 114 | def find_max(input_list: List[float]) -> Tuple[Any, int]: 115 | max_item = max(input_list) 116 | return max_item, input_list.index(max_item) 117 | -------------------------------------------------------------------------------- /src/eval/evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import enum 3 | import json 4 | from typing import List, Dict 5 | 6 | from src.helpers.dataset_info import wiqa_explanations_v1 7 | from src.wiqa_wrapper import WIQAQuesType, download_from_url_if_not_in_cache, Jsonl, WIQAExplanation, \ 8 | WIQAExplanationType 9 | 10 | 11 | class FineEvalMetrics(enum.Enum): 12 | EQDIR = 0 13 | XSENTID = 1 14 | YSENTID = 2 15 | XDIR = 3 16 | 17 | def to_json(self): 18 | return self.name 19 | 20 | 21 | class Precision: 22 | def __init__(self): 23 | self.tp = 0 24 | self.fp = 0 25 | 26 | def __str__(self): 27 | if self.tp + self.fp == 0: 28 | return "0" 29 | else: 30 | return self.tp / (self.tp + self.fp) 31 | 32 | def p(self): 33 | if self.tp + self.fp == 0.0: 34 | return 0.0 35 | else: 36 | return 1.0 * self.tp / (1.0 * (self.tp + self.fp)) 37 | 38 | 39 | class InputRequiredByFineEval: 40 | def __init__(self, 41 | graph_id: str, 42 | ques_type: WIQAQuesType, 43 | path_arr: List[str], 44 | metrics: Dict[FineEvalMetrics, bool]): 45 | self.graph_id = graph_id 46 | self.path_arr = path_arr 47 | self.ques_type = ques_type 48 | self.metrics = metrics 49 | 50 | def get_path_len(self): 51 | return len(self.path_arr) 52 | 53 | def __str__(self): 54 | return f"graphid={self.graph_id}\nques_type={self.ques_type.name}" \ 55 | f"path_arr={self.path_arr}\nmetrics={self.metrics}" 56 | 57 | # {"logits": [[0.7365949749946594, 0.16483880579471588, 0.38572263717651367, 0.33268705010414124, -0.6508089900016785, -0.950057864189148], [0.3361547291278839, -0.02215845137834549, 0.8630506992340088, 0.4753769040107727, -1.0981523990631104, 0.04984292760491371], [0.4498467743396759, 0.26323091983795166, 0.5597160458564758, 0.06369128823280334, -0.33793506026268005, -0.30190590023994446], [0.41394802927970886, 0.31742218136787415, 0.42982375621795654, -0.2891058027744293, -0.09577881544828415, -0.4486318528652191], [0.5242481231689453, -0.05186435207724571, 0.4505387544631958, -0.43092456459999084, -0.015227549709379673, 0.10361793637275696], [0.8527745604515076, 0.18845966458320618, 0.6540948748588562, -0.06324845552444458, -0.03267676383256912, 0.058296892791986465], [0.40418609976768494, -0.24220454692840576, 0.0737631767988205, -0.8445389270782471, -0.12929767370224, 0.5813987851142883]], "class_probabilities": [[0.29663264751434326, 0.16745895147323608, 0.20885121822357178, 0.19806326925754547, 0.07407592236995697, 0.054918017238378525], [0.18079228699207306, 0.12634745240211487, 0.3062019348144531, 0.20779892802238464, 0.04307926073670387, 0.13578014075756073], [0.21968600153923035, 0.18228720128536224, 0.24519862234592438, 0.14931286871433258, 0.09992476552724838, 0.10359060019254684], [0.22513438761234283, 0.20441898703575134, 0.22873708605766296, 0.11145754158496857, 0.13522914052009583, 0.09502287209033966], [0.24298663437366486, 0.13657772541046143, 0.2257203906774521, 0.09348806738853455, 0.14167429506778717, 0.15955300629138947], [0.27786338329315186, 0.1429957151412964, 0.22779585421085358, 0.11117511987686157, 0.11462641507387161, 0.12554346024990082], [0.23202574253082275, 0.1215660497546196, 0.16673828661441803, 0.06656130403280258, 0.1360965520143509, 0.27701207995414734]], "metadata": {"question": {"question": {"stem": "suppose during boiling point happens, how will it affect more evaporation.", "choices": [{"text": "Correct effect", "label": "A"}, {"text": "Opposite effect", "label": "B"}, {"text": "No effect", "label": "C"}]}, "answerKey": "A", "explanation": " ['during boiling point', 'during sunshine'] ==> ['increase water temperatures at least 100 C'] ==> ['more vapors'] ==> ['MORE evaporation?']", "path_info": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A", "more_info": {"tf_q_type": "EXOGENOUS_EFFECT", "prompt": "Describe the process of evaporation", "para_id": "127", "group_ids": {"NO_GROUPING": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT", "BY_SRC_DEST": "Z,A", "BY_SRC_LABEL_DEST": "Z->SituationLabel.RESULTS_IN->A", "BY_PROMPT": "Describe the process of evaporation", "BY_PARA": "127", "BY_FULL_PATH": "[Z, X, Y, A]", "BY_GROUNDING": "influence_graph,127,12,39", "BY_SRC_DEST_INTRA": "12,Z,A", "BY_SRC_DEST_STEM_INTRA": "12,Z,A,In the context of describe the process of evaporation, suppose during boiling point happens, how will it affect MORE evaporation?.", "BY_TF_Q_TYPE": "EXOGENOUS_EFFECT", "BY_PATH_LENGTH": 4}, "all_q_keys": ["is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT", "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, W, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,38#0:tf_q_type^EXOGENOUS_EFFECT"]}, "para": "Water is exposed to heat energy, like sunlight. The water temperature is raised above 212 degrees fahrenheit. The heat breaks down the molecules in the water. These molecules escape from the water. The water becomes vapor. The vapor evaporates into the atmosphere. ", "graph_id": "12", "para_id": "127", "prompt": "Describe the process of evaporation", "id": "influence_graph:127:12:39#0", "distractor_info": {}, "primary_question_key": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT"}, "kg": [{"from_node": "increase water temperatures at least 100 C", "to_node": "more ice", "label": "NOT_RESULTS_IN"}, {"from_node": "increase water temperatures at least 100 C", "to_node": "less sweating", "label": "NOT_RESULTS_IN"}, {"from_node": "increase water temperatures at least 100 C", "to_node": "more vapors", "label": "RESULTS_IN"}, {"from_node": "low water temperatures at most 100 C", "to_node": "more vapors", "label": "NOT_RESULTS_IN"}, {"from_node": "less water molecule colliding", "to_node": "more vapors", "label": "NOT_RESULTS_IN"}, {"from_node": "more ice", "to_node": "MORE evaporation?", "label": "NOT_RESULTS_IN"}, {"from_node": "less sweating", "to_node": "MORE evaporation?", "label": "NOT_RESULTS_IN"}, {"from_node": "more ice", "to_node": "LESS evaporation", "label": "RESULTS_IN"}, {"from_node": "less sweating", "to_node": "LESS evaporation", "label": "RESULTS_IN"}, {"from_node": "more vapors", "to_node": "MORE evaporation?", "label": "RESULTS_IN"}, {"from_node": "more vapors", "to_node": "LESS evaporation", "label": "NOT_RESULTS_IN"}], "explanation": {"steps": "(1. Water is exposed to heat energy, like sunlight. 2. The water temperature is raised above 212 degrees fahrenheit. 3. The heat breaks down the molecules in the water. 4. These molecules escape from the water. 5. The water becomes vapor. 6. The vapor evaporates into the atmosphere. 7. .)", "x_sent_id": 1, "y_sent_id": 4, "x_dir": "RESULTS_IN", "y_dir": "RESULTS_IN", "x_grounding": "increase water temperatures at least 100 C", "y_grounding": "more vapors", "is_valid_tuple": true, "eq_dir_orig_answer": "RESULTS_IN"}, "id": "NA"}, "tags": ["O", "E+", "E+", "E+", "O", "O", "E-"]} 58 | @staticmethod 59 | def from_(prediction_on_this_example: WIQAExplanation, 60 | json_from_question: Dict, 61 | expl_type: WIQAExplanationType): 62 | # get expected answer and whether predicted and expected match. 63 | expected = WIQAExplanation.instantiate_from(json_from_question) 64 | 65 | metrics = {} 66 | metrics[FineEvalMetrics.EQDIR] = expected.de == prediction_on_this_example.de 67 | if expl_type == WIQAExplanationType.PARA_SENT_EXPL: 68 | metrics[FineEvalMetrics.XDIR] = expected.di == prediction_on_this_example.di 69 | metrics[FineEvalMetrics.XSENTID] = expected.i == prediction_on_this_example.i 70 | metrics[FineEvalMetrics.YSENTID] = expected.j == prediction_on_this_example.j 71 | return InputRequiredByFineEval(graph_id=json_from_question["metadata"]["graph_id"], 72 | ques_type=WIQAQuesType.from_str(json_from_question["metadata"]["question_type"]), 73 | path_arr=[], 74 | metrics=metrics) 75 | 76 | @staticmethod 77 | def accumulate_metrics(current_metrics: Dict[FineEvalMetrics, bool], 78 | all_metrics: Dict[FineEvalMetrics, Precision]): 79 | for k, is_correct in current_metrics.items(): 80 | if k not in all_metrics: 81 | all_metrics[k] = Precision() 82 | if is_correct: 83 | all_metrics[k].tp += 1 84 | else: 85 | all_metrics[k].fp += 1 86 | 87 | 88 | class MetricEvaluator: 89 | def __init__(self): 90 | self.entries: List[InputRequiredByFineEval] = [] 91 | 92 | def add_entry(self, entry: InputRequiredByFineEval): 93 | self.entries.append(entry) 94 | 95 | def group_by_ques_type(self): 96 | per_ques_type: Dict[WIQAQuesType, Dict[FineEvalMetrics, Precision]] = {} 97 | for e in self.entries: 98 | if e.ques_type not in per_ques_type: 99 | per_ques_type[e.ques_type] = {} 100 | InputRequiredByFineEval.accumulate_metrics(current_metrics=e.metrics, 101 | all_metrics=per_ques_type[e.ques_type]) 102 | return per_ques_type 103 | 104 | def group_by_path_len(self): 105 | per_path_len: Dict[int, Dict[FineEvalMetrics, Precision]] = {} 106 | for e in self.entries: 107 | if e.get_path_len() not in per_path_len: 108 | per_path_len[e.get_path_len()] = {} 109 | InputRequiredByFineEval.accumulate_metrics(current_metrics=e.metrics, 110 | all_metrics=per_path_len[e.get_path_len()]) 111 | return per_path_len 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Analyze BERT results per question and per question type', 116 | usage="\n\npython src/dataset_creation_of_wiqa/analysis/fine_eval_on_wiqa_models.py" 117 | "\n\t --pred_path dir_path/eval_test.json" 118 | "\n\t --out_path /tmp/bert_anal" 119 | "\n\t --from_model tagging_model" 120 | "\n\t --group_by questype" 121 | ) 122 | 123 | # ------------------------------------------------ 124 | # Mandatory arguments. 125 | # ------------------------------------------------ 126 | 127 | parser.add_argument('--pred_path', 128 | action='store', 129 | dest='pred_path', 130 | required=True, 131 | help='File path containing predictors json output.') 132 | 133 | parser.add_argument('--group_by', 134 | action='store', 135 | dest='group_by', 136 | required=True, 137 | help='questype|pathlen') 138 | 139 | parser.add_argument('--from_model', 140 | action='store', 141 | dest='from_model', 142 | required=True, 143 | help='your_model_name|no_explanation_model') 144 | 145 | parser.add_argument('--out_path', 146 | action='store', 147 | dest='out_path', 148 | required=True, 149 | help='File path to store output such as metrics.json') 150 | 151 | args = parser.parse_args() 152 | m = MetricEvaluator() 153 | 154 | # Compile input for metrics. 155 | if args.from_model == "wordoverlap_baseline_model": 156 | args.pred_path = download_from_url_if_not_in_cache(cloud_path=wiqa_explanations_v1.cloud_path + "test.jsonl") 157 | if args.from_model =="vectoroverlap_baseline_model": 158 | raise NotImplementedError 159 | 160 | if args.from_model == "emnlp19_model": 161 | map_of_expected_keys = {} 162 | for x in Jsonl.load(download_from_url_if_not_in_cache(wiqa_explanations_v1.cloud_path + "test.jsonl")): 163 | key = x["question"]["id"] 164 | value = {"graph_id": x["question"]["graph_id"], 165 | "path_arr": x["question"]["path_info"], 166 | "qtype": x["question"]["more_info"]["tf_q_type"]} 167 | map_of_expected_keys[key] = value 168 | 169 | outfile = open(args.out_path, 'w') 170 | with open(args.pred_path) as infile: 171 | for line in infile: 172 | j = json.loads(line) 173 | if args.from_model == "no_explanation_model": 174 | entry = InputRequiredByFineEval.from_(expl_type=WIQAExplanationType.NO_EXPL, 175 | json_from_question=j["orig_question"], 176 | prediction_on_this_example=WIQAExplanation.instantiate_from(j) 177 | ) 178 | else: 179 | raise NotImplementedError(f"fine_eval script does not support model: {args.from_model}, " 180 | f"implement your own here in evaluation.py") 181 | m.add_entry(entry=entry) 182 | 183 | # Compute metrics. 184 | if args.group_by == "pathlen": 185 | metrics = m.group_by_path_len() 186 | elif args.group_by == "questype": 187 | metrics = m.group_by_ques_type() 188 | else: 189 | raise NotImplementedError(f"fine_eval script does not support group by: {args.group_by}") 190 | 191 | # Write metrics to file. 192 | overall = {} 193 | for k, v_dict in metrics.items(): 194 | for k2, v2 in v_dict.items(): 195 | if k2 not in overall: 196 | overall[k2] = Precision() 197 | overall[k2].tp += v2.tp 198 | overall[k2].fp += v2.fp 199 | outfile.write(f"{k.name}_{k2.name}:{v2.p():0.4}\n") 200 | 201 | outlines = [] 202 | for k, v in overall.items(): 203 | outline = f"{k.name}_overall:{v.p():0.4}" 204 | outlines.append(outline) 205 | outfile.write(f"{outline}\n") 206 | print(f"\nOutput is in {args.out_path}") 207 | print("\n".join(outlines)) 208 | outfile.close() 209 | -------------------------------------------------------------------------------- /src/helpers/ProparaExtendedPara.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | from typing import List, Dict 4 | 5 | from src.helpers.dataset_info import propara_para_info, download_from_url_if_not_in_cache 6 | 7 | 8 | class ProparaExtendedParaEntry: 9 | def __init__(self, topic: str, prompt: str, paraid: str, 10 | s1: str, s2: str, s3: str, 11 | s4: str, s5: str, s6: str, 12 | s7: str, s8: str, s9: str, 13 | s10: str): 14 | self.topic = topic 15 | self.prompt = prompt 16 | self.paraid = paraid 17 | all_sentences = [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] 18 | self.sentences = [x for x in all_sentences if x] 19 | 20 | def sent_at(self, sentidx_startingzero: int): 21 | assert len(self.sentences) > sentidx_startingzero 22 | return self.sentences[sentidx_startingzero] 23 | 24 | def get_sentence_arr(self): 25 | return self.sentences 26 | 27 | def as_json(self): 28 | return json.dumps(self.__dict__) 29 | 30 | @staticmethod 31 | def from_json(j): 32 | raise NotImplementedError("from json is not done yet.") 33 | 34 | @staticmethod 35 | def from_tsv(t): 36 | return ProparaExtendedParaEntry(*t) 37 | 38 | 39 | class ProparaExtendedParaMetadata: 40 | 41 | def __init__(self, 42 | extended_propara_para_fp=download_from_url_if_not_in_cache( 43 | propara_para_info.cloud_path)): 44 | self.para_map: Dict[str, ProparaExtendedParaEntry] = dict() 45 | for row_num, row_as_arr in enumerate(csv.reader(open(extended_propara_para_fp), delimiter="\t")): 46 | if row_num > 0: # skips header 47 | e: ProparaExtendedParaEntry = ProparaExtendedParaEntry.from_tsv(row_as_arr) 48 | self.para_map[e.paraid] = e 49 | 50 | def paraentry_for_id(self, para_id: str) -> ProparaExtendedParaEntry: 51 | return self.para_map.get(para_id, None) 52 | 53 | def sentences_for_paraid(self, paraid: str) -> List[str]: 54 | return self.paraentry_for_id(para_id=paraid).get_sentence_arr() 55 | 56 | def sent_at(self, paraid: str, sentidx_startingzero: int) -> str: 57 | return self.paraentry_for_id(para_id=paraid).sent_at(sentidx_startingzero=sentidx_startingzero) 58 | 59 | 60 | if __name__ == '__main__': 61 | o = ProparaExtendedParaMetadata() 62 | assert o.sent_at(paraid="2453", 63 | sentidx_startingzero=1) == "The heat from the reaction is used to create steam in water." 64 | assert o.sent_at(paraid="24", 65 | sentidx_startingzero=2) == "Depending on what is eroding the valley, the slope of the land, the type of rock or soil and the amount of time the land has been eroded." 66 | assert not o.sent_at(paraid="24", 67 | sentidx_startingzero=2) == "The heat from the reaction is used to create steam in water." 68 | -------------------------------------------------------------------------------- /src/helpers/SGFileLoaders.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from random import sample 4 | 5 | 6 | def compile_input_files(dir_or_file_path): 7 | input_is_directory = os.path.isdir(dir_or_file_path) 8 | input_files = [] 9 | if input_is_directory: 10 | input_dir = os.fsencode(dir_or_file_path) 11 | for infile_bytename in os.listdir(input_dir): 12 | infile_fullpath = os.path.join( 13 | input_dir.decode("utf-8"), infile_bytename.decode("utf-8")) 14 | input_files.append(infile_fullpath) 15 | else: 16 | input_files.append(dir_or_file_path) 17 | return input_files 18 | 19 | 20 | def load_grpkey_to_qkeys(path, meta_info): 21 | ''' 22 | { 23 | "group_key":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 24 | "question_keys":[ 25 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 26 | ] 27 | } 28 | :param path: 29 | :param meta_info: 30 | :return: 31 | ''' 32 | grpkey_to_qkeys = {} 33 | for in_fp in compile_input_files(path): 34 | with open(in_fp, 'r') as infile: 35 | for line in infile: 36 | j = json.loads(line) 37 | grpkey_to_qkeys[j["group_key"]] = j["question_keys"] 38 | return grpkey_to_qkeys 39 | 40 | 41 | def load_qkey_to_grpkey(path, meta_info): 42 | ''' 43 | { 44 | "group_key":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 45 | "question_keys":[ 46 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 47 | ] 48 | } 49 | :param path: 50 | :param meta_info: 51 | :return: 52 | ''' 53 | qkey_to_grpkey = {} 54 | for in_fp in compile_input_files(path): 55 | with open(in_fp, 'r') as infile: 56 | for line in infile: 57 | j = json.loads(line) 58 | grp_key = j["group_key"] 59 | for q_key in j["question_keys"]: 60 | qkey_to_grpkey[q_key] = grp_key 61 | return qkey_to_grpkey 62 | 63 | 64 | def load_qkey_to_qjson(path, meta_info): 65 | ''' 66 | { 67 | "question":{ 68 | "stem":"In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 69 | "choices":[ 70 | { 71 | "text":"True", 72 | "label":"C" 73 | }, 74 | { 75 | "text":"False", 76 | "label":"D" 77 | } 78 | ] 79 | }, 80 | "answerKey":"C", 81 | "explanation":"less magma is formed=>more magma is released", 82 | "path_info":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X", 83 | "more_info":{ 84 | "tf_q_type":"EXOGENOUS_EFFECT", 85 | "prompt":"How does igneous rock form?", 86 | "para_id":"34", 87 | "group_ids":{ 88 | "NO_GROUPING":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT", 89 | "BY_SRC_DEST":"V,X", 90 | "BY_SRC_LABEL_DEST":"V->SituationLabel.NOT_RESULTS_IN->X", 91 | "BY_PROMPT":"How does igneous rock form?", 92 | "BY_PARA":"34", 93 | "BY_FULL_PATH":"[V, X]", 94 | "BY_GROUNDING":"34,1,1,0", 95 | "BY_SRC_DEST_INTRA":"1,V,X", 96 | "BY_SRC_DEST_STEM_INTRA":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released" 97 | }, 98 | "all_q_keys":[ 99 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 100 | ] 101 | }, 102 | "para":"Volcanos contain magma. The magma is very hot. The magma rises toward the surface of the volcano. The magma cools. The magma starts to harden as it cools. The magma is sometimes released from the volcano as lava. The magma or lava becomes a hard rock as it solidifies.", 103 | "graph_id":"1", 104 | "id":"34:1:1#0#0", 105 | "primary_question_key":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 106 | } 107 | :param path: 108 | :param meta_info: 109 | :return: 110 | ''' 111 | qkey_to_qjson = {} 112 | for in_fp in compile_input_files(path): 113 | with open(in_fp, 'r') as infile: 114 | for line in infile: 115 | j = json.loads(line) 116 | qkey_to_qjson[j["primary_question_key"]] = j 117 | return qkey_to_qjson 118 | 119 | 120 | def load_qkey_to_ans(path, meta_info): 121 | ''' 122 | { 123 | "id":"34:1:1#0#0", 124 | "primary_question_key":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT", 125 | "answerKey":"C" 126 | } 127 | :param path: 128 | :param meta_info: 129 | :return: 130 | ''' 131 | qkey_to_ans = {} 132 | for in_fp in compile_input_files(path): 133 | with open(in_fp, 'r') as infile: 134 | for line in infile: 135 | j = json.loads(line) 136 | qkey_to_ans[j["primary_question_key"]] = j["answerKey"] 137 | return qkey_to_ans 138 | 139 | 140 | def load_allennlp_qkey_to_ans(path, qkey_to_qjson_map, meta_info): 141 | ''' 142 | { 143 | "id":"131", 144 | "question":"NA", 145 | "question_text":"It is observed that scavengers increase in number happens. In the context of How are fossils formed? , what is the least likely cause?", 146 | "choice_text_list":[ 147 | "habitat is destroyed", 148 | "humans hunt less animals", 149 | "", 150 | "" 151 | ], 152 | "correct_answer_index":0, 153 | "label_logits":[ 154 | -6.519074440002441, 155 | -6.476490497589111, 156 | -6.935580730438232, 157 | -6.935580730438232 158 | ], 159 | "label_probs":[ 160 | 0.29742464423179626, 161 | 0.31036368012428284, 162 | 0.19610585272312164, 163 | 0.19610585272312164 164 | ], 165 | "answer_index":1 166 | } 167 | :param meta_info: 168 | :return: 169 | ''' 170 | # step 1 : map allennlp_qkey_to_qkey 171 | # step 2 : existing functionality then. 172 | qkey_from_allennlp_key_map = {v["id"]: k for k, v in qkey_to_qjson_map.items()} 173 | qkey_to_ans = {} 174 | for in_fp in compile_input_files(path): 175 | with open(in_fp, 'r') as infile: 176 | for line in infile: 177 | j = json.loads(line) 178 | makeshift_id = j["id"] 179 | qkey_from_allennlp_key = qkey_from_allennlp_key_map[makeshift_id] 180 | qkey_to_ans[qkey_from_allennlp_key] = j["choice_text_list"][int(j["answer_index"])] 181 | print(f"Loaded {len(qkey_to_ans)} system answers from {path}") 182 | return qkey_to_ans 183 | 184 | 185 | def serialize_whole(out_file_path, items, randomize_order=False, header=""): 186 | print(f"Writing {len(items)} items (e.g., question jsons) to directory: {out_file_path}") 187 | curr_file = open(out_file_path, 'w') 188 | if header: 189 | curr_file.write(header) 190 | if "\n" not in header: 191 | curr_file.write("\n") 192 | 193 | items_randomized = items 194 | if randomize_order: 195 | items_randomized = sample(items, len(items)) 196 | for item_num, item in enumerate(items_randomized): 197 | curr_file.write(item) 198 | if "\n" not in item: 199 | curr_file.write("\n") 200 | 201 | curr_file.close() 202 | 203 | 204 | def serialize_in_pieces(out_dir_path, max_items_in_a_piece, items, randomize_order=False, header=""): 205 | print(f"Writing {len(items)} items (e.g., question jsons) to directory: {out_dir_path}") 206 | if not os.path.exists(out_dir_path): 207 | os.makedirs(out_dir_path) 208 | 209 | curr_file_num = 1 210 | curr_file = open(f"{out_dir_path}/{curr_file_num}.json", 'w') 211 | if header: 212 | curr_file.write(header) 213 | 214 | items_randomized = items 215 | if randomize_order: 216 | items_randomized = sample(items, len(items)) 217 | 218 | for item_num, item in enumerate(items_randomized): 219 | if item_num % max_items_in_a_piece == 0 and item_num > 1: 220 | # close the old file. 221 | curr_file.close() 222 | # open a new file. 223 | curr_file_num += 1 224 | curr_file = open(f"{out_dir_path}/{curr_file_num}.json", "w") 225 | if header: 226 | curr_file.write(header) 227 | if "\n" not in header: 228 | curr_file.write("\n") 229 | 230 | curr_file.write(item) 231 | if "\n" not in item: 232 | curr_file.write("\n") 233 | 234 | if not curr_file.closed: 235 | curr_file.close() 236 | -------------------------------------------------------------------------------- /src/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/helpers/__init__.py -------------------------------------------------------------------------------- /src/helpers/collections_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def add_key_to_map_arr(key, value, map_): 5 | if key not in map_: 6 | map_[key] = [] 7 | map_[key].append(value) 8 | 9 | 10 | def getElem(arr: List, elem_idx:int, defaultValue): 11 | if not arr or len(arr) <= elem_idx: 12 | return defaultValue 13 | return arr[elem_idx] -------------------------------------------------------------------------------- /src/helpers/dataset_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the latest dataset related files 3 | for EMNLP2019 4 | """ 5 | 6 | from collections import namedtuple 7 | 8 | from src.third_party_utils.allennlp_cached_filepath import cached_path 9 | 10 | DatasetInfo = namedtuple('DatasetInfo', 11 | ['beaker_link', 'cloud_path', 'local_path', 'data_reader', 'metadata_info', 'readme']) 12 | 13 | # EMNLP dataset: "with explanations" in json format. 14 | wiqa_explanations_v1 = DatasetInfo( 15 | beaker_link="https://beaker.org/ds/ds_jsdpme8ixz86/", # an unpublished, baseline model. 16 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_dataset_with_explanation/", 17 | # [cloud_path + partition for partition in ["train.jsonl", "dev.jsonl", "test.jsonl"]] 18 | local_path="", 19 | data_reader="BertMCQAReaderSentAnnotated", # an unpublished, baseline model. 20 | metadata_info="WhatifMetadata", 21 | readme="" 22 | ) 23 | 24 | # EMNLP dataset: "no explanations" in json format. 25 | wiqa_no_explanations_vetted_v1 = DatasetInfo( 26 | beaker_link="https://beaker.org/ds/ds_zyukvhb9ezqa/", 27 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_dataset_no_explanation_v2/", 28 | # [cloud_path + partition for partition in ["train.jsonl", "dev.jsonl", "test.jsonl"]] 29 | local_path="", 30 | data_reader="BertMCQAReaderPara", 31 | metadata_info="WhatifMetadata", 32 | readme="" 33 | ) 34 | 35 | # All the turked influence graphs in json format. 36 | influence_graphs_v1 = DatasetInfo( 37 | beaker_link="", 38 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_influence_graphs.jsonl", 39 | local_path="", 40 | data_reader="SituationGraph", 41 | metadata_info="WhatifMetadata", 42 | readme="All influence graphs" 43 | ) 44 | 45 | # Train/dev/test in ProPara dataset is partitioned by topics. Every topic in test is therefore, novel. 46 | # In the WIQA dataset, we continue partitioning by topic. The following data contains para id and and partition. 47 | para_partition_info = DatasetInfo( 48 | beaker_link="", 49 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/metadata/para_id.prompt.topic.partition.tsv", 50 | local_path="", 51 | data_reader="", 52 | metadata_info="WhatifMetadata", 53 | readme="Para id, topic, and the partition type" 54 | ) 55 | 56 | # The ProPara dataset (includes paragraph id and paragraph sentences (incl. metadata: paragraph prompt/title & topic) 57 | propara_para_info = DatasetInfo( 58 | beaker_link="", 59 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/metadata/propara-extended-para.tsv", 60 | local_path="", 61 | data_reader="", 62 | metadata_info="WhatifMetadata", 63 | readme="Para id, topic, and the sentences" 64 | ) 65 | 66 | 67 | def download_from_url_if_not_in_cache(cloud_path: str, cache_dir: str = None): 68 | """ 69 | :param cloud_path: e.g., https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa-model.tar.gz 70 | :param to_dir: will be regarded as a cache. 71 | :return: the path of file to which the file is downloaded. 72 | """ 73 | return cached_path(url_or_filename=cloud_path, cache_dir=cache_dir) 74 | -------------------------------------------------------------------------------- /src/helpers/situation_graph.py: -------------------------------------------------------------------------------- 1 | # Tasks: 2 | # input = graph in some format 3 | # output1 = graph in json format 4 | # output2 = tf_ques_given_a_json_record 5 | # output3 = generate explanation candidates for_a_tf_ques 6 | 7 | import copy 8 | import enum 9 | import json 10 | from typing import List 11 | 12 | 13 | def strip_special_char(a_string): 14 | return "".join([x for x in a_string if (ord('a') <= ord(x) <= ord('z')) or (ord('A') <= ord(x) <= ord('Z'))]).lower() 15 | 16 | 17 | class SituationNode: 18 | def __init__(self, node_id: str, 19 | the_groundings: [], 20 | is_decision_node: bool = False, 21 | node_semantics: str = ""): 22 | ''' 23 | :param node_id e.g., "VX" 24 | :param node_semantics e.g., for A/D node, semantics is "accelerates", "decelerates"; 25 | or, x/y node can be "causal" or u,v,w,z "indirect". 26 | :param the_groundings: is an array of groundings. 27 | :param is_decision_node: e.g., Accelerates or Decelerates nodes are decision nodes currently 28 | ''' 29 | self.id = node_id 30 | self.node_semantics = node_semantics 31 | self.groundings = [t for t in the_groundings] # we need a copy and not a reference of the supplied array. 32 | self.is_decision_node = is_decision_node 33 | 34 | def join_groundings(self, separator): 35 | if not self.groundings: 36 | return "" 37 | return separator.join(self.groundings) 38 | 39 | def remove_grounding(self, the_grounding_to_remove): 40 | try: 41 | self.groundings.remove(the_grounding_to_remove) 42 | except ValueError: 43 | # Continue, do not stop 44 | print(f"Warning: Removal of a non-existent grounding: '{the_grounding_to_remove}' from node {self.id}") 45 | 46 | def get_grounding_pairs(self): 47 | if not self.groundings or len(self.groundings) < 2: 48 | return [] 49 | return zip(*[self.groundings[i:] for i in range(2)]) 50 | 51 | def is_empty(self): 52 | return not self.groundings or len(self.groundings) == 0 53 | 54 | def __repr__(self): 55 | return self.id 56 | 57 | def __eq__(self, other): 58 | return self.id == other.id 59 | 60 | def __hash__(self): 61 | return hash(self.id) 62 | 63 | def get_specific_grounding(self, specific_grounding=None): 64 | ''' 65 | :param specific_grounding: if set to None, then all groundings are used {source_node.groundings} otherwise the string supplied. 66 | :return: 67 | ''' 68 | return f"{self.groundings}" if not specific_grounding else specific_grounding 69 | 70 | def get_first_grounding(self): 71 | return "" if not self.groundings or len(self.groundings) < 1 else self.groundings[0] 72 | 73 | 74 | class SituationLabel(str, enum.Enum): 75 | NOT_RESULTS_IN ="NOT_RESULTS_IN" 76 | RESULTS_IN = "RESULTS_IN" 77 | NO_EFFECT = "NO_EFFECT" 78 | MARKED_NOISE = "MARKED_NOISE" 79 | 80 | def get_sign(self): 81 | if self == SituationLabel.RESULTS_IN: 82 | return '+' 83 | elif self == SituationLabel.NOT_RESULTS_IN: 84 | return '-' 85 | else: 86 | return '.' 87 | 88 | def get_sign_str(self): 89 | if self == SituationLabel.RESULTS_IN: 90 | return 'MORE' 91 | elif self == SituationLabel.NOT_RESULTS_IN: 92 | return 'LESS' 93 | else: 94 | return 'NOEFFECT' 95 | 96 | def as_less_more(self): 97 | if self == SituationLabel.RESULTS_IN: 98 | return 'more' 99 | elif self == SituationLabel.NOT_RESULTS_IN: 100 | return 'less' 101 | else: 102 | return 'no_effect' 103 | 104 | def get_nickname(self): 105 | if self == SituationLabel.RESULTS_IN: 106 | return 'a' 107 | elif self == SituationLabel.NOT_RESULTS_IN: 108 | return 'd' 109 | else: 110 | return '-' 111 | 112 | def get_opposite_label(self): 113 | if self == SituationLabel.RESULTS_IN: 114 | return SituationLabel.NOT_RESULTS_IN 115 | elif self == SituationLabel.NOT_RESULTS_IN: 116 | return SituationLabel.RESULTS_IN 117 | else: 118 | return self 119 | 120 | def get_emnlp_test_choice(self): 121 | if self == SituationLabel.RESULTS_IN: 122 | return 'A' 123 | elif self == SituationLabel.NOT_RESULTS_IN: 124 | return 'B' 125 | else: 126 | return 'C' 127 | 128 | @staticmethod 129 | def from_str(sl): 130 | if not sl: 131 | raise ValueError( 132 | f"({sl}) is not a valid Enum SituationLabel") 133 | sl = sl.lower().replace('_', ' ').strip() 134 | 135 | if sl in ['-', 'not results in', 'opposite', 'opp', 'results in opp', 'opp effect', 'opposite effect', 'less']: 136 | return SituationLabel.NOT_RESULTS_IN 137 | elif sl in ['+', 'positive', 'results in', 'correct', 'correct effect', 'more']: 138 | return SituationLabel.RESULTS_IN 139 | elif sl in ['.', 'none', 'no effect']: 140 | return SituationLabel.NO_EFFECT 141 | else: 142 | print(f"WARNING: ({sl}) is not a valid Enum SituationLabel") 143 | return SituationLabel.MARKED_NOISE 144 | 145 | def to_readable_str(self): 146 | if self == SituationLabel.RESULTS_IN: 147 | return 'RESULTS_IN' 148 | elif self == SituationLabel.NOT_RESULTS_IN: 149 | return 'RESULTS_IN_OPP' 150 | else: 151 | return 'NO_EFFECT' 152 | 153 | def to_json(self): 154 | return self.name 155 | 156 | 157 | class SituationEdge: 158 | def __init__(self, from_node: SituationNode, to_node: SituationNode, label: SituationLabel, wt=1.0): 159 | self.from_node = from_node 160 | self.to_node = to_node 161 | self.wt = wt 162 | self.label = label 163 | self.id = (from_node.id, to_node.id) 164 | self.is_noise = False 165 | 166 | def mark_noise(self): 167 | # If marked as noise, then we will drop this edge completely 168 | self.is_noise = True 169 | 170 | def reset_label_to(self, new_label: SituationLabel): 171 | # If marked as noise, then we will drop this edge completely 172 | self.label = new_label 173 | 174 | def ground_edge(self): 175 | grounded_edges = [] 176 | # node A (A1, A2) => node B (B1, B2) 177 | for s1 in self.from_node.groundings: 178 | for s2 in self.to_node.groundings: 179 | grounded_edges.append({"from_node": s1, "to_node": s2, "label": self.label.name}) 180 | return grounded_edges 181 | 182 | def __repr__(self): 183 | return self.from_node.id + "-" + self.to_node.id 184 | 185 | 186 | class SituationGraph: 187 | # graph structure is a json of edges and their labels. 188 | # e.g., [(V,X,not) 189 | def __init__(self, situation_nodes: List[SituationNode], situation_edges: List[SituationEdge], 190 | other_properties: dict): 191 | ''' 192 | :param situation_nodes: an array of situation nodes. 193 | :param other_properties: e.g., other_graph_provenance, 194 | or, para_outcome, 195 | or, graph_provenance: paraid__blocknum__paraprompt etc. 196 | ''' 197 | self.nodes = situation_nodes 198 | self.situation_edges = situation_edges 199 | self.other_properties = {k: v for k, v in other_properties.items()} 200 | self.cleanup() 201 | 202 | def cleanup(self, remove_no_effects=True): 203 | """ 204 | This function can also be called at a later point when vetting data is available on a situation graph. 205 | @:param remove_no_effects : unless we extend our graphs to contain no effect nodes, the default should be 206 | True 207 | :return: 208 | """ 209 | self.remove_empty_nodes() 210 | edges_to_remove = [e for e in self.situation_edges 211 | if e.is_noise or e.label == SituationLabel.MARKED_NOISE or ( 212 | remove_no_effects and e.label == SituationLabel.NO_EFFECT)] 213 | for e in edges_to_remove: 214 | self.remove_an_edge(edge_to_remove=e) 215 | 216 | def copy(self): 217 | copied = SituationGraph( 218 | situation_nodes=copy.deepcopy(self.nodes), 219 | situation_edges=copy.deepcopy(self.situation_edges), 220 | other_properties=copy.deepcopy(self.other_properties) 221 | ) 222 | copied.cleanup() 223 | return copied 224 | 225 | def get_empty_nodes(self): 226 | return [n for n in self.nodes if n.is_empty()] 227 | 228 | def remove_empty_nodes(self): 229 | nodes_to_remove = self.get_empty_nodes() 230 | for cand_node in nodes_to_remove: 231 | self.remove_a_node(node_to_remove=cand_node) 232 | 233 | def get_all_node_grounding_pairs(self): 234 | lists = [] 235 | for x in self.nodes: 236 | pairs = x.get_grounding_pairs() 237 | if pairs: 238 | lists.extend(pairs) 239 | return lists 240 | 241 | def remove_a_node(self, node_to_remove): 242 | if not node_to_remove: 243 | return 244 | try: 245 | self.nodes.remove(node_to_remove) 246 | # also remove all edges it occurs in. 247 | edges_to_remove = [e for e in self.situation_edges 248 | if e.from_node.id == node_to_remove.id or e.to_node.id == node_to_remove.id] 249 | for e in edges_to_remove: 250 | self.remove_an_edge(edge_to_remove=e) 251 | except ValueError: 252 | print(f"Warning: Removal of a non-existent node: '{node_to_remove}'") 253 | 254 | def get_exogenous_nodes(self, exogenous_ids=["Z", "V"]): 255 | return [self.lookup_node(node_id=x) for x in exogenous_ids] 256 | 257 | def remove_an_edge(self, edge_to_remove): 258 | if not edge_to_remove: 259 | return 260 | try: 261 | self.situation_edges.remove(edge_to_remove) 262 | except ValueError: 263 | print(f"Warning: Removal of a non-existent edge: '{edge_to_remove}'") 264 | 265 | def lookup_node(self, node_id) -> SituationNode: 266 | for t in self.nodes: 267 | if t.id == node_id: 268 | return t 269 | return None 270 | 271 | def lookup_node_from_grounding(self, grounding) -> SituationNode: 272 | for t in self.nodes: 273 | # There is a very small chance that a grounding in two different nodes is same. 274 | # For those cases, this function just picks the nodes that comes first in the self.nodes array 275 | if grounding in t.groundings: 276 | return t 277 | 278 | # Search unsuccessful, so try removing lowercasing/uppercasing and tokenization mismatches (drop special chars) 279 | stripped_grounding = strip_special_char(a_string=grounding) 280 | if not stripped_grounding: 281 | return None 282 | for t in self.nodes: 283 | if stripped_grounding in [strip_special_char(x) for x in t.groundings]: 284 | return t 285 | 286 | return None 287 | 288 | def lookup_edge(self, edge_source_node, edge_target_node) -> SituationEdge: 289 | for t in self.situation_edges: 290 | if t.from_node == edge_source_node and t.to_node == edge_target_node: 291 | return t 292 | return None 293 | 294 | def speak_path(self, path, path_label, as_symbols=False, specific_grounding=None): 295 | if not path or len(path) < 1: 296 | return "" 297 | return self.speak_path_with_symbols(path=path, path_label=path_label, specific_grounding=specific_grounding) \ 298 | if as_symbols else self.speak_path_in_sentences(path=path, path_label=path_label, 299 | specific_grounding=specific_grounding) 300 | 301 | def construct_path_from_str_arr(self, str_node_arr): 302 | path = [self.lookup_node(node_str.strip()) for node_str in str_node_arr] 303 | path_ok = True 304 | for node in path: 305 | if node is None or node.is_empty(): 306 | path_ok = False 307 | return path if path_ok else None 308 | 309 | def get_graph_id(self): 310 | return self.other_properties["graph_id"] 311 | 312 | def speak_path_with_symbols(self, path, path_label, specific_grounding=None): 313 | ''' 314 | 315 | :param path: array of Situation nodes e.g. V, X, A 316 | :param path_label: array of path_len - 1 nodes, e.g., [RESULTS_IN, NOT_RESULTS_IN] 317 | :return: 318 | ''' 319 | return f"{path[0].get_specific_grounding(specific_grounding)}->" + \ 320 | ("->".join([l.name + "->" + f"{p.get_specific_grounding(specific_grounding)}" 321 | for p, l in zip(path[1:], path_label)])) 322 | 323 | def speak_path_in_sentences(self, path, path_label, specific_grounding=None): 324 | ''' 325 | Symbolic forms are very confusing: 326 | 327 | ```A not B not C 328 | A =/> B =/> C``` 329 | 330 | So instead express it pairwise as multiple sentences (each explanation answer option can be list of sentences): 331 | 332 | ```If A happens, then B will not happen. 333 | If B happens, then C will not happen.``` 334 | 335 | :param path: array of Situation nodes e.g. V, X, A 336 | :param path_label: array of path_len - 1 nodes, e.g., [RESULTS_IN, NOT_RESULTS_IN] 337 | :return: 338 | ''' 339 | 340 | sentences = [] 341 | for edge_num, (n1, n2) in enumerate(zip(path, path[1:])): 342 | speak_label = "will" if path_label[edge_num] == SituationLabel.RESULTS_IN else "will not" 343 | sentences.append(f"If {n1.get_specific_grounding(specific_grounding)} happens, then " 344 | f"{n2.get_specific_grounding(specific_grounding)} {speak_label} happen") 345 | 346 | return ". ".join(sentences).strip() # strip removes last space 347 | # return f"{path[0].get_specific_grounding(specific_grounding)}->" + \ 348 | # ("->".join([l.name + "->" + f"{p.get_specific_grounding(specific_grounding)}" 349 | # for p, l in zip(path[1:], path_label)])) 350 | 351 | def all_labels_in(self, path): 352 | return [self.lookup_edge(src, dest).label if self.lookup_edge(src, 353 | dest) is not None else SituationLabel.NOT_RESULTS_IN 354 | for src, dest in zip(path, path[1:])] 355 | # raw_path_labels = [self.lookup_edge(src, dest).label for src, dest in zip(path, path[1:])] 356 | # fixed_all_labels = raw_path_labels[:-1] 357 | # # FIXME this is possibly a bug and needs to be fixed when end label in path != decision node. 358 | # # and if y=/>a that means flip the inherent label. 359 | # # end node in the path is {accelerate/ decelerate} 360 | # if path[-1].is_decision_node and not self.other_properties.get("is_positive_outcm", True): 361 | # fixed_all_labels.append(raw_path_labels[0].get_opposite_label()) 362 | # else: 363 | # fixed_all_labels.append(raw_path_labels[-1]) 364 | # # fixed_all_labels.append(self.end_label_for_path(path=path)) 365 | # return fixed_all_labels 366 | 367 | # one process per html page (9 situation graphs at most). 368 | # For perl: https://metacpan.org/pod/distribution/Graph-Easy/bin/graph-easy 369 | # For python, see: https://pypi.org/project/graphviz/ 370 | # For javascript, see: http://www.webgraphviz.com/ 371 | @staticmethod 372 | def to_html(graphviz_data_map, html_title, html_outfile_path=None): 373 | ''' 374 | # 375 | # 376 | # Paragraph 1 377 | # 378 | # 379 | # 380 | # 381 | #

html_page_top_str

382 | #
383 | #
384 | # 385 | # 399 | # 400 | # 403 | # 404 | # 405 | # 406 | :param graphviz_data_map: 407 | :param html_title: 408 | :param html_outfile_path: 409 | :return: 410 | ''' 411 | htmls = list() 412 | assert graphviz_data_map is not None and len(graphviz_data_map) > 0, \ 413 | f"Input to graphviz is empty. This function call must be fixed: " \ 414 | f"\nSituationGraph.to_html(graphviz_data_map=empty_array, html_title={html_title})" 415 | htmls.append(f" \n \n \n {html_title} \n") 416 | htmls.append('\n\n') 417 | htmls.append(f"\n") 418 | htmls.append(f"\n

") 419 | htmls.append("\n".join([f'\n\n

\n
' for x_idx, _ in 420 | enumerate(graphviz_data_map.keys())])) 421 | htmls.append(f"\n\n\n\n\n\n") 427 | html_str = "\n".join(htmls) 428 | if html_outfile_path is not None: 429 | html_outfile = open(html_outfile_path, 'w') 430 | html_outfile.write(html_str) 431 | html_outfile.close() 432 | return html_str 433 | 434 | def as_graphviz(self, statement_separator): 435 | ''' 436 | :param statement_separator either "\n" (human readable) or "\t" (for html) 437 | ## Sample digraph 438 | ## color: "implies" : green, "not implies": red. 439 | ## node[style=filled, color=cornflowerblue, fontcolor=white, fontsize=10, fontname='Helvetica'] 440 | ## edge[arrowhead=vee, arrowtail=inv, arrowsize=.7, color=maroon, fontsize=10, fontcolor=navy] 441 | # digraph G { 442 | # "X" -> "Y"[label = "implies"] 443 | # "X" -> "W"[label = "not implies"] 444 | # "U" -> "Y"[label = "not implies"] 445 | # "The sun was not in the sky\nThe sun was in the sky" -> "X"[label = "implies"] 446 | # "V" -> "X"[label = "not implies"] 447 | # "Y" -> "A"[label = "implies"] 448 | # "Y" -> "D"[label = "not implies"] 449 | # "W" -> "A"[label = "not implies"] 450 | # "W" -> "D"[label = "implies"] 451 | # } 452 | :return: 453 | ''' 454 | g = list() 455 | printed_newline = "________" # "\\\n" 456 | g.append("digraph G {") 457 | for e in self.situation_edges: 458 | edge_color = "green" if e.label == SituationLabel.RESULTS_IN else ( 459 | "red" if e.label == SituationLabel.NOT_RESULTS_IN else "yellow") 460 | node1 = e.from_node.join_groundings(separator=printed_newline) 461 | node2 = e.to_node.join_groundings(separator=printed_newline) 462 | edge = f"\"{node1}\" -> \"{node2}\" [color={edge_color}]" 463 | g.append(edge.replace(printed_newline, ("\\\\" + "n"))) 464 | g.append("}") 465 | return statement_separator.join(g) 466 | 467 | def as_graphviz_with_labels(self, statement_separator): 468 | """ This is like as_graphviz, but with labels for each node. """ 469 | g = list() 470 | printed_newline = "________" # "\\\n" 471 | g.append("digraph G {") 472 | for e in self.situation_edges: 473 | edge_color = "green" if e.label == SituationLabel.RESULTS_IN else ( 474 | "red" if e.label == SituationLabel.NOT_RESULTS_IN else "yellow") 475 | node1 = "Node " + e.from_node.id + printed_newline + e.from_node.join_groundings(separator=printed_newline) 476 | node2 = "Node " + e.to_node.id + printed_newline + e.to_node.join_groundings(separator=printed_newline) 477 | edge = f"\"{node1}\" -> \"{node2}\" [color={edge_color}]" 478 | g.append(edge.replace(printed_newline, ("\\\\" + "n"))) 479 | g.append("}") 480 | return statement_separator.join(g) 481 | 482 | def end_label_for_path(self, path): 483 | # Suppose, X ==> Y 484 | # X =/=> W 485 | # U =/=> Y 486 | # Z ==> X 487 | # V =/=> X 488 | # if we encounter even number of not's then answer is True else False 489 | # This will work for all cases even the intermediate paths V =/=> X =/=> W 490 | raw_path_labels = [self.lookup_edge(src, dest).label for src, dest in zip(path, path[1:])] 491 | return SituationLabel.RESULTS_IN \ 492 | if len([x for x in raw_path_labels if x == SituationLabel.NOT_RESULTS_IN]) % 2 == 0 \ 493 | else SituationLabel.NOT_RESULTS_IN 494 | # 495 | # 496 | # num_negations = 0 497 | # for l in raw_path_labels: 498 | # if l == SituationLabel.NOT_RESULTS_IN: 499 | # num_negations += 1 500 | # if num_negations % 2 == 0: 501 | # return SituationLabel.RESULTS_IN 502 | # return SituationLabel.NOT_RESULTS_IN 503 | 504 | # previous end_label_for_path implementation 505 | # def end_label_for_path_old(self, path): 506 | # # Suppose, X ==> Y 507 | # # X =/=> W 508 | # # U =/=> Y 509 | # # Z ==> X 510 | # # V =/=> X 511 | # # Then, V,X,W,A path ... 512 | # # not, not, not 513 | # # so, if ever encounter "not" then answer is "not" 514 | # # Other example, V,X,Y,A 515 | # # so, 516 | # raw_path_labels = [self.lookup_edge(src, dest).label for src, dest in zip(path, path[1:])] 517 | # for l in raw_path_labels: 518 | # if l == SituationLabel.NOT_RESULTS_IN: 519 | # return SituationLabel.NOT_RESULTS_IN 520 | # return SituationLabel.RESULTS_IN 521 | 522 | def nodes_from(self, source_node): 523 | return [t.to_node for t in self.situation_edges if t.from_node.id == source_node.id] 524 | 525 | def dfs_between(self, source_node, target_node, visited, current_path, all_paths, min_path_len): 526 | visited[source_node.id] = True 527 | current_path.append(source_node) 528 | 529 | if source_node.id == target_node.id: 530 | # We found one path to the target, record it. 531 | if len(current_path) >= min_path_len: # x=>y=>z has path length 3, which is accepted if minlen is 3 532 | all_paths.append(current_path.copy()) 533 | 534 | else: 535 | # We haven't arrived at the target node yet, keep searching. 536 | for e in self.nodes_from(source_node=source_node): 537 | if not visited[e.id]: 538 | self.dfs_between(source_node=e, 539 | target_node=target_node, 540 | current_path=current_path, 541 | visited=visited, 542 | all_paths=all_paths, 543 | min_path_len=min_path_len 544 | ) 545 | 546 | # target_node shouldn't be marked visited forever 547 | # current node may be part of another path so allow it to be visited again. 548 | visited[source_node.id] = False 549 | current_path.pop() # remove current node. 550 | 551 | def paths_between(self, source_node, target_node, min_path_len): 552 | all_paths = [] 553 | visited = {t.id: False for t in self.nodes} 554 | self.dfs_between(source_node=source_node, 555 | target_node=target_node, 556 | visited=visited, 557 | current_path=[], 558 | all_paths=all_paths, 559 | min_path_len=min_path_len 560 | ) 561 | return all_paths 562 | 563 | # def distractor_paths(self, paths): 564 | # invalid_paths = [] 565 | # valid_paths = self.paths_between(source_node=source_node, target_node=target_node, min_path_len=min_path_len) 566 | # 567 | # # source nodes can never be decision nodes. 568 | # query_source_nodes = [t for t in self.nodes if not t.is_decision_node] 569 | # # target nodes may or may not be decision nodes. 570 | # query_target_nodes = [t for t in self.nodes if not t.is_decision_node] 571 | # for src in query_source_nodes: 572 | # for tgt in query_target_nodes: 573 | # if src != tgt: 574 | # paths_between = self.paths_between(source_node=src, target_node=tgt, min_path_len=min_path_len) 575 | # if not paths_between: 576 | # all_queries.append(paths_between) 577 | # return invalid_paths 578 | 579 | def all_query_paths(self, min_path_len, target_node_has_to_be_decision_node, hardcoded_source_nodes_strs=None): 580 | ''' 581 | 582 | :param hardcoded_source_nodes: 583 | :param min_path_len: e.g. 3 would mean that Z=>X would be ignored but Z,,W 584 | :param target_node_has_to_be_decision_node: if set to False, then both source and tgt nodes are internal nodes (and not A or D nodes) 585 | :return: 586 | ''' 587 | all_queries = [] 588 | # source nodes can never be decision nodes. 589 | cleaned_source_nodes = self.construct_path_from_str_arr( 590 | str_node_arr=hardcoded_source_nodes_strs) if hardcoded_source_nodes_strs else [] 591 | if hardcoded_source_nodes_strs is not None and not cleaned_source_nodes: 592 | return all_queries 593 | query_source_nodes = cleaned_source_nodes or [t for t in self.nodes if not t.is_decision_node] 594 | # target nodes may or may not be decision nodes. 595 | query_target_nodes = [t for t in self.nodes if ( 596 | (t.is_decision_node or not target_node_has_to_be_decision_node) 597 | # and t not in query_source_nodes 598 | ) 599 | ] 600 | for src in query_source_nodes: 601 | for tgt in query_target_nodes: 602 | if src.id != tgt.id: 603 | paths_between = self.paths_between(source_node=src, target_node=tgt, min_path_len=min_path_len) 604 | # valid_paths = self.get_valid_paths(cand_paths=paths_between, source_nodes=query_source_nodes) 605 | valid_paths = self.get_valid_paths(cand_paths=paths_between, source_nodes=cleaned_source_nodes) 606 | if valid_paths: 607 | all_queries.append(valid_paths) 608 | return all_queries 609 | 610 | # given a situationgraph id we need (list of {s,o,p} dict). 611 | def get_grounded_edges(self): 612 | grounded_edges = [] 613 | # node A (A1, A2) => node B (B1, B2) 614 | for edge in self.situation_edges: 615 | for edge_grounding in edge.ground_edge(): 616 | grounded_edges.append(edge_grounding) 617 | return grounded_edges 618 | 619 | def to_json_v1(self): 620 | return json.dumps(self.to_struct_v1()) 621 | 622 | def to_struct_v1(self): 623 | struct = {} 624 | 625 | for n in self.nodes: 626 | the_name = n.id 627 | if the_name == "A": 628 | the_name = "para_outcome_accelerate" 629 | if the_name == "D": 630 | the_name = "para_outcome_decelerate" 631 | struct[the_name] = n.groundings 632 | 633 | # (not used) "para_outcome": "oil formation", 634 | # (not used) "Y_is_outcome": "" 635 | # "Y_affects_outcome": "-", 636 | y_to_a_label = self.lookup_edge(edge_source_node=self.lookup_node(node_id="Y"), 637 | edge_target_node=self.lookup_node(node_id="A")).label 638 | struct["Y_affects_outcome"] = y_to_a_label.get_nickname() 639 | struct["paragraph"] = self.other_properties.get("paragraph", "") 640 | struct["prompt"] = self.other_properties.get("prompt", "") 641 | struct["para_id"] = self.other_properties.get("para_id", "") 642 | struct["Y_is_outcome"] = self.other_properties.get("y_is_outcome", "") 643 | 644 | return struct 645 | 646 | @staticmethod 647 | def from_json_v1(json_string: str): 648 | # decode the JSON string into a structure 649 | return SituationGraph.from_struct_v1(struct=json.loads(json_string)) 650 | 651 | @staticmethod 652 | def from_struct_v1(struct: dict): 653 | ''' 654 | 655 | :param struct (for data_version "v1"): a data structure that looks like this: 656 | { 657 | "para_id":"propara_pilot1_task12_p1.txt", 658 | "prompt":"How does oil form?", 659 | "paragraph":"Algae and plankton die. 660 | The dead algae and plankton end up part of sediment on a seafloor. 661 | The sediment breaks down. 662 | The bottom layers of sediment become compacted by pressure. 663 | Higher pressure causes the sediment to heat up. 664 | The heat causes chemical processes. 665 | The material becomes a liquid. 666 | Is known as oil. 667 | Oil moves up through rock.", 668 | "X":"pressure on sea floor increases", 669 | "Y":"sediment becomes hotter", 670 | "W":[ 671 | "sediment becomes cooler" 672 | ], 673 | "U":[ 674 | 675 | ], 676 | "Z":[ 677 | "ocean levels rise", 678 | "more plankton then normal die" 679 | ], 680 | "V":[ 681 | "oceans evaporate" 682 | ], 683 | "para_outcome":"oil formation", 684 | "para_outcome_accelerate": 685 | "More oil forms" 686 | , 687 | "para_outcome_decelerate": 688 | "Less oil forms" 689 | , 690 | "Y_affects_outcome":"-", 691 | "Y_is_outcome":"" 692 | } 693 | 694 | # graph_version "v1" assumes that: 695 | # X ==> Y 696 | # X =/=> W 697 | # U =/=> Y 698 | # Z ==> X 699 | # V =/=> X 700 | # W =/=> A 701 | # W ==> D 702 | # Y ==> A or Y ==> D 703 | 704 | :return: 705 | SituationGraph object. 706 | 707 | ''' 708 | 709 | # Fill an empty node because the graph structure is fixed. 710 | for expected_inner_node in ["Z", "V", "X", "Y", "W", "U"]: 711 | if expected_inner_node not in struct: 712 | struct[expected_inner_node] = [] 713 | 714 | node_v = SituationNode(node_id="V", the_groundings=struct["V"]) 715 | node_z = SituationNode(node_id="Z", the_groundings=struct["Z"]) 716 | node_x = SituationNode(node_id="X", 717 | the_groundings=struct["X"] if isinstance(struct["X"], list) else [ 718 | struct["X"]]) 719 | node_u = SituationNode(node_id="U", the_groundings=struct["U"]) 720 | node_w = SituationNode(node_id="W", the_groundings=struct["W"]) 721 | node_y = SituationNode(node_id="Y", 722 | the_groundings=struct["Y"] if isinstance(struct["Y"], list) else [ 723 | struct["Y"]]) 724 | outcm_accelerates = "accelerates process" if "para_outcome_accelerate" not in struct or not struct[ 725 | "para_outcome_accelerate"] else struct["para_outcome_accelerate"] 726 | outcm_decelerates = "decelerates process" if "para_outcome_decelerate" not in struct or not struct[ 727 | "para_outcome_decelerate"] else struct["para_outcome_decelerate"] 728 | node_a = SituationNode(node_id="A", 729 | the_groundings=outcm_accelerates if isinstance(outcm_accelerates, list) else [ 730 | outcm_accelerates], 731 | is_decision_node=True, 732 | node_semantics="accelerates") 733 | node_d = SituationNode(node_id="D", 734 | the_groundings=outcm_decelerates if isinstance(outcm_decelerates, list) else [ 735 | outcm_decelerates], 736 | is_decision_node=True, 737 | node_semantics="decelerates") 738 | 739 | nodes = [node_v, node_z, node_x, node_u, node_w, node_y, node_a, node_d] 740 | 741 | ######################### 742 | # BEGIN-hardcoding 743 | ######################### 744 | outcm = "Y_affects_outcome" 745 | is_positive_outcm = True # True if y=>a 746 | # If the following two were provided in input_json, then no pre-processing is needed 747 | ya_label = SituationLabel.MARKED_NOISE 748 | yd_label = SituationLabel.MARKED_NOISE 749 | if struct[outcm] == 'a' or struct[outcm] == True or struct[outcm] == 'more': 750 | ya_label = SituationLabel.RESULTS_IN 751 | yd_label = SituationLabel.NOT_RESULTS_IN 752 | elif struct[outcm] == 'd' or struct[outcm] == False or struct[outcm] == 'less': 753 | ya_label = SituationLabel.NOT_RESULTS_IN 754 | yd_label = SituationLabel.RESULTS_IN 755 | is_positive_outcm = False 756 | else: 757 | ya_label = SituationLabel.NO_EFFECT 758 | yd_label = SituationLabel.NO_EFFECT 759 | 760 | if ya_label == (SituationLabel.MARKED_NOISE or SituationLabel.NO_EFFECT) \ 761 | or yd_label == (SituationLabel.MARKED_NOISE or SituationLabel.NO_EFFECT): 762 | return None 763 | ######################### 764 | # END-hardcoding 765 | ######################### 766 | 767 | edges = [ 768 | SituationEdge(from_node=node_v, to_node=node_x, label=SituationLabel.NOT_RESULTS_IN), 769 | SituationEdge(from_node=node_z, to_node=node_x, label=SituationLabel.RESULTS_IN), 770 | SituationEdge(from_node=node_x, to_node=node_w, label=SituationLabel.NOT_RESULTS_IN), 771 | SituationEdge(from_node=node_x, to_node=node_y, label=SituationLabel.RESULTS_IN), 772 | SituationEdge(from_node=node_u, to_node=node_y, label=SituationLabel.NOT_RESULTS_IN), 773 | # SituationEdge(from_node=node_w, to_node=node_a, label=SituationLabel.NOT_RESULTS_IN), 774 | # SituationEdge(from_node=node_w, to_node=node_d, label=SituationLabel.RESULTS_IN), 775 | SituationEdge(from_node=node_w, to_node=node_a, label=yd_label), 776 | SituationEdge(from_node=node_w, to_node=node_d, label=ya_label), 777 | SituationEdge(from_node=node_y, to_node=node_a, label=ya_label), 778 | SituationEdge(from_node=node_y, to_node=node_d, label=yd_label) 779 | ] 780 | paragraph = struct["paragraph"] 781 | paragraph = paragraph.replace("\"", "'") 782 | 783 | return SituationGraph( 784 | situation_nodes=nodes, 785 | situation_edges=edges, 786 | other_properties={ 787 | "para_id": struct["para_id"], 788 | "prompt": struct["prompt"], 789 | "paragraph": paragraph, 790 | "outcome": struct.get("para_outcome", ""), 791 | "y_is_outcome": struct["Y_is_outcome"], 792 | "is_cyclic_process": struct.get("is_cyclic_process", False), 793 | "is_positive_outcm": is_positive_outcm, 794 | "graph_id": struct.get("graph_id", "") 795 | } 796 | ) 797 | 798 | def get_valid_paths(self, cand_paths, source_nodes): 799 | if not cand_paths: 800 | return [] 801 | return [p for p in cand_paths if not self.is_any_source_node_in_path(path=p, sources_nodes=source_nodes)] 802 | 803 | def is_any_source_node_in_path(self, path, sources_nodes): 804 | ''' 805 | 806 | :param path: [Z,X,Y,A] 807 | :param sources_nodes: [V,Z,U,Y] 808 | :return: in this example since Y is present in the non-starting node of the path so returns True 809 | ''' 810 | source_nodes_ids = [x.id for x in sources_nodes] 811 | for p in path[1:]: 812 | if p.id in source_nodes_ids: 813 | return True 814 | return False 815 | -------------------------------------------------------------------------------- /src/helpers/whatif_metadata.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | 4 | 5 | from src.helpers import SGFileLoaders, collections_util 6 | from src.helpers.collections_util import add_key_to_map_arr 7 | from src.helpers.situation_graph import SituationGraph 8 | 9 | 10 | class WhatifMetadata: 11 | def __init__(self, 12 | para_ids_metainfo_fp: str, 13 | situation_graphs_fp: str): 14 | self.graph_id_to_graph = dict() 15 | self.graph_id_to_graphjson = dict() 16 | self.para_id_to_graphids = dict() 17 | self.prompt_to_paraids = dict() 18 | self.topic_to_paraids = dict() 19 | self.topic_to_graph_ids = dict() 20 | self.paraid_to_topic = dict() 21 | self.paraid_to_para = dict() 22 | self.paraid_to_partition = dict() 23 | self.topic_to_partition = dict() 24 | 25 | for in_fp in SGFileLoaders.compile_input_files(situation_graphs_fp): 26 | with open(in_fp) as infile: 27 | for line in infile: 28 | j = json.loads(line) 29 | g = SituationGraph.from_struct_v1(struct=j) 30 | graph_id = j["graph_id"] 31 | self.graph_id_to_graphjson[graph_id] = j 32 | self.graph_id_to_graph[graph_id] = g 33 | para_id = j["para_id"] 34 | paragraph = j["paragraph"] 35 | self.paraid_to_para[para_id] = paragraph 36 | # k = self.get_topic_for_paraid(para_id=para_id) 37 | # if not k: 38 | # # FIXME This is to make the code work under a temporary bug of ids with prefix propara_ 39 | # para_id = "propara_" + para_id 40 | # k = self.get_topic_for_paraid(para_id=para_id) 41 | 42 | add_key_to_map_arr(key=para_id, 43 | value=graph_id, 44 | map_=self.para_id_to_graphids) 45 | # add_key_to_map_arr(key=k, 46 | # value=graph_id, 47 | # map_=self.topic_to_graph_ids) 48 | 49 | for in_fp in SGFileLoaders.compile_input_files(para_ids_metainfo_fp): 50 | # para_id, prompt, topic, partition 51 | with open(in_fp) as infile: 52 | reader = csv.DictReader(infile, delimiter='\t') 53 | for row in reader: 54 | para_id = row["para_id"] 55 | # FIXME temp fix due to id bug that inconsisently contains propara_ prefix sometimes. 56 | if "propara_" in para_id: 57 | para_id = para_id.replace("propara_", "") 58 | # Cleanup paraids from partition for which we have no graph ids. 59 | # This is a dataset issue that causes runtime errors if not cleaned up. 60 | if para_id in self.paraid_to_para.keys(): 61 | topic = row["topic"] 62 | add_key_to_map_arr(key=row["prompt"], value=para_id, map_=self.prompt_to_paraids) 63 | add_key_to_map_arr(key=topic, value=para_id, map_=self.topic_to_paraids) 64 | self.paraid_to_topic[para_id] = topic 65 | self.paraid_to_partition[para_id] = row["partition"] 66 | self.topic_to_partition[topic] = row["partition"] 67 | for graph_id in self.para_id_to_graphids[para_id]: 68 | # Repeated entries but maps takes care of it. 69 | # Moving it to previous for loop block from sg file creates problems because 70 | # some para ids are present in prompt file but absent in graphs file. 71 | add_key_to_map_arr(key=topic, 72 | value=graph_id, 73 | map_=self.topic_to_graph_ids) 74 | 75 | def get_paraids_for_prompt(self, prompt_str): 76 | assert prompt_str in self.prompt_to_paraids, f"no paraid present {prompt_str} in prompt to paraid" 77 | return [] if prompt_str not in self.prompt_to_paraids else self.prompt_to_paraids[prompt_str] 78 | 79 | def get_graph_for_id(self, graph_id) -> SituationGraph: 80 | assert graph_id in self.graph_id_to_graph, f"no graphid present {graph_id} in graphid to graph" 81 | return self.graph_id_to_graph[graph_id] 82 | 83 | def get_graphjson_for_id(self, graph_id) -> SituationGraph: 84 | assert graph_id in self.graph_id_to_graphjson, f"no graphid present {graph_id} in graphid to graphjson" 85 | return self.graph_id_to_graphjson[graph_id] 86 | 87 | def get_paraids_for_topic(self, topic_str): 88 | assert topic_str in self.topic_to_paraids, f"no topic present {topic_str} in paraids for topic" 89 | return [] if topic_str not in self.topic_to_paraids else self.topic_to_paraids[topic_str] 90 | 91 | def get_graphids_for_topic(self, topic_str): 92 | assert topic_str in self.topic_to_graph_ids, f"no topic present {topic_str} in topic_to_graph_ids" 93 | return [] if topic_str not in self.topic_to_graph_ids else self.topic_to_graph_ids[topic_str] 94 | 95 | def get_topic_for_graphid(self, graph_id): 96 | para_id = self.get_graph_for_id(graph_id=graph_id).other_properties["para_id"] 97 | return self.get_topic_for_paraid(para_id=para_id) 98 | 99 | def get_topic_for_paraid(self, para_id): 100 | assert para_id in self.paraid_to_topic, f"no paraid present {para_id} in paraid to topic" 101 | return "" if para_id not in self.paraid_to_topic else self.paraid_to_topic[para_id] 102 | 103 | def get_para_for_paraid(self, para_id): 104 | assert para_id in self.paraid_to_para, f"no paraid present {para_id} in paraid to para" 105 | return "" if para_id not in self.paraid_to_para else self.paraid_to_para[para_id] 106 | 107 | def get_partition_for_graphid(self, graph_id): 108 | para_id = self.get_graph_for_id(graph_id=graph_id).other_properties["para_id"] 109 | return self.get_partition_for_paraid(para_id=para_id) 110 | 111 | def get_partition_for_paraid(self, para_id): 112 | assert para_id in self.paraid_to_partition, f"no paraid present {para_id} in paraid to partition map" 113 | return "" if para_id not in self.paraid_to_partition else self.paraid_to_partition[para_id] 114 | 115 | def get_partition_for_prompt(self, prompt): 116 | paraid = collections_util.getElem(arr=self.get_paraids_for_prompt(prompt_str=prompt), elem_idx=0, 117 | defaultValue="") 118 | return self.get_partition_for_paraid(para_id=paraid) 119 | 120 | def get_partition_for_topic(self, topic): 121 | assert topic in self.topic_to_partition, f"topic not present: {topic} in topic to partition map" 122 | return self.topic_to_partition[topic] 123 | 124 | def get_all_topics(self): 125 | # Note that self.topic_to_paraids.keys() would provide ALL paragraphs, not just those 126 | # for which we have situation graphs. 127 | return self.topic_to_graph_ids.keys() 128 | 129 | def get_all_topics_in_partition(self, partition_reqd): 130 | # Note that self.topic_to_paraids.keys() would provide ALL paragraphs, not just those 131 | # for which we have situation graphs. 132 | return [x for x, partition in self.topic_to_partition.items() if partition == partition_reqd] 133 | 134 | def get_graphids_for_paraid(self, para_id): 135 | assert para_id in self.para_id_to_graphids, f"no paraid present {para_id} in para->graphids map" 136 | return self.para_id_to_graphids[para_id] 137 | 138 | def get_paraid_for_graph(self, graph): 139 | return graph.other_properties["para_id"] 140 | 141 | def get_paraid_for_graphid(self, graph_id): 142 | return self.get_graph_for_id(graph_id=graph_id).other_properties["para_id"] 143 | 144 | def get_prompt_for_graphid(self, graph_id): 145 | return self.get_graph_for_id(graph_id=graph_id).other_properties["prompt"] 146 | 147 | def get_paragraph_for_graphid(self, graph_id): 148 | return self.get_graph_for_id(graph_id=graph_id).other_properties["paragraph"] 149 | 150 | def get_paragraph_for_paraid(self, para_id): 151 | return self.paraid_to_para[para_id] 152 | 153 | def get_graph_nodes_as_text(self, graph_id): 154 | ''' 155 | Text = Bag of groundings from all nodes in the graph. Lowercases the resulting text 156 | :param graph_id: 157 | :return: 158 | ''' 159 | graph_as_text = " ".join([" ".join(x.groundings) for x in self.get_graph_for_id(graph_id=graph_id).nodes]) 160 | return graph_as_text.lower() 161 | 162 | def get_all_graphids(self): 163 | return self.graph_id_to_graph.keys() 164 | 165 | def get_all_graphids_for_partition(self, partition): 166 | ''' 167 | 168 | :param partition: train, dev, test 169 | :return: all_graphids (list of str) in that partition 170 | ''' 171 | graph_ids = [] 172 | for topic in self.get_all_topics_in_partition(partition_reqd=partition): 173 | graph_ids.extend(self.get_graphids_for_topic(topic_str=topic)) 174 | return graph_ids 175 | -------------------------------------------------------------------------------- /src/third_party_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/third_party_utils/__init__.py -------------------------------------------------------------------------------- /src/third_party_utils/allennlp_cached_filepath.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | # Code adapted from Allennlp file_utils. 5 | import shutil 6 | import tempfile 7 | from hashlib import sha256 8 | 9 | import requests 10 | from tqdm import tqdm 11 | 12 | 13 | def url_to_filename(url, etag=None): 14 | """ 15 | Convert `url` into a hashed filename in a repeatable way. 16 | If `etag` is specified, append its hash to the url's, delimited 17 | by a period. 18 | """ 19 | url_bytes = url.encode('utf-8') 20 | url_hash = sha256(url_bytes) 21 | filename = url_hash.hexdigest() 22 | 23 | if etag: 24 | etag_bytes = etag.encode('utf-8') 25 | etag_hash = sha256(etag_bytes) 26 | filename += '.' + etag_hash.hexdigest() 27 | 28 | return filename 29 | 30 | 31 | def http_get(url, temp_file): 32 | req = requests.get(url, stream=True) 33 | content_length = req.headers.get('Content-Length') 34 | total = int(content_length) if content_length is not None else None 35 | progress = tqdm(unit="B", total=total) 36 | for chunk in req.iter_content(chunk_size=1024): 37 | if chunk: # filter out keep-alive new chunks 38 | progress.update(len(chunk)) 39 | temp_file.write(chunk) 40 | progress.close() 41 | 42 | 43 | def get_from_cache(url, cache_dir=None): 44 | """ 45 | Given a URL, look for the corresponding dataset in the local cache. 46 | If it's not there, download it. Then return the path to the cached file. 47 | """ 48 | response = requests.head(url, allow_redirects=True) 49 | if response.status_code != 200: 50 | raise IOError("HEAD request failed for url {} with status code {}" 51 | .format(url, response.status_code)) 52 | etag = response.headers.get("ETag") 53 | filename = url_to_filename(url, etag) 54 | 55 | # get cache path to put the file 56 | cache_path = os.path.join(cache_dir, filename) 57 | 58 | if not os.path.exists(cache_path): 59 | # Download to temporary file, then copy to cache dir once finished. 60 | # Otherwise you get corrupt cache entries if the download gets interrupted. 61 | with tempfile.NamedTemporaryFile() as temp_file: 62 | print("%s not found in cache, downloading to %s", url, temp_file.name) 63 | 64 | http_get(url, temp_file) 65 | 66 | # we are copying the file before closing it, so flush to avoid truncation 67 | temp_file.flush() 68 | # shutil.copyfileobj() starts at the current position, so go to the start 69 | temp_file.seek(0) 70 | 71 | print("copying %s to cache at %s", temp_file.name, cache_path) 72 | with open(cache_path, 'wb') as cache_file: 73 | shutil.copyfileobj(temp_file, cache_file) 74 | 75 | print("creating metadata file for %s", cache_path) 76 | meta = {'url': url, 'etag': etag} 77 | meta_path = cache_path + '.json' 78 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 79 | json.dump(meta, meta_file) 80 | 81 | print("removing temp file %s", temp_file.name) 82 | 83 | return cache_path 84 | 85 | 86 | def cached_path(url_or_filename, cache_dir=None): 87 | """ 88 | Given something that might be a URL (or might be a local path), 89 | determine which. If it's a URL, download the file and cache it, and 90 | return the path to the cached file. If it's already a local path, 91 | make sure the file exists and then return the path. 92 | """ 93 | if cache_dir is None: 94 | cache_dir = "/tmp/wiqa-cache/" 95 | if not os.path.exists(cache_dir): 96 | os.makedirs(cache_dir) 97 | if url_or_filename.startswith('http'): 98 | # URL, so get it from the cache (downloading if necessary) 99 | return get_from_cache(url_or_filename, cache_dir) 100 | elif os.path.exists(url_or_filename): 101 | # File, and it exists. 102 | return url_or_filename 103 | else: 104 | raise ValueError("unable to parse {} as a URL or valid local path".format(url_or_filename)) 105 | -------------------------------------------------------------------------------- /src/third_party_utils/nltk_porter_stemmer.py: -------------------------------------------------------------------------------- 1 | 2 | class PorterStemmer: 3 | """ 4 | A word stemmer based on the Porter stemming algorithm. 5 | 6 | Porter, M. "An algorithm for suffix stripping." 7 | Program 14.3 (1980): 130-137. 8 | 9 | See http://www.tartarus.org/~martin/PorterStemmer/ for the homepage 10 | of the algorithm. 11 | 12 | Martin Porter has endorsed several modifications to the Porter 13 | algorithm since writing his original paper, and those extensions are 14 | included in the implementations on his website. Additionally, others 15 | have proposed further improvements to the algorithm, including NLTK 16 | contributors. There are thus three modes that can be selected by 17 | passing the appropriate constant to the class constructor's `mode` 18 | attribute: 19 | 20 | PorterStemmer.ORIGINAL_ALGORITHM 21 | - Implementation that is faithful to the original paper. 22 | 23 | Note that Martin Porter has deprecated this version of the 24 | algorithm. Martin distributes implementations of the Porter 25 | Stemmer in many languages, hosted at: 26 | 27 | http://www.tartarus.org/~martin/PorterStemmer/ 28 | 29 | and all of these implementations include his extensions. He 30 | strongly recommends against using the original, published 31 | version of the algorithm; only use this mode if you clearly 32 | understand why you are choosing to do so. 33 | 34 | PorterStemmer.MARTIN_EXTENSIONS 35 | - Implementation that only uses the modifications to the 36 | algorithm that are included in the implementations on Martin 37 | Porter's website. He has declared Porter frozen, so the 38 | behaviour of those implementations should never change. 39 | 40 | PorterStemmer.NLTK_EXTENSIONS (default) 41 | - Implementation that includes further improvements devised by 42 | NLTK contributors or taken from other modified implementations 43 | found on the web. 44 | 45 | For the best stemming, you should use the default NLTK_EXTENSIONS 46 | version. However, if you need to get the same results as either the 47 | original algorithm or one of Martin Porter's hosted versions for 48 | compatibility with an existing implementation or dataset, you can use 49 | one of the other modes instead. 50 | """ 51 | 52 | # Modes the Stemmer can be instantiated in 53 | NLTK_EXTENSIONS = 'NLTK_EXTENSIONS' 54 | MARTIN_EXTENSIONS = 'MARTIN_EXTENSIONS' 55 | ORIGINAL_ALGORITHM = 'ORIGINAL_ALGORITHM' 56 | 57 | def __init__(self, mode=NLTK_EXTENSIONS): 58 | if mode not in ( 59 | self.NLTK_EXTENSIONS, 60 | self.MARTIN_EXTENSIONS, 61 | self.ORIGINAL_ALGORITHM, 62 | ): 63 | raise ValueError( 64 | "Mode must be one of PorterStemmer.NLTK_EXTENSIONS, " 65 | "PorterStemmer.MARTIN_EXTENSIONS, or " 66 | "PorterStemmer.ORIGINAL_ALGORITHM" 67 | ) 68 | 69 | self.mode = mode 70 | 71 | if self.mode == self.NLTK_EXTENSIONS: 72 | # This is a table of irregular forms. It is quite short, 73 | # but still reflects the errors actually drawn to Martin 74 | # Porter's attention over a 20 year period! 75 | irregular_forms = { 76 | "sky": ["sky", "skies"], 77 | "die": ["dying"], 78 | "lie": ["lying"], 79 | "tie": ["tying"], 80 | "news": ["news"], 81 | "inning": ["innings", "inning"], 82 | "outing": ["outings", "outing"], 83 | "canning": ["cannings", "canning"], 84 | "howe": ["howe"], 85 | "proceed": ["proceed"], 86 | "exceed": ["exceed"], 87 | "succeed": ["succeed"], 88 | } 89 | 90 | self.pool = {} 91 | for key in irregular_forms: 92 | for val in irregular_forms[key]: 93 | self.pool[val] = key 94 | 95 | self.vowels = frozenset(['a', 'e', 'i', 'o', 'u']) 96 | 97 | def _is_consonant(self, word, i): 98 | """Returns True if word[i] is a consonant, False otherwise 99 | 100 | A consonant is defined in the paper as follows: 101 | 102 | A consonant in a word is a letter other than A, E, I, O or 103 | U, and other than Y preceded by a consonant. (The fact that 104 | the term `consonant' is defined to some extent in terms of 105 | itself does not make it ambiguous.) So in TOY the consonants 106 | are T and Y, and in SYZYGY they are S, Z and G. If a letter 107 | is not a consonant it is a vowel. 108 | """ 109 | if word[i] in self.vowels: 110 | return False 111 | if word[i] == 'y': 112 | if i == 0: 113 | return True 114 | else: 115 | return not self._is_consonant(word, i - 1) 116 | return True 117 | 118 | def _measure(self, stem): 119 | """Returns the 'measure' of stem, per definition in the paper 120 | 121 | From the paper: 122 | 123 | A consonant will be denoted by c, a vowel by v. A list 124 | ccc... of length greater than 0 will be denoted by C, and a 125 | list vvv... of length greater than 0 will be denoted by V. 126 | Any word, or part of a word, therefore has one of the four 127 | forms: 128 | 129 | CVCV ... C 130 | CVCV ... V 131 | VCVC ... C 132 | VCVC ... V 133 | 134 | These may all be represented by the single form 135 | 136 | [C]VCVC ... [V] 137 | 138 | where the square brackets denote arbitrary presence of their 139 | contents. Using (VC){m} to denote VC repeated m times, this 140 | may again be written as 141 | 142 | [C](VC){m}[V]. 143 | 144 | m will be called the \measure\ of any word or word part when 145 | represented in this form. The case m = 0 covers the null 146 | word. Here are some examples: 147 | 148 | m=0 TR, EE, TREE, Y, BY. 149 | m=1 TROUBLE, OATS, TREES, IVY. 150 | m=2 TROUBLES, PRIVATE, OATEN, ORRERY. 151 | """ 152 | cv_sequence = '' 153 | 154 | # Construct a string of 'c's and 'v's representing whether each 155 | # character in `stem` is a consonant or a vowel. 156 | # e.g. 'falafel' becomes 'cvcvcvc', 157 | # 'architecture' becomes 'vcccvcvccvcv' 158 | for i in range(len(stem)): 159 | if self._is_consonant(stem, i): 160 | cv_sequence += 'c' 161 | else: 162 | cv_sequence += 'v' 163 | 164 | # Count the number of 'vc' occurences, which is equivalent to 165 | # the number of 'VC' occurrences in Porter's reduced form in the 166 | # docstring above, which is in turn equivalent to `m` 167 | return cv_sequence.count('vc') 168 | 169 | def _has_positive_measure(self, stem): 170 | return self._measure(stem) > 0 171 | 172 | def _contains_vowel(self, stem): 173 | """Returns True if stem contains a vowel, else False""" 174 | for i in range(len(stem)): 175 | if not self._is_consonant(stem, i): 176 | return True 177 | return False 178 | 179 | def _ends_double_consonant(self, word): 180 | """Implements condition *d from the paper 181 | 182 | Returns True if word ends with a double consonant 183 | """ 184 | return ( 185 | len(word) >= 2 186 | and word[-1] == word[-2] 187 | and self._is_consonant(word, len(word) - 1) 188 | ) 189 | 190 | def _ends_cvc(self, word): 191 | """Implements condition *o from the paper 192 | 193 | From the paper: 194 | 195 | *o - the stem ends cvc, where the second c is not W, X or Y 196 | (e.g. -WIL, -HOP). 197 | """ 198 | return ( 199 | len(word) >= 3 200 | and self._is_consonant(word, len(word) - 3) 201 | and not self._is_consonant(word, len(word) - 2) 202 | and self._is_consonant(word, len(word) - 1) 203 | and word[-1] not in ('w', 'x', 'y') 204 | ) or ( 205 | self.mode == self.NLTK_EXTENSIONS 206 | and len(word) == 2 207 | and not self._is_consonant(word, 0) 208 | and self._is_consonant(word, 1) 209 | ) 210 | 211 | def _replace_suffix(self, word, suffix, replacement): 212 | """Replaces `suffix` of `word` with `replacement""" 213 | assert word.endswith(suffix), "Given word doesn't end with given suffix" 214 | if suffix == '': 215 | return word + replacement 216 | else: 217 | return word[: -len(suffix)] + replacement 218 | 219 | def _apply_rule_list(self, word, rules): 220 | """Applies the first applicable suffix-removal rule to the word 221 | 222 | Takes a word and a list of suffix-removal rules represented as 223 | 3-tuples, with the first element being the suffix to remove, 224 | the second element being the string to replace it with, and the 225 | final element being the condition for the rule to be applicable, 226 | or None if the rule is unconditional. 227 | """ 228 | for rule in rules: 229 | suffix, replacement, condition = rule 230 | if suffix == '*d' and self._ends_double_consonant(word): 231 | stem = word[:-2] 232 | if condition is None or condition(stem): 233 | return stem + replacement 234 | else: 235 | # Don't try any further rules 236 | return word 237 | if word.endswith(suffix): 238 | stem = self._replace_suffix(word, suffix, '') 239 | if condition is None or condition(stem): 240 | return stem + replacement 241 | else: 242 | # Don't try any further rules 243 | return word 244 | 245 | return word 246 | 247 | def _step1a(self, word): 248 | """Implements Step 1a from "An algorithm for suffix stripping" 249 | 250 | From the paper: 251 | 252 | SSES -> SS caresses -> caress 253 | IES -> I ponies -> poni 254 | ties -> ti 255 | SS -> SS caress -> caress 256 | S -> cats -> cat 257 | """ 258 | # this NLTK-only rule extends the original algorithm, so 259 | # that 'flies'->'fli' but 'dies'->'die' etc 260 | if self.mode == self.NLTK_EXTENSIONS: 261 | if word.endswith('ies') and len(word) == 4: 262 | return self._replace_suffix(word, 'ies', 'ie') 263 | 264 | return self._apply_rule_list( 265 | word, 266 | [ 267 | ('sses', 'ss', None), # SSES -> SS 268 | ('ies', 'i', None), # IES -> I 269 | ('ss', 'ss', None), # SS -> SS 270 | ('s', '', None), # S -> 271 | ], 272 | ) 273 | 274 | def _step1b(self, word): 275 | """Implements Step 1b from "An algorithm for suffix stripping" 276 | 277 | From the paper: 278 | 279 | (m>0) EED -> EE feed -> feed 280 | agreed -> agree 281 | (*v*) ED -> plastered -> plaster 282 | bled -> bled 283 | (*v*) ING -> motoring -> motor 284 | sing -> sing 285 | 286 | If the second or third of the rules in Step 1b is successful, 287 | the following is done: 288 | 289 | AT -> ATE conflat(ed) -> conflate 290 | BL -> BLE troubl(ed) -> trouble 291 | IZ -> IZE siz(ed) -> size 292 | (*d and not (*L or *S or *Z)) 293 | -> single letter 294 | hopp(ing) -> hop 295 | tann(ed) -> tan 296 | fall(ing) -> fall 297 | hiss(ing) -> hiss 298 | fizz(ed) -> fizz 299 | (m=1 and *o) -> E fail(ing) -> fail 300 | fil(ing) -> file 301 | 302 | The rule to map to a single letter causes the removal of one of 303 | the double letter pair. The -E is put back on -AT, -BL and -IZ, 304 | so that the suffixes -ATE, -BLE and -IZE can be recognised 305 | later. This E may be removed in step 4. 306 | """ 307 | # this NLTK-only block extends the original algorithm, so that 308 | # 'spied'->'spi' but 'died'->'die' etc 309 | if self.mode == self.NLTK_EXTENSIONS: 310 | if word.endswith('ied'): 311 | if len(word) == 4: 312 | return self._replace_suffix(word, 'ied', 'ie') 313 | else: 314 | return self._replace_suffix(word, 'ied', 'i') 315 | 316 | # (m>0) EED -> EE 317 | if word.endswith('eed'): 318 | stem = self._replace_suffix(word, 'eed', '') 319 | if self._measure(stem) > 0: 320 | return stem + 'ee' 321 | else: 322 | return word 323 | 324 | rule_2_or_3_succeeded = False 325 | 326 | for suffix in ['ed', 'ing']: 327 | if word.endswith(suffix): 328 | intermediate_stem = self._replace_suffix(word, suffix, '') 329 | if self._contains_vowel(intermediate_stem): 330 | rule_2_or_3_succeeded = True 331 | break 332 | 333 | if not rule_2_or_3_succeeded: 334 | return word 335 | 336 | return self._apply_rule_list( 337 | intermediate_stem, 338 | [ 339 | ('at', 'ate', None), # AT -> ATE 340 | ('bl', 'ble', None), # BL -> BLE 341 | ('iz', 'ize', None), # IZ -> IZE 342 | # (*d and not (*L or *S or *Z)) 343 | # -> single letter 344 | ( 345 | '*d', 346 | intermediate_stem[-1], 347 | lambda stem: intermediate_stem[-1] not in ('l', 's', 'z'), 348 | ), 349 | # (m=1 and *o) -> E 350 | ( 351 | '', 352 | 'e', 353 | lambda stem: (self._measure(stem) == 1 and self._ends_cvc(stem)), 354 | ), 355 | ], 356 | ) 357 | 358 | def _step1c(self, word): 359 | """Implements Step 1c from "An algorithm for suffix stripping" 360 | 361 | From the paper: 362 | 363 | Step 1c 364 | 365 | (*v*) Y -> I happy -> happi 366 | sky -> sky 367 | """ 368 | 369 | def nltk_condition(stem): 370 | """ 371 | This has been modified from the original Porter algorithm so 372 | that y->i is only done when y is preceded by a consonant, 373 | but not if the stem is only a single consonant, i.e. 374 | 375 | (*c and not c) Y -> I 376 | 377 | So 'happy' -> 'happi', but 378 | 'enjoy' -> 'enjoy' etc 379 | 380 | This is a much better rule. Formerly 'enjoy'->'enjoi' and 381 | 'enjoyment'->'enjoy'. Step 1c is perhaps done too soon; but 382 | with this modification that no longer really matters. 383 | 384 | Also, the removal of the contains_vowel(z) condition means 385 | that 'spy', 'fly', 'try' ... stem to 'spi', 'fli', 'tri' and 386 | conflate with 'spied', 'tried', 'flies' ... 387 | """ 388 | return len(stem) > 1 and self._is_consonant(stem, len(stem) - 1) 389 | 390 | def original_condition(stem): 391 | return self._contains_vowel(stem) 392 | 393 | return self._apply_rule_list( 394 | word, 395 | [ 396 | ( 397 | 'y', 398 | 'i', 399 | nltk_condition 400 | if self.mode == self.NLTK_EXTENSIONS 401 | else original_condition, 402 | ) 403 | ], 404 | ) 405 | 406 | def _step2(self, word): 407 | """Implements Step 2 from "An algorithm for suffix stripping" 408 | 409 | From the paper: 410 | 411 | Step 2 412 | 413 | (m>0) ATIONAL -> ATE relational -> relate 414 | (m>0) TIONAL -> TION conditional -> condition 415 | rational -> rational 416 | (m>0) ENCI -> ENCE valenci -> valence 417 | (m>0) ANCI -> ANCE hesitanci -> hesitance 418 | (m>0) IZER -> IZE digitizer -> digitize 419 | (m>0) ABLI -> ABLE conformabli -> conformable 420 | (m>0) ALLI -> AL radicalli -> radical 421 | (m>0) ENTLI -> ENT differentli -> different 422 | (m>0) ELI -> E vileli - > vile 423 | (m>0) OUSLI -> OUS analogousli -> analogous 424 | (m>0) IZATION -> IZE vietnamization -> vietnamize 425 | (m>0) ATION -> ATE predication -> predicate 426 | (m>0) ATOR -> ATE operator -> operate 427 | (m>0) ALISM -> AL feudalism -> feudal 428 | (m>0) IVENESS -> IVE decisiveness -> decisive 429 | (m>0) FULNESS -> FUL hopefulness -> hopeful 430 | (m>0) OUSNESS -> OUS callousness -> callous 431 | (m>0) ALITI -> AL formaliti -> formal 432 | (m>0) IVITI -> IVE sensitiviti -> sensitive 433 | (m>0) BILITI -> BLE sensibiliti -> sensible 434 | """ 435 | 436 | if self.mode == self.NLTK_EXTENSIONS: 437 | # Instead of applying the ALLI -> AL rule after '(a)bli' per 438 | # the published algorithm, instead we apply it first, and, 439 | # if it succeeds, run the result through step2 again. 440 | if word.endswith('alli') and self._has_positive_measure( 441 | self._replace_suffix(word, 'alli', '') 442 | ): 443 | return self._step2(self._replace_suffix(word, 'alli', 'al')) 444 | 445 | bli_rule = ('bli', 'ble', self._has_positive_measure) 446 | abli_rule = ('abli', 'able', self._has_positive_measure) 447 | 448 | rules = [ 449 | ('ational', 'ate', self._has_positive_measure), 450 | ('tional', 'tion', self._has_positive_measure), 451 | ('enci', 'ence', self._has_positive_measure), 452 | ('anci', 'ance', self._has_positive_measure), 453 | ('izer', 'ize', self._has_positive_measure), 454 | abli_rule if self.mode == self.ORIGINAL_ALGORITHM else bli_rule, 455 | ('alli', 'al', self._has_positive_measure), 456 | ('entli', 'ent', self._has_positive_measure), 457 | ('eli', 'e', self._has_positive_measure), 458 | ('ousli', 'ous', self._has_positive_measure), 459 | ('ization', 'ize', self._has_positive_measure), 460 | ('ation', 'ate', self._has_positive_measure), 461 | ('ator', 'ate', self._has_positive_measure), 462 | ('alism', 'al', self._has_positive_measure), 463 | ('iveness', 'ive', self._has_positive_measure), 464 | ('fulness', 'ful', self._has_positive_measure), 465 | ('ousness', 'ous', self._has_positive_measure), 466 | ('aliti', 'al', self._has_positive_measure), 467 | ('iviti', 'ive', self._has_positive_measure), 468 | ('biliti', 'ble', self._has_positive_measure), 469 | ] 470 | 471 | if self.mode == self.NLTK_EXTENSIONS: 472 | rules.append(('fulli', 'ful', self._has_positive_measure)) 473 | 474 | # The 'l' of the 'logi' -> 'log' rule is put with the stem, 475 | # so that short stems like 'geo' 'theo' etc work like 476 | # 'archaeo' 'philo' etc. 477 | rules.append( 478 | ("logi", "log", lambda stem: self._has_positive_measure(word[:-3])) 479 | ) 480 | 481 | if self.mode == self.MARTIN_EXTENSIONS: 482 | rules.append(("logi", "log", self._has_positive_measure)) 483 | 484 | return self._apply_rule_list(word, rules) 485 | 486 | def _step3(self, word): 487 | """Implements Step 3 from "An algorithm for suffix stripping" 488 | 489 | From the paper: 490 | 491 | Step 3 492 | 493 | (m>0) ICATE -> IC triplicate -> triplic 494 | (m>0) ATIVE -> formative -> form 495 | (m>0) ALIZE -> AL formalize -> formal 496 | (m>0) ICITI -> IC electriciti -> electric 497 | (m>0) ICAL -> IC electrical -> electric 498 | (m>0) FUL -> hopeful -> hope 499 | (m>0) NESS -> goodness -> good 500 | """ 501 | return self._apply_rule_list( 502 | word, 503 | [ 504 | ('icate', 'ic', self._has_positive_measure), 505 | ('ative', '', self._has_positive_measure), 506 | ('alize', 'al', self._has_positive_measure), 507 | ('iciti', 'ic', self._has_positive_measure), 508 | ('ical', 'ic', self._has_positive_measure), 509 | ('ful', '', self._has_positive_measure), 510 | ('ness', '', self._has_positive_measure), 511 | ], 512 | ) 513 | 514 | def _step4(self, word): 515 | """Implements Step 4 from "An algorithm for suffix stripping" 516 | 517 | Step 4 518 | 519 | (m>1) AL -> revival -> reviv 520 | (m>1) ANCE -> allowance -> allow 521 | (m>1) ENCE -> inference -> infer 522 | (m>1) ER -> airliner -> airlin 523 | (m>1) IC -> gyroscopic -> gyroscop 524 | (m>1) ABLE -> adjustable -> adjust 525 | (m>1) IBLE -> defensible -> defens 526 | (m>1) ANT -> irritant -> irrit 527 | (m>1) EMENT -> replacement -> replac 528 | (m>1) MENT -> adjustment -> adjust 529 | (m>1) ENT -> dependent -> depend 530 | (m>1 and (*S or *T)) ION -> adoption -> adopt 531 | (m>1) OU -> homologou -> homolog 532 | (m>1) ISM -> communism -> commun 533 | (m>1) ATE -> activate -> activ 534 | (m>1) ITI -> angulariti -> angular 535 | (m>1) OUS -> homologous -> homolog 536 | (m>1) IVE -> effective -> effect 537 | (m>1) IZE -> bowdlerize -> bowdler 538 | 539 | The suffixes are now removed. All that remains is a little 540 | tidying up. 541 | """ 542 | measure_gt_1 = lambda stem: self._measure(stem) > 1 543 | 544 | return self._apply_rule_list( 545 | word, 546 | [ 547 | ('al', '', measure_gt_1), 548 | ('ance', '', measure_gt_1), 549 | ('ence', '', measure_gt_1), 550 | ('er', '', measure_gt_1), 551 | ('ic', '', measure_gt_1), 552 | ('able', '', measure_gt_1), 553 | ('ible', '', measure_gt_1), 554 | ('ant', '', measure_gt_1), 555 | ('ement', '', measure_gt_1), 556 | ('ment', '', measure_gt_1), 557 | ('ent', '', measure_gt_1), 558 | # (m>1 and (*S or *T)) ION -> 559 | ( 560 | 'ion', 561 | '', 562 | lambda stem: self._measure(stem) > 1 and stem[-1] in ('s', 't'), 563 | ), 564 | ('ou', '', measure_gt_1), 565 | ('ism', '', measure_gt_1), 566 | ('ate', '', measure_gt_1), 567 | ('iti', '', measure_gt_1), 568 | ('ous', '', measure_gt_1), 569 | ('ive', '', measure_gt_1), 570 | ('ize', '', measure_gt_1), 571 | ], 572 | ) 573 | 574 | def _step5a(self, word): 575 | """Implements Step 5a from "An algorithm for suffix stripping" 576 | 577 | From the paper: 578 | 579 | Step 5a 580 | 581 | (m>1) E -> probate -> probat 582 | rate -> rate 583 | (m=1 and not *o) E -> cease -> ceas 584 | """ 585 | # Note that Martin's test vocabulary and reference 586 | # implementations are inconsistent in how they handle the case 587 | # where two rules both refer to a suffix that matches the word 588 | # to be stemmed, but only the condition of the second one is 589 | # true. 590 | # Earlier in step2b we had the rules: 591 | # (m>0) EED -> EE 592 | # (*v*) ED -> 593 | # but the examples in the paper included "feed"->"feed", even 594 | # though (*v*) is true for "fe" and therefore the second rule 595 | # alone would map "feed"->"fe". 596 | # However, in THIS case, we need to handle the consecutive rules 597 | # differently and try both conditions (obviously; the second 598 | # rule here would be redundant otherwise). Martin's paper makes 599 | # no explicit mention of the inconsistency; you have to infer it 600 | # from the examples. 601 | # For this reason, we can't use _apply_rule_list here. 602 | if word.endswith('e'): 603 | stem = self._replace_suffix(word, 'e', '') 604 | if self._measure(stem) > 1: 605 | return stem 606 | if self._measure(stem) == 1 and not self._ends_cvc(stem): 607 | return stem 608 | return word 609 | 610 | def _step5b(self, word): 611 | """Implements Step 5a from "An algorithm for suffix stripping" 612 | 613 | From the paper: 614 | 615 | Step 5b 616 | 617 | (m > 1 and *d and *L) -> single letter 618 | controll -> control 619 | roll -> roll 620 | """ 621 | return self._apply_rule_list( 622 | word, [('ll', 'l', lambda stem: self._measure(word[:-1]) > 1)] 623 | ) 624 | 625 | def stem(self, word): 626 | stem = word.lower() 627 | 628 | if self.mode == self.NLTK_EXTENSIONS and word in self.pool: 629 | return self.pool[word] 630 | 631 | if self.mode != self.ORIGINAL_ALGORITHM and len(word) <= 2: 632 | # With this line, strings of length 1 or 2 don't go through 633 | # the stemming process, although no mention is made of this 634 | # in the published algorithm. 635 | return word 636 | 637 | stem = self._step1a(stem) 638 | stem = self._step1b(stem) 639 | stem = self._step1c(stem) 640 | stem = self._step2(stem) 641 | stem = self._step3(stem) 642 | stem = self._step4(stem) 643 | stem = self._step5a(stem) 644 | stem = self._step5b(stem) 645 | 646 | return stem 647 | 648 | 649 | -------------------------------------------------------------------------------- /src/third_party_utils/spacy_stop_words.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | # from __future__ import unicode_literals 3 | 4 | 5 | # Stop words 6 | STOP_WORDS = set( 7 | """ 8 | a about above across after afterwards again against all almost alone along 9 | already also although always am among amongst amount an and another any anyhow 10 | anyone anything anyway anywhere are around as at 11 | 12 | back be became because become becomes becoming been before beforehand behind 13 | being below beside besides between beyond both bottom but by 14 | 15 | call can cannot ca could 16 | 17 | did do does doing done down due during 18 | 19 | each eight either eleven else elsewhere empty enough even ever every 20 | everyone everything everywhere except 21 | 22 | few fifteen fifty first five for former formerly forty four from front full 23 | further 24 | 25 | get give go 26 | 27 | had has have he hence her here hereafter hereby herein hereupon hers herself 28 | him himself his how however hundred 29 | 30 | i if in indeed into is it its itself 31 | 32 | keep 33 | 34 | last latter latterly least less 35 | 36 | just 37 | 38 | made make many may me meanwhile might mine more moreover most mostly move much 39 | must my myself 40 | 41 | name namely neither never nevertheless next nine no nobody none noone nor not 42 | nothing now nowhere 43 | 44 | of off often on once one only onto or other others otherwise our ours ourselves 45 | out over own 46 | 47 | part per perhaps please put 48 | 49 | quite 50 | 51 | rather re really regarding 52 | 53 | same say see seem seemed seeming seems serious several she should show side 54 | since six sixty so some somehow someone something sometime sometimes somewhere 55 | still such 56 | 57 | take ten than that the their them themselves then thence there thereafter 58 | thereby therefore therein thereupon these they third this those though three 59 | through throughout thru thus to together too top toward towards twelve twenty 60 | two 61 | 62 | under until up unless upon us used using 63 | 64 | various very very via was we well were what whatever when whence whenever where 65 | whereafter whereas whereby wherein whereupon wherever whether which while 66 | whither who whoever whole whom whose why will with within without would 67 | 68 | yet you your yours yourself yourselves 69 | """.split() 70 | ) 71 | 72 | contractions = ["n't", "'d", "'ll", "'m", "'re", "'s", "'ve"] 73 | STOP_WORDS.update(contractions) 74 | 75 | for apostrophe in ["‘", "’"]: 76 | for stopword in contractions: 77 | STOP_WORDS.add(stopword.replace("'", apostrophe)) 78 | -------------------------------------------------------------------------------- /src/wiqa_wrapper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import enum 3 | import json 4 | import os 5 | from os import listdir 6 | from os.path import isfile, join 7 | from typing import Any, Dict, List 8 | 9 | from tqdm import tqdm 10 | 11 | from src.helpers.ProparaExtendedPara import ProparaExtendedParaMetadata 12 | from src.helpers.dataset_info import * 13 | from src.helpers.situation_graph import SituationLabel 14 | from src.helpers.whatif_metadata import WhatifMetadata 15 | 16 | 17 | class Jsonl: 18 | @staticmethod 19 | def load(in_filepath: str) -> List[Dict[str, Any]]: 20 | d: List[Dict[str, Any]] = [] 21 | if not os.path.exists(in_filepath): 22 | print(f"JSONL Path does not exist: {in_filepath}") 23 | return d 24 | with open(in_filepath, 'r') as infile: 25 | for line in infile: 26 | if line and not Jsonl.is_comment_line(line=line): 27 | d.append(json.loads(line.strip())) 28 | return d 29 | 30 | @staticmethod 31 | def is_comment_line(line: str): 32 | return line.strip().startswith("#") 33 | 34 | 35 | class WIQAUtils: 36 | @staticmethod 37 | def get_split_para(passage_with_sent_id: str) -> List[str]: 38 | split_para = passage_with_sent_id.split('. ') 39 | text_split_para = split_para[1::2] 40 | return text_split_para 41 | 42 | @staticmethod 43 | def filenames_in_folder(folder): 44 | return [f for f in listdir(folder) if isfile(join(folder, f))] 45 | 46 | @staticmethod 47 | def strip_special_char(a_string): 48 | return "".join( 49 | [x for x in a_string if (ord('a') <= ord(x) <= ord('z')) or (ord('A') <= ord(x) <= ord('Z'))]).lower() 50 | 51 | # Constants for the dataset wrapper 52 | LABELS = [{"label": "A", "text": "more"}, 53 | {"label": "B", "text": "less"}, 54 | {"label": "C", "text": "no effect"}] 55 | 56 | 57 | class WIQAQuesType(str, enum.Enum): 58 | OTHER = "OTHER" 59 | INPARA_EFFECT = "INPARA_EFFECT" # From same if-then block, source node is from { X, Y, U, W } 60 | EXOGENOUS_EFFECT = "EXOGENOUS_EFFECT" # From same if-then block but source node is Z or V 61 | INPARA_DISTRACTOR = "INPARA_DISTRACTOR" # non-path from same if-then block 62 | OUTOFPARA_DISTRACTOR = "OUTOFPARA_DISTRACTOR" # will be useful when we get out-of-para distractor annotations 63 | 64 | @staticmethod 65 | def from_str(qtype_str): 66 | if not qtype_str: 67 | raise ValueError( 68 | f"TF question type must not be empty or None-- input to WIQAQuesType from_str: ({qtype_str}) ") 69 | qtype_str = qtype_str.lower().replace('_', ' ').strip() 70 | 71 | if qtype_str in ['in para effect', 'inpara effect', 'direct']: 72 | return WIQAQuesType.INPARA_EFFECT 73 | elif qtype_str in ['exogeneous effect', 'exogenous effect', 'indirect']: 74 | return WIQAQuesType.EXOGENOUS_EFFECT 75 | elif qtype_str in ['inpara distractor', 'in para distractor']: 76 | return WIQAQuesType.INPARA_DISTRACTOR 77 | elif qtype_str in ['outofpara distractor', 'out of para distractor']: 78 | return WIQAQuesType.OUTOFPARA_DISTRACTOR 79 | else: 80 | print(f"WARNING: tf question type: {qtype_str} not identified") 81 | return WIQAQuesType.OTHER 82 | 83 | @staticmethod 84 | def from_path(path, is_distractor, in_para=True): 85 | if is_distractor: 86 | if in_para: 87 | return WIQAQuesType.INPARA_DISTRACTOR 88 | else: 89 | return WIQAQuesType.OUTOFPARA_DISTRACTOR 90 | 91 | start_node = path[0] if path and len(path) > 1 else "" 92 | if start_node.id == "Z" or start_node.id == "V": # Out of para situations causing changes in the process 93 | return WIQAQuesType.EXOGENOUS_EFFECT 94 | 95 | return WIQAQuesType.INPARA_EFFECT # Changes to in-para events causing changes in rest of the process 96 | 97 | def to_json(self): 98 | return self.name 99 | 100 | 101 | class WIQAExplanationType(str, enum.Enum): 102 | NO_EXPL = "NO_EXPL" 103 | PARA_SENT_EXPL = "PARA_SENT_EXPL" 104 | 105 | @staticmethod 106 | def from_str(sl): 107 | if not sl: 108 | raise ValueError( 109 | f"({sl}) is not a valid Enum WIQAExplanationType") 110 | sl = sl.lower().replace('_', ' ').strip() 111 | if sl in ['no expl', 'no explanation', 'no exp']: 112 | return WIQAExplanationType.NO_EXPL 113 | elif sl in ['with exp', 'with expl', 'para sent expl', 'paragraph sentence explanation', 'expl', 'exp']: 114 | return WIQAExplanationType.PARA_SENT_EXPL 115 | else: 116 | raise Exception(f"WARNING: ({sl}) is not a valid choice for {WIQAExplanationType.__dict__}") 117 | 118 | def to_json(self): 119 | return self.name 120 | 121 | 122 | class WIQAExplanation(object): 123 | 124 | def __init__(self, di: SituationLabel, dj: SituationLabel, de: SituationLabel, i: int, j: int): 125 | self.di = di 126 | self.dj = dj 127 | self.de = de 128 | self.i = i 129 | self.j = j 130 | 131 | @staticmethod 132 | def instantiate_from_old_json_version(json_data: Dict[str, Any]): 133 | assert 'explanation' in json_data, f"WIQA explanation cannot be instantiated due to missing keys[explanation] in json: {json_data}" 134 | assert 'di' in json_data[ 135 | 'explanation'], f"WIQA explanation cannot be instantiated due to missing keys[explanation][di] in json: {json_data}" 136 | return WIQAExplanation( 137 | di=SituationLabel.from_str(json_data['explanation']['di']), 138 | dj=SituationLabel.from_str(json_data['explanation']['di']), 139 | de=SituationLabel.from_str(json_data['explanation']['dj'] if "de" not in json_data["explanation"] else json_data['explanation']['de']), 140 | i=json_data['explanation']['i'], 141 | j=json_data['explanation']['j'] 142 | ) 143 | 144 | @staticmethod 145 | def instantiate_from(json_data: Dict[str, Any]): 146 | ''' 147 | :param json_data: must contain keys: 148 | "explanations": { 149 | "de" : answer_label, 150 | "di": optional_supporting_sent_label, 151 | "i": optional_sentidx_or_None, 152 | "j": optional_sentidx_or_None 153 | } 154 | :return: WIQAExplanation object. 155 | ''' 156 | assert 'explanation' in json_data, f"WIQA explanation cannot be instantiated due to missing keys[explanation] in json: {json_data}" 157 | assert 'di' in json_data[ 158 | 'explanation'], f"WIQA explanation cannot be instantiated due to missing keys[explanation][di] in json: {json_data}" 159 | return WIQAExplanation( 160 | di=None if 'di' not in 'explanation' or not json_data['explanation']['di'] else SituationLabel.from_str(json_data['explanation']['di']), 161 | dj=None if 'di' not in 'explanation' or not json_data['explanation']['di'] else SituationLabel.from_str(json_data['explanation']['di']), 162 | de=SituationLabel.from_str(json_data['explanation']['de']), 163 | i=None if 'i' not in json_data['explanation'] else json_data['explanation']['i'], 164 | j=None if 'j' not in json_data['explanation'] else json_data['explanation']['j'] 165 | ) 166 | 167 | @staticmethod 168 | def deserialize(json_data: Dict[str, Any]): 169 | return WIQAExplanation(di=SituationLabel.from_str(json_data['di']), 170 | dj=SituationLabel.from_str(json_data['dj']), 171 | de=SituationLabel.from_str(json_data['de']), 172 | i=json_data['i'], 173 | j=json_data['j'] 174 | ) 175 | 176 | 177 | class WIQAQuestion(object): 178 | def __init__(self, stem: str, 179 | para_steps: List[str], 180 | answer_label: str, 181 | answer_label_as_choice: str, 182 | choices: List[Dict[str, str]] = WIQAUtils.LABELS): 183 | self.stem = stem 184 | self.para_steps = para_steps 185 | self.answer_label = answer_label 186 | self.answer_label_as_choice = answer_label_as_choice 187 | self.choices = choices 188 | 189 | @staticmethod 190 | def instantiate_from(json_data: Dict[str, Any]): 191 | assert 'explanation' in json_data, f"WIQA question cannot be instantiated due to missing keys[explanation] in json: {json_data}" 192 | chosen_label = SituationLabel.from_str(json_data['explanation']['dj']) 193 | return WIQAQuestion(stem=json_data['question']['question'], 194 | para_steps=json_data['steps'], 195 | answer_label=chosen_label.as_less_more(), 196 | answer_label_as_choice=SituationLabel.get_emnlp_test_choice(chosen_label)) 197 | 198 | @staticmethod 199 | def deserialize(json_data: Dict[str, Any]): 200 | return WIQAQuestion(stem=json_data['stem'], 201 | para_steps=json_data['para_steps'], 202 | answer_label=json_data['answer_label'], 203 | answer_label_as_choice=json_data['answer_label_as_choice']) 204 | 205 | 206 | class WIQAQuesMetadata(object): 207 | def __init__(self, ques_id, graph_id: str, para_id: str, question_type: WIQAQuesType): 208 | self.ques_id = ques_id 209 | self.graph_id = graph_id 210 | self.para_id = para_id 211 | self.question_type = question_type 212 | 213 | @staticmethod 214 | def instantiate_from(json_data: Dict[str, Any]): 215 | assert 'question' in json_data, f"WIQA QuesMetadata cannot be instantiated due to missing keys[question] in json: {json_data}" 216 | return WIQAQuesMetadata( 217 | ques_id=json_data['id'], 218 | graph_id=json_data['metadata']['graph_id'], 219 | para_id=json_data['metadata']['para_id'], 220 | question_type=WIQAQuesType.from_str(json_data['metadata']['question_type']) 221 | ) 222 | 223 | @staticmethod 224 | def deserialize(json_data: Dict[str, Any]): 225 | return WIQAQuesMetadata(ques_id=json_data['ques_id'], 226 | graph_id=json_data['graph_id'], 227 | para_id=json_data['para_id'], 228 | question_type=WIQAQuesType.from_str(json_data['question_type']) 229 | ) 230 | 231 | 232 | class WIQADataPoint(object): 233 | """ 234 | holds the relevant WIQA data sample 235 | """ 236 | 237 | def __init__(self, question: WIQAQuestion, explanation: WIQAExplanation, metadata: WIQAQuesMetadata): 238 | self.question = question 239 | self.explanation = explanation 240 | self.metadata = metadata 241 | 242 | @staticmethod 243 | def instantiate_from(explanation_type: WIQAExplanationType, json_data: Dict[str, Any]): 244 | if explanation_type == WIQAExplanationType.NO_EXPL: 245 | explanation = None 246 | else: 247 | explanation = WIQAExplanation.instantiate_from_old_json_version(json_data=json_data) 248 | return WIQADataPoint( 249 | question=WIQAQuestion.instantiate_from(json_data=json_data), 250 | explanation=explanation, 251 | metadata=WIQAQuesMetadata.instantiate_from(json_data=json_data) 252 | ) 253 | 254 | def to_json(self, explanation_type: WIQAExplanationType= WIQAExplanationType.PARA_SENT_EXPL): 255 | # ensure that no object contains the 256 | if explanation_type == WIQAExplanationType.NO_EXPL: 257 | return {'question': self.question.__dict__, 258 | 'metadata': self.metadata.__dict__ 259 | } 260 | else: 261 | return {'question': self.question.__dict__, 262 | 'explanation': self.explanation.__dict__, 263 | 'metadata': self.metadata.__dict__ 264 | } 265 | 266 | @staticmethod 267 | def deserialize(json_data: Dict[str, Any]): 268 | return WIQADataPoint(question=WIQAQuestion.deserialize(json_data=json_data['question']), 269 | metadata=WIQAQuesMetadata.deserialize(json_data=json_data['metadata']), 270 | explanation=WIQAExplanation.deserialize(json_data=json_data['explanation']) 271 | ) 272 | 273 | @staticmethod 274 | def get_default_whatif_metadata( 275 | para_ids_metainfo_fp=download_from_url_if_not_in_cache( 276 | para_partition_info.cloud_path), 277 | situation_graphs_fp=download_from_url_if_not_in_cache( 278 | influence_graphs_v1.cloud_path)): 279 | return WhatifMetadata(para_ids_metainfo_fp=para_ids_metainfo_fp, situation_graphs_fp=situation_graphs_fp) 280 | 281 | @staticmethod 282 | def get_default_propara_paragraphs_metadata(extended_propara_para_fp=download_from_url_if_not_in_cache( 283 | propara_para_info.cloud_path)): 284 | return ProparaExtendedParaMetadata(extended_propara_para_fp=extended_propara_para_fp) 285 | 286 | @staticmethod 287 | def load_all_in_jsonl(jsonl_filepath): 288 | for j in Jsonl.load(in_filepath=jsonl_filepath): 289 | yield WIQADataPoint.deserialize(json_data=j) 290 | 291 | def get_steps(self): 292 | return self.question.para_steps 293 | 294 | def get_provenance_influence_graph(self, whatif_metadata: WhatifMetadata): 295 | return whatif_metadata.get_graph_for_id(graph_id=self.metadata.graph_id) 296 | 297 | def get_orig_propara_paragraph(self, 298 | orig_propara: ProparaExtendedParaMetadata): 299 | return orig_propara.paraentry_for_id(para_id=self.metadata.para_id) 300 | 301 | def get_other_paragraphs_under_this_topic(self, whatif_metadata: WhatifMetadata, 302 | orig_propara: ProparaExtendedParaMetadata): 303 | return [orig_propara.paraentry_for_id(para_id=x) for x in whatif_metadata.get_paraids_for_topic( 304 | topic_str=whatif_metadata.get_topic_for_paraid(para_id=self.metadata.para_id))] 305 | 306 | 307 | def create_concise_dataset(input_filepaths, output_filepaths, explanation_type: WIQAExplanationType): 308 | """ 309 | :param input_filepaths: 310 | :param output_filepaths: 311 | :param explanation_type: 312 | :usage ```partitions = ["train.jsonl", "dev.jsonl", "test.jsonl"] 313 | create_concise_dataset( 314 | input_filepaths=[download_from_url_if_not_in_cache(wiqa_explanations_v1.cloud_path + partition) for partition in 315 | partitions], 316 | output_filepaths=["/tmp/od/" + x for x in partitions], 317 | explanation_type=WIQAExplanationType.PARA_SENT_EXPL)``` 318 | :return: 319 | """ 320 | print(f"\nInput file paths: {input_filepaths}") 321 | assert input_filepaths is not None and len(input_filepaths) > 0, \ 322 | f"in/outfile paths for creating wiqa wrapper files is not matching or empty." 323 | 324 | for file_num, input_filepath in enumerate(input_filepaths): 325 | print(f"\nGenerating reformatted data for .... {input_filepath}") 326 | with open(input_filepath, 'r') as in_file: 327 | output_filepath = output_filepaths[file_num] 328 | 329 | # ensure that the outpath directory exists, if not create that path. 330 | outdir = "/".join(output_filepath.split("/")[0:-1]) 331 | if not os.path.exists(outdir): 332 | os.makedirs(outdir) 333 | 334 | with open(output_filepath, 'w') as out_file: 335 | for line in tqdm(in_file): 336 | json_data = json.loads(line) 337 | data_object = WIQADataPoint.instantiate_from(explanation_type=explanation_type, 338 | json_data=json_data) 339 | dump_it = data_object.to_json(explanation_type=explanation_type) 340 | out_file.write(json.dumps(dump_it)) 341 | out_file.write('\n') 342 | 343 | 344 | 345 | if __name__ == '__main__': 346 | parser = argparse.ArgumentParser() 347 | subparsers = parser.add_subparsers(dest='subcommand') 348 | 349 | # -------------------------------------------------- 350 | # ################ Download dataset ################ 351 | # -------------------------------------------------- 352 | parser_download = subparsers.add_parser('download') 353 | parser_download.add_argument('--input_dirpath', 354 | action='store', 355 | dest='input_dirpath', 356 | required=True, 357 | help='Input dataset directory') 358 | parser_download.add_argument('--output_dirpath', 359 | action='store', 360 | dest='output_dirpath', 361 | required=True, 362 | help='folder to store output') 363 | parser_download.add_argument('--explanation_type', 364 | action='store', 365 | dest='explanation_type', 366 | required=True, 367 | help='with_expl|no_expl') 368 | 369 | # -------------------------------------------------- 370 | # ################ Load json data ################ 371 | # -------------------------------------------------- 372 | args = parser.parse_args() 373 | 374 | if args.subcommand == "download": 375 | print(f"Input {[args.input_dirpath + x for x in WIQAUtils.filenames_in_folder(args.input_dirpath)]}") 376 | create_concise_dataset( 377 | input_filepaths=[args.input_dirpath + x for x in WIQAUtils.filenames_in_folder(args.input_dirpath)], 378 | output_filepaths=[args.output_dirpath + x for x in WIQAUtils.filenames_in_folder(args.input_dirpath)], 379 | explanation_type=args.explanation_type) 380 | -------------------------------------------------------------------------------- /tests/simple_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tempfile 4 | from unittest import TestCase 5 | 6 | from src.eval.evaluation import InputRequiredByFineEval, FineEvalMetrics 7 | from src.helpers.dataset_info import download_from_url_if_not_in_cache, wiqa_explanations_v1 8 | from src.wiqa_wrapper import WIQAExplanation, WIQAExplanationType, create_concise_dataset, WIQADataPoint 9 | 10 | 11 | class TestWIQAWrapper(TestCase): 12 | 13 | def setUp(self) -> None: 14 | self.init_dir = tempfile.mkdtemp() 15 | self.init_dir += "/" if "/" in self.init_dir else "\\" 16 | 17 | # 1. (done) reformat dataset (this is a command line argument.) 18 | # 2. load (reformatted) dataset ... includes other operations on it. (this is not command line) 19 | # 3. Evaluation code. 20 | def test_1_create_concise_dataset(self): 21 | partitions = ["train.jsonl", "dev.jsonl", "test.jsonl"] 22 | create_concise_dataset( 23 | input_filepaths=[download_from_url_if_not_in_cache(wiqa_explanations_v1.cloud_path + partition) for 24 | partition in 25 | partitions], 26 | output_filepaths=[self.init_dir + x for x in partitions], 27 | explanation_type=WIQAExplanationType.PARA_SENT_EXPL) 28 | for outfp in [self.init_dir + x for x in partitions]: 29 | assert os.path.exists(outfp) 30 | 31 | # Load the dataset from jsonl files. 32 | def test_2_load_dataset(self): 33 | for x in WIQADataPoint.load_all_in_jsonl(jsonl_filepath=self.init_dir + "dev.jsonl"): 34 | j_str = json.dumps(x.to_json()) 35 | assert j_str is not None 36 | break 37 | 38 | def test_3_sample_model_output_eval(self): 39 | # Note: j_str contains a mix of ' and " 40 | json_data = """{"logits": [[0.7365949749946594, 0.16483880579471588, 0.38572263717651367, 0.33268705010414124, -0.6508089900016785, -0.950057864189148], [0.3361547291278839, -0.02215845137834549, 0.8630506992340088, 0.4753769040107727, -1.0981523990631104, 0.04984292760491371], [0.4498467743396759, 0.26323091983795166, 0.5597160458564758, 0.06369128823280334, -0.33793506026268005, -0.30190590023994446], [0.41394802927970886, 0.31742218136787415, 0.42982375621795654, -0.2891058027744293, -0.09577881544828415, -0.4486318528652191], [0.5242481231689453, -0.05186435207724571, 0.4505387544631958, -0.43092456459999084, -0.015227549709379673, 0.10361793637275696], [0.8527745604515076, 0.18845966458320618, 0.6540948748588562, -0.06324845552444458, -0.03267676383256912, 0.058296892791986465], [0.40418609976768494, -0.24220454692840576, 0.0737631767988205, -0.8445389270782471, -0.12929767370224, 0.5813987851142883]], "class_probabilities": [[0.29663264751434326, 0.16745895147323608, 0.20885121822357178, 0.19806326925754547, 0.07407592236995697, 0.054918017238378525], [0.18079228699207306, 0.12634745240211487, 0.3062019348144531, 0.20779892802238464, 0.04307926073670387, 0.13578014075756073], [0.21968600153923035, 0.18228720128536224, 0.24519862234592438, 0.14931286871433258, 0.09992476552724838, 0.10359060019254684], [0.22513438761234283, 0.20441898703575134, 0.22873708605766296, 0.11145754158496857, 0.13522914052009583, 0.09502287209033966], [0.24298663437366486, 0.13657772541046143, 0.2257203906774521, 0.09348806738853455, 0.14167429506778717, 0.15955300629138947], [0.27786338329315186, 0.1429957151412964, 0.22779585421085358, 0.11117511987686157, 0.11462641507387161, 0.12554346024990082], [0.23202574253082275, 0.1215660497546196, 0.16673828661441803, 0.06656130403280258, 0.1360965520143509, 0.27701207995414734]], "question": {"stem": "suppose squirrels get sick happens, how will it affect squirrels need more food.", "para_steps": ["Squirrels try to eat as much as possible", "Squirrel gains weight and fat", "Squirrel also hides food in or near its den", "Squirrels also grow a thicker coat as the weather gets colder", "Squirrel lives off of its excess body fat", "Squirrel uses its food stores in the winter..)"], "answer_label": "more", "answer_label_as_choice": "A", "choices": [{"label": "A", "text": "more"}, {"label": "B", "text": "less"}, {"label": "C", "text": "no effect"}]}, "explanation": {"di": "RESULTS_IN", "dj": "RESULTS_IN", "de": "RESULTS_IN", "i": 1, "j": 4}, "orig_answer": {"explanation": {"di": "RESULTS_IN", "dj": "RESULTS_IN", "de": "NOT_RESULTS_IN", "i": 2, "j": 3}, "metadata": {"ques_id": "influence_graph:1310:156:83#3", "graph_id": "156", "para_id": "1310", "question_type": "EXOGENOUS_EFFECT"}} }""" 41 | json_obj = json.loads(json_data) 42 | answer_obj = InputRequiredByFineEval.from_( 43 | prediction_on_this_example=WIQAExplanation.instantiate_from(json_data=json_obj), 44 | json_from_question=json_obj["orig_answer"], 45 | expl_type=WIQAExplanationType.PARA_SENT_EXPL 46 | ) 47 | assert answer_obj.metrics[FineEvalMetrics.XDIR] \ 48 | and not answer_obj.metrics[FineEvalMetrics.EQDIR] \ 49 | and not answer_obj.metrics[FineEvalMetrics.XSENTID] \ 50 | and not answer_obj.metrics[FineEvalMetrics.YSENTID] 51 | 52 | # all influence graphs 53 | # 54 | # 55 | # 56 | # graph structure is (note: nice diagram depicting node names coming soon): 57 | # X ==> Y 58 | # X =/=> W 59 | # U =/=> Y 60 | # Z ==> X 61 | # V =/=> X 62 | # W =/=> A 63 | # W ==> D 64 | # Y ==> A(ccelerate) or Y ==> D(ecelerate) 65 | def test_4_influence_graphs(self): 66 | igs = WIQADataPoint.get_default_whatif_metadata() 67 | for gid in igs.get_all_graphids(): 68 | ig = igs.get_graph_for_id(graph_id=gid) 69 | assert ig.to_json_v1() is not None 70 | break 71 | 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import datetime 4 | import numpy as np 5 | 6 | # Function to calculate the accuracy of our predictions vs labels 7 | from typing import List, Dict 8 | 9 | 10 | def flat_accuracy(preds, labels): 11 | pred_flat = np.argmax(preds, axis=1).flatten() 12 | labels_flat = labels.flatten() 13 | return np.sum(pred_flat == labels_flat) / len(labels_flat) 14 | 15 | 16 | def format_time(elapsed): 17 | # Round to the nearest second. 18 | elapsed_rounded = int(round((elapsed))) 19 | # Format as hh:mm:ss 20 | return str(datetime.timedelta(seconds=elapsed_rounded)) 21 | 22 | def read_jsonl(input_file: str) -> List[Dict]: 23 | output: List[Dict] = [] 24 | with open(input_file, 'r') as open_file: 25 | for line in open_file: 26 | output.append(json.loads(line)) 27 | return output --------------------------------------------------------------------------------