├── .gitignore ├── LICENSE ├── README.md ├── model ├── __init__.py ├── data.py ├── qa_model.py ├── requirements.txt ├── run.py └── run_wiqa_classifier.sh ├── requirements.txt ├── src ├── __init__.py ├── eval │ ├── __init__.py │ ├── eval_utils.py │ └── evaluation.py ├── helpers │ ├── ProparaExtendedPara.py │ ├── SGFileLoaders.py │ ├── __init__.py │ ├── collections_util.py │ ├── dataset_info.py │ ├── situation_graph.py │ └── whatif_metadata.py ├── third_party_utils │ ├── __init__.py │ ├── allennlp_cached_filepath.py │ ├── nltk_porter_stemmer.py │ └── spacy_stop_words.py └── wiqa_wrapper.py ├── tests └── simple_test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wiqa-dataset 2 | 3 | Code repo for EMNLP 2019 WIQA dataset paper. 4 | 5 | ## Usage 6 | 7 | First, set up a virtual environment like this: 8 | 9 | ``` 10 | virtualenv venv 11 | source venv/bin/activate 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | (You can also use Conda.) 16 | 17 | Create a simple program `retrieve.py` like this: 18 | 19 | ``` 20 | from src.wiqa_wrapper import WIQADataPoint 21 | 22 | wimd = WIQADataPoint.get_default_whatif_metadata() 23 | sg = wimd.get_graph_for_id(graph_id="13") 24 | print(sg.to_json_v1()) 25 | ``` 26 | 27 | This program will read the What-If metadata (`wimd`), retrieve situation graph 13 (`sg`), and print a string representation in JSON format. To see the result, run it like this (in the virtual env): 28 | 29 | ``` 30 | % PYTHONPATH=. python retrieve.py 31 | {"V": ["water is exposed to high heat", "water is not protected from high heat"], "Z": ["water is shielded from heat", ... 32 | ``` 33 | 34 | ## Running tests 35 | 36 | Set up the virtual environment as above, then run the test like this: 37 | 38 | ``` 39 | PYTHONPATH=. 40 | pytest 41 | ``` 42 | 43 | ## Running Model 44 | 45 | ``` 46 | pip install -r model/requirements.txt 47 | bash model/run_wiqa_classifer.sh 48 | ``` 49 | 50 | Note: comment out the `--gpus` and `--accelerator` arguments in the script for CPU training 51 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/model/__init__.py -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | """Wrapper for a conditional generation dataset present in 2 tab-separated columns: 2 | source[TAB]target 3 | """ 4 | import logging 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | import pandas as pd 10 | from transformers import AutoTokenizer 11 | from tqdm import tqdm 12 | from collections import defaultdict 13 | # from src.data.creation.influence_graph import InfluenceGraph 14 | 15 | label_dict = {"less": 0, "attenuator": 0, "more": 1, "intensifier": 1, "no_effect": 2} 16 | rev_label_dict = defaultdict(list) 17 | 18 | for k, v in label_dict.items(): 19 | rev_label_dict[v].append(k) 20 | 21 | rev_label_dict = {k: "/".join(v) for k, v in rev_label_dict.items()} 22 | 23 | class GraphQaDataModule(pl.LightningDataModule): 24 | def __init__(self, basedir: str, tokenizer_name: str, batch_size: int, num_workers: int = 16): 25 | super().__init__() 26 | self.basedir = basedir 27 | self.batch_size = batch_size 28 | self.num_workers = num_workers 29 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=True) 30 | 31 | def train_dataloader(self): 32 | dataset = GraphQADataset(tokenizer=self.tokenizer, 33 | qa_pth=f"{self.basedir}/train.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 34 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 35 | shuffle=True, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 36 | 37 | def val_dataloader(self): 38 | dataset = GraphQADataset(tokenizer=self.tokenizer, 39 | qa_pth=f"{self.basedir}/dev.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 40 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 41 | shuffle=False, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 42 | 43 | def test_dataloader(self): 44 | dataset = GraphQADataset(tokenizer=self.tokenizer, 45 | qa_pth=f"{self.basedir}/test.jsonl", graph_pth=f"{self.basedir}/influence_graphs.jsonl") 46 | return DataLoader(dataset=dataset, batch_size=self.batch_size, 47 | shuffle=False, num_workers=self.num_workers, collate_fn=GraphQADataset.collate_pad) 48 | 49 | 50 | class GraphQADataset(Dataset): 51 | def __init__(self, tokenizer, qa_pth: str, graph_pth: str) -> None: 52 | super().__init__() 53 | self.qa_pth = qa_pth 54 | self.graph_pth = graph_pth 55 | self.tokenizer = tokenizer 56 | # self.read_graphs() 57 | self.read_qa() 58 | 59 | # def read_graphs(self): 60 | # influence_graphs = pd.read_json( 61 | # self.graph_pth, orient='records', lines=True).to_dict(orient='records') 62 | # self.graphs = {} 63 | # for graph_dict in tqdm(influence_graphs, desc="Reading graphs", total=len(influence_graphs)): 64 | # self.graphs[str(graph_dict["graph_id"])] = graph_dict 65 | 66 | def read_qa(self): 67 | logging.info("Reading data from {}".format(self.qa_pth)) 68 | data = pd.read_json(self.qa_pth, orient="records", lines=True) 69 | self.questions, self.answer_labels, self.paragraphs = [], [], [] 70 | logging.info(f"Reading QA file from {self.qa_pth}") 71 | for i, row in tqdm(data.iterrows(), total=len(data), desc="Reading QA examples"): 72 | self.answer_labels.append(row["question"]["answer_label"].strip()) 73 | para = " ".join([p.strip() for p in row["question"]["para_steps"] if len(p) > 0]) 74 | question = row["question"]["stem"].strip() 75 | self.questions.append(question) 76 | self.paragraphs.append(para) 77 | # self.graph_ids.append(row["metadata"]["graph_id"]) 78 | 79 | encoded_input = self.tokenizer(self.questions, self.paragraphs) 80 | self.input_ids = encoded_input["input_ids"] 81 | if "token_type_ids" in encoded_input: 82 | self.token_type_ids = encoded_input["token_type_ids"] 83 | else: 84 | self.token_type_ids = [[0] * len(s) for s in encoded_input["input_ids"]] # only BERT uses it anyways, so just set it to 0 85 | def __len__(self) -> int: 86 | return len(self.questions) 87 | 88 | def __getitem__(self, i): 89 | # We’ll pad at the batch level. 90 | return (self.input_ids[i], self.token_type_ids[i], self.answer_labels[i]) 91 | 92 | @staticmethod 93 | def collate_pad(batch): 94 | max_token_len = 0 95 | num_elems = len(batch) 96 | for i in range(num_elems): 97 | tokens, _, _ = batch[i] 98 | max_token_len = max(max_token_len, len(tokens)) 99 | 100 | tokens = torch.zeros(num_elems, max_token_len).long() 101 | tokens_mask = torch.zeros(num_elems, max_token_len).long() 102 | token_type_ids = torch.zeros(num_elems, max_token_len).long() 103 | labels = torch.zeros(num_elems).long() 104 | # graphs = [] 105 | for i in range(num_elems): 106 | toks, type_ids, label = batch[i] 107 | length = len(toks) 108 | tokens[i, :length] = torch.LongTensor(toks) 109 | token_type_ids[i, :length] = torch.LongTensor(type_ids) 110 | tokens_mask[i, :length] = 1 111 | # graphs.append(graph) 112 | labels[i] = label_dict[label] 113 | return [tokens, token_type_ids, tokens_mask, labels] 114 | 115 | 116 | # class InfluenceGraphNNData: 117 | # """ 118 | # V Z 119 | # | / 120 | # - + 121 | # | / 122 | # X U 123 | # | \ | 124 | # - + - 125 | # | \ | 126 | # W Y 127 | # | \ / | 128 | # - + - 129 | # | / \ | 130 | # L M 131 | # """ 132 | # node_index = { 133 | # "V": 0, "Z": 1, "X": 2, "U": 3, "W": 4, "Y": 5, "dec": 6, "acc": 7} 134 | # index_node = {v: k for k, v in node_index.items()} 135 | # edge_index = [[0, 1, 2, 2, 3, 4, 4, 5, 5], 136 | # [2, 2, 4, 5, 5, 6, 7, 6, 7]] 137 | # EDGE_TYPE_HELPS, EDGE_TYPE_HURTS = 0, 1 138 | # def __init__(self, data) -> None: 139 | # super().__init__() 140 | # self.data = data 141 | 142 | # @staticmethod 143 | # def make_data_from_dict(graph_dict: dict, tokenizer, max_length=30): 144 | # igraph = InfluenceGraph(graph_dict) 145 | # if igraph.graph["Y_affects_outcome"] == "more": 146 | # # the final edges depend on para outcome 147 | # edge_features = [InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS] 148 | # else: 149 | # edge_features = [InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HURTS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HELPS, InfluenceGraphNNData.EDGE_TYPE_HURTS] 150 | 151 | # node_sentences = [] 152 | # for node in InfluenceGraphNNData.node_index: 153 | # if node in igraph.nodes_dict and len(igraph.nodes_dict[node]) > 0: 154 | # node_sentences.append(" [OR] ".join(igraph.nodes_dict[node])) 155 | # else: 156 | # node_sentences.append(tokenizer.pad_token) 157 | # encoding_dict = tokenizer(node_sentences, max_length=max_length, truncation=True) 158 | # return Data(graph_id = str(igraph.graph_id), num_nodes = len(InfluenceGraphNNData.node_index), tokens=encoding_dict["input_ids"], 159 | # edge_index=torch.tensor(InfluenceGraphNNData.edge_index).long(), edge_attr=torch.tensor(edge_features).long()) 160 | 161 | if __name__ == "__main__": 162 | import sys 163 | dm = GraphQaDataModule( 164 | basedir=sys.argv[1], model_name=sys.argv[2], batch_size=32) 165 | for (tokens, tokens_mask, labels) in dm.train_dataloader(): 166 | print(torch.tensor(tokens_mask[0].tokens).shape) 167 | -------------------------------------------------------------------------------- /model/qa_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_lightning.core.lightning import LightningModule 6 | from torch.optim import AdamW 7 | from transformers import AutoModel 8 | from transformers import BertConfig 9 | from transformers.models.bert.modeling_bert import BertPooler 10 | 11 | 12 | class GraphQaModel(LightningModule): 13 | def __init__(self, hparams): 14 | super().__init__() 15 | self.hparams = hparams 16 | self.save_hyperparameters() 17 | config = BertConfig() 18 | #self.model = BertForSequenceClassification.from_pretrained(self.hparams.model_name, num_labels=self.hparams.n_class) 19 | self.model = AutoModel.from_pretrained(self.hparams.model_name) 20 | self.pooler = BertPooler(config) 21 | # self.attention = MultiheadedAttention(h_dim=self.hparams.h_dim, kqv_dim=self.hparams.kqv_dim, n_heads=self.hparams.n_heads) 22 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 23 | self.classifier = nn.Linear(config.hidden_size, self.hparams.n_class) 24 | self.loss = nn.CrossEntropyLoss() 25 | 26 | @staticmethod 27 | def add_model_specific_args(parent_parser): 28 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 29 | parser.add_argument("--min_lr", default=0, type=float, 30 | help="Minimum learning rate.") 31 | parser.add_argument("--h_dim", type=int, 32 | help="Size of the hidden dimension.", default=768) 33 | parser.add_argument("--n_heads", type=int, 34 | help="Number of attention heads.", default=1) 35 | parser.add_argument("--kqv_dim", type=int, 36 | help="Dimensionality of the each attention head.", default=256) 37 | parser.add_argument("--n_class", type=float, 38 | help="Number of classes.", default=3) 39 | parser.add_argument("--lr", default=5e-4, type=float, 40 | help="Initial learning rate.") 41 | parser.add_argument("--weight_decay", default=0.01, type=float, 42 | help="Weight decay rate.") 43 | parser.add_argument("--warmup_prop", default=0., type=float, 44 | help="Warmup proportion.") 45 | parser.add_argument("--num_relations", default=2, type=int, 46 | help="The number of relations (edge types) in your graph.") 47 | parser.add_argument( 48 | "--model_name", default='bert-base-uncased', help="Model to use.") 49 | return parser 50 | 51 | def configure_optimizers(self): 52 | return AdamW(self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.99), 53 | eps=1e-8) 54 | 55 | def forward(self, batch): 56 | question_tokens, question_type_ids, question_masks, labels = batch 57 | 58 | # step 1: encode the question/paragraph 59 | question_cls_embeddeding = self.forward_bert( 60 | input_ids=question_tokens, token_type_ids=question_type_ids, attention_mask=question_masks) 61 | 62 | 63 | logits = self.classifier(question_cls_embeddeding) 64 | predicted_labels = torch.argmax(logits, -1) 65 | acc = torch.true_divide( 66 | (predicted_labels == labels).sum(), labels.shape[0]) 67 | return logits, acc 68 | 69 | def forward_bert(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None): 70 | """Returns the pooled token from BERT 71 | """ 72 | outputs = self.model( 73 | input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True) 74 | hidden_states = outputs["hidden_states"] 75 | cls_embeddeding = self.dropout(self.pooler(hidden_states[-1])) 76 | return cls_embeddeding 77 | 78 | def training_step(self, batch, batch_idx): 79 | # Load the data into variables 80 | logits, acc = self(batch) 81 | loss = self.loss(logits, batch[-1]) 82 | self.log('train_acc', acc, on_step=True, 83 | on_epoch=True, prog_bar=True, sync_dist=True) 84 | return {"loss": loss} 85 | 86 | 87 | def validation_step(self, batch, batch_idx): 88 | # Load the data into variables 89 | logits, acc = self(batch) 90 | 91 | loss_f = nn.CrossEntropyLoss() 92 | loss = loss_f(logits, batch[-1]) 93 | 94 | self.log('val_loss', loss, on_step=True, 95 | on_epoch=True, prog_bar=True, sync_dist=True) 96 | self.log('val_acc', acc, on_step=True, on_epoch=True, 97 | prog_bar=True, sync_dist=True) 98 | return {"loss": loss} 99 | 100 | def test_step(self, batch, batch_idx): 101 | # Load the data into variables 102 | logits, acc = self(batch) 103 | 104 | loss_f = nn.CrossEntropyLoss() 105 | loss = loss_f(logits, batch[-1]) 106 | return {"loss": loss} 107 | 108 | def get_progress_bar_dict(self): 109 | tqdm_dict = super().get_progress_bar_dict() 110 | tqdm_dict.pop("v_num", None) 111 | tqdm_dict.pop("val_loss_step", None) 112 | tqdm_dict.pop("val_acc_step", None) 113 | return tqdm_dict 114 | 115 | 116 | if __name__ == "__main__": 117 | sentences = ['This framework generates embeddings for each input sentence', 118 | 'Sentences are passed as a list of string.', 119 | 'The quick brown fox jumps over the lazy dog.'] 120 | -------------------------------------------------------------------------------- /model/requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | # gnn_qa/data.py: 13 4 | # gnn_qa/infer.py: 4 5 | # gnn_qa/qa_model.py: 11 6 | # gnn_qa/run.py: 8 7 | numpy == 1.16.6 8 | 9 | # gnn_qa/data.py: 11 10 | pandas == 1.0.3 11 | 12 | # gnn_qa/data.py: 5 13 | # gnn_qa/ignn.py: 10 14 | # gnn_qa/infer.py: 5 15 | # gnn_qa/qa_model.py: 7 16 | # gnn_qa/run.py: 6,9,12,13 17 | pytorch_lightning == 1.1.2 18 | 19 | # gnn_qa/data.py: 6,7,9,10 20 | # gnn_qa/ignn.py: 3,5 21 | # gnn_qa/infer.py: 1 22 | # gnn_qa/qa_model.py: 4,8,9,12 23 | # gnn_qa/run.py: 3 24 | # gnn_qa/utils.py: 1,2 25 | torch == 1.6.0 26 | 27 | # gnn_qa/ignn.py: 4,6,9 28 | # gnn_qa/qa_model.py: 2,5,13 29 | torch_geometric == 1.6.3 30 | 31 | # gnn_qa/data.py: 17 32 | # gnn_qa/infer.py: 7 33 | tqdm == 4.48.0 34 | 35 | # gnn_qa/data.py: 14,16 36 | # gnn_qa/qa_model.py: 3,15 37 | transformers == 4.1.1 38 | -------------------------------------------------------------------------------- /model/run.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning 2 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 3 | import random 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import logging 7 | from argparse import ArgumentParser 8 | import resource 9 | from model.data import GraphQaDataModule 10 | from model.qa_model import GraphQaModel 11 | 12 | def get_train_steps(dm): 13 | total_devices = args.num_gpus * args.num_nodes 14 | train_batches = len(dm.train_dataloader()) // total_devices 15 | return (args.max_epochs * train_batches) // args.accumulate_grad_batches 16 | 17 | 18 | 19 | 20 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 21 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 22 | # init: important to make sure every node initializes the same weights 23 | SEED = 43 24 | np.random.seed(SEED) 25 | random.seed(SEED) 26 | pl.utilities.seed.seed_everything(SEED) 27 | pytorch_lightning.seed_everything(SEED) 28 | 29 | 30 | # argparser 31 | parser = ArgumentParser() 32 | parser.add_argument('--num_gpus', type=int) 33 | parser.add_argument('--batch_size', type=int, default=16) 34 | parser.add_argument('--clip_grad', type=float, default=1.0) 35 | parser.add_argument("--dataset_basedir", help="Base directory where the dataset is located.", type=str) 36 | 37 | parser = pl.Trainer.add_argparse_args(parser) 38 | parser = GraphQaModel.add_model_specific_args(parser) 39 | 40 | args = parser.parse_args() 41 | args.num_gpus = len(str(args.gpus).split(",")) 42 | 43 | 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | # Step 1: Init Data 47 | logging.info("Loading the data module") 48 | dm = GraphQaDataModule(basedir=args.dataset_basedir, tokenizer_name=args.model_name, batch_size=args.batch_size) 49 | 50 | # Step 2: Init Model 51 | logging.info("Initializing the model") 52 | model = GraphQaModel(hparams=args) 53 | model.hparams.warmup_steps = int(get_train_steps(dm) * model.hparams.warmup_prop) 54 | lr_monitor = LearningRateMonitor(logging_interval='step') 55 | 56 | # Step 3: Start 57 | logging.info("Starting the training") 58 | checkpoint_callback = ModelCheckpoint( 59 | filename='{epoch}-{step}-{val_acc_epoch:.2f}', 60 | save_top_k=3, 61 | verbose=True, 62 | monitor='val_acc_epoch', 63 | mode='max' 64 | ) 65 | 66 | trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], val_check_interval=0.5, gradient_clip_val=args.clip_grad, track_grad_norm=2) 67 | trainer.fit(model, dm) -------------------------------------------------------------------------------- /model/run_wiqa_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export TOKENIZERS_PARALLELISM=false 3 | python model/run.py --dataset_basedir data/wiqa-qa/ \ 4 | --lr 2e-5 --max_epochs 20 \ 5 | --gpus 2 \ 6 | --accelerator ddp -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | requests 3 | pytest 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/__init__.py -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/eval/__init__.py -------------------------------------------------------------------------------- /src/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Any 2 | 3 | from src.third_party_utils.nltk_porter_stemmer import PorterStemmer 4 | from src.third_party_utils.spacy_stop_words import STOP_WORDS 5 | 6 | 7 | def get_label_from_question_metadata(correct_answer_key: str, 8 | question_dict: List[Dict[str, str]]) -> str: 9 | for item in question_dict: 10 | try: 11 | if correct_answer_key == item['label']: 12 | return item['text'] 13 | except KeyError as exc: 14 | raise KeyError(f"key='label' or 'text' absent in item:\n{item}.\nException = {exc}") 15 | raise KeyError(f"key='label' absent in metadata:\n{question_dict}") 16 | 17 | 18 | def split_question_cause_effect(question: str) -> Tuple[str, str]: 19 | question = question.lower() 20 | question_split = question.split("happens, how will it affect") 21 | cause_part = question_split[0].replace("suppose", "") 22 | effect_part = question_split[1] 23 | return cause_part, effect_part 24 | 25 | 26 | def get_most_similar_idx_word_overlap(p1s: List[str], p2:str): 27 | if not p1s or not p2: 28 | return -1 29 | max_idx = -1 30 | max_sim = 0.0 31 | k1 = set(get_content_words(p2)) 32 | for idx, p1 in enumerate(p1s): 33 | k2 = set(get_content_words(p1)) 34 | sim = 1.0 * len(k1.intersection(k2)) / (1.0 * (len(k1.union(k2)))) 35 | if sim > max_sim: 36 | max_sim = sim 37 | max_idx = idx 38 | return max_idx 39 | 40 | def predict_word_overlap_best(input_steps, input_cq, input_eq): 41 | xsentid = get_most_similar_idx_word_overlap(p1s=input_steps, p2=input_cq) 42 | ysentid = get_most_similar_idx_word_overlap(p1s=input_steps, p2=input_eq) 43 | return xsentid, ysentid 44 | 45 | 46 | def is_stop(cand): 47 | return cand.lower() in STOP_WORDS 48 | 49 | # your other hand => hand 50 | # the hand => hand 51 | def drop_leading_articles_and_stopwords(p): 52 | # other and another can only appear after the primary articles in first line. 53 | articles = ["a ", "an ", "the ", "your ", "his ", "their ", "my ", "this ", "that ", 54 | "another ", "other ", "more ", "less "] 55 | for article in articles: 56 | if p.lower().startswith(article): 57 | p = p[len(article):] 58 | words = p.split(" ") 59 | answer = "" 60 | for idx, w in enumerate(words): 61 | if is_stop(w): 62 | continue 63 | else: 64 | answer = " ".join(words[idx:]) 65 | break 66 | return answer 67 | 68 | 69 | def stem(w: str): 70 | if not w or len(w.strip()) == 0: 71 | return "" 72 | w_lower = w.lower() 73 | # Remove leading articles from the phrase (e.g., the rays => rays). 74 | # FIXME: change this logic to accept a list of leading articles. 75 | if w_lower.startswith("a "): 76 | w_lower = w_lower[2:] 77 | elif w_lower.startswith("an "): 78 | w_lower = w_lower[3:] 79 | elif w_lower.startswith("the "): 80 | w_lower = w_lower[4:] 81 | elif w_lower.startswith("your "): 82 | w_lower = w_lower[5:] 83 | elif w_lower.startswith("his "): 84 | w_lower = w_lower[4:] 85 | elif w_lower.startswith("their "): 86 | w_lower = w_lower[6:] 87 | elif w_lower.startswith("my "): 88 | w_lower = w_lower[3:] 89 | elif w_lower.startswith("another "): 90 | w_lower = w_lower[8:] 91 | elif w_lower.startswith("other "): 92 | w_lower = w_lower[6:] 93 | elif w_lower.startswith("this "): 94 | w_lower = w_lower[5:] 95 | elif w_lower.startswith("that "): 96 | w_lower = w_lower[5:] 97 | # Porter stemmer: rays => ray 98 | return PorterStemmer().stem(w_lower).strip() 99 | 100 | 101 | def stem_words(words): 102 | return [stem(w) for w in words] 103 | 104 | 105 | def get_content_words(s): 106 | para_words_prev = s.strip().lower().split(" ") 107 | para_words = set() 108 | for word in para_words_prev: 109 | if not is_stop(word): 110 | para_words.add(word) 111 | return stem_words(list(para_words)) 112 | 113 | 114 | def find_max(input_list: List[float]) -> Tuple[Any, int]: 115 | max_item = max(input_list) 116 | return max_item, input_list.index(max_item) 117 | -------------------------------------------------------------------------------- /src/eval/evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import enum 3 | import json 4 | from typing import List, Dict 5 | 6 | from src.helpers.dataset_info import wiqa_explanations_v1 7 | from src.wiqa_wrapper import WIQAQuesType, download_from_url_if_not_in_cache, Jsonl, WIQAExplanation, \ 8 | WIQAExplanationType 9 | 10 | 11 | class FineEvalMetrics(enum.Enum): 12 | EQDIR = 0 13 | XSENTID = 1 14 | YSENTID = 2 15 | XDIR = 3 16 | 17 | def to_json(self): 18 | return self.name 19 | 20 | 21 | class Precision: 22 | def __init__(self): 23 | self.tp = 0 24 | self.fp = 0 25 | 26 | def __str__(self): 27 | if self.tp + self.fp == 0: 28 | return "0" 29 | else: 30 | return self.tp / (self.tp + self.fp) 31 | 32 | def p(self): 33 | if self.tp + self.fp == 0.0: 34 | return 0.0 35 | else: 36 | return 1.0 * self.tp / (1.0 * (self.tp + self.fp)) 37 | 38 | 39 | class InputRequiredByFineEval: 40 | def __init__(self, 41 | graph_id: str, 42 | ques_type: WIQAQuesType, 43 | path_arr: List[str], 44 | metrics: Dict[FineEvalMetrics, bool]): 45 | self.graph_id = graph_id 46 | self.path_arr = path_arr 47 | self.ques_type = ques_type 48 | self.metrics = metrics 49 | 50 | def get_path_len(self): 51 | return len(self.path_arr) 52 | 53 | def __str__(self): 54 | return f"graphid={self.graph_id}\nques_type={self.ques_type.name}" \ 55 | f"path_arr={self.path_arr}\nmetrics={self.metrics}" 56 | 57 | # {"logits": [[0.7365949749946594, 0.16483880579471588, 0.38572263717651367, 0.33268705010414124, -0.6508089900016785, -0.950057864189148], [0.3361547291278839, -0.02215845137834549, 0.8630506992340088, 0.4753769040107727, -1.0981523990631104, 0.04984292760491371], [0.4498467743396759, 0.26323091983795166, 0.5597160458564758, 0.06369128823280334, -0.33793506026268005, -0.30190590023994446], [0.41394802927970886, 0.31742218136787415, 0.42982375621795654, -0.2891058027744293, -0.09577881544828415, -0.4486318528652191], [0.5242481231689453, -0.05186435207724571, 0.4505387544631958, -0.43092456459999084, -0.015227549709379673, 0.10361793637275696], [0.8527745604515076, 0.18845966458320618, 0.6540948748588562, -0.06324845552444458, -0.03267676383256912, 0.058296892791986465], [0.40418609976768494, -0.24220454692840576, 0.0737631767988205, -0.8445389270782471, -0.12929767370224, 0.5813987851142883]], "class_probabilities": [[0.29663264751434326, 0.16745895147323608, 0.20885121822357178, 0.19806326925754547, 0.07407592236995697, 0.054918017238378525], [0.18079228699207306, 0.12634745240211487, 0.3062019348144531, 0.20779892802238464, 0.04307926073670387, 0.13578014075756073], [0.21968600153923035, 0.18228720128536224, 0.24519862234592438, 0.14931286871433258, 0.09992476552724838, 0.10359060019254684], [0.22513438761234283, 0.20441898703575134, 0.22873708605766296, 0.11145754158496857, 0.13522914052009583, 0.09502287209033966], [0.24298663437366486, 0.13657772541046143, 0.2257203906774521, 0.09348806738853455, 0.14167429506778717, 0.15955300629138947], [0.27786338329315186, 0.1429957151412964, 0.22779585421085358, 0.11117511987686157, 0.11462641507387161, 0.12554346024990082], [0.23202574253082275, 0.1215660497546196, 0.16673828661441803, 0.06656130403280258, 0.1360965520143509, 0.27701207995414734]], "metadata": {"question": {"question": {"stem": "suppose during boiling point happens, how will it affect more evaporation.", "choices": [{"text": "Correct effect", "label": "A"}, {"text": "Opposite effect", "label": "B"}, {"text": "No effect", "label": "C"}]}, "answerKey": "A", "explanation": " ['during boiling point', 'during sunshine'] ==> ['increase water temperatures at least 100 C'] ==> ['more vapors'] ==> ['MORE evaporation?']", "path_info": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A", "more_info": {"tf_q_type": "EXOGENOUS_EFFECT", "prompt": "Describe the process of evaporation", "para_id": "127", "group_ids": {"NO_GROUPING": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT", "BY_SRC_DEST": "Z,A", "BY_SRC_LABEL_DEST": "Z->SituationLabel.RESULTS_IN->A", "BY_PROMPT": "Describe the process of evaporation", "BY_PARA": "127", "BY_FULL_PATH": "[Z, X, Y, A]", "BY_GROUNDING": "influence_graph,127,12,39", "BY_SRC_DEST_INTRA": "12,Z,A", "BY_SRC_DEST_STEM_INTRA": "12,Z,A,In the context of describe the process of evaporation, suppose during boiling point happens, how will it affect MORE evaporation?.", "BY_TF_Q_TYPE": "EXOGENOUS_EFFECT", "BY_PATH_LENGTH": 4}, "all_q_keys": ["is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT", "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, W, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,38#0:tf_q_type^EXOGENOUS_EFFECT"]}, "para": "Water is exposed to heat energy, like sunlight. The water temperature is raised above 212 degrees fahrenheit. The heat breaks down the molecules in the water. These molecules escape from the water. The water becomes vapor. The vapor evaporates into the atmosphere. ", "graph_id": "12", "para_id": "127", "prompt": "Describe the process of evaporation", "id": "influence_graph:127:12:39#0", "distractor_info": {}, "primary_question_key": "is_distractor^False:is_labeled_tgt^True:path_nodes^[Z, X, Y, A]:path_label^Z->SituationLabel.RESULTS_IN->A:graph_id^12:id^influence_graph,127,12,39#0:tf_q_type^EXOGENOUS_EFFECT"}, "kg": [{"from_node": "increase water temperatures at least 100 C", "to_node": "more ice", "label": "NOT_RESULTS_IN"}, {"from_node": "increase water temperatures at least 100 C", "to_node": "less sweating", "label": "NOT_RESULTS_IN"}, {"from_node": "increase water temperatures at least 100 C", "to_node": "more vapors", "label": "RESULTS_IN"}, {"from_node": "low water temperatures at most 100 C", "to_node": "more vapors", "label": "NOT_RESULTS_IN"}, {"from_node": "less water molecule colliding", "to_node": "more vapors", "label": "NOT_RESULTS_IN"}, {"from_node": "more ice", "to_node": "MORE evaporation?", "label": "NOT_RESULTS_IN"}, {"from_node": "less sweating", "to_node": "MORE evaporation?", "label": "NOT_RESULTS_IN"}, {"from_node": "more ice", "to_node": "LESS evaporation", "label": "RESULTS_IN"}, {"from_node": "less sweating", "to_node": "LESS evaporation", "label": "RESULTS_IN"}, {"from_node": "more vapors", "to_node": "MORE evaporation?", "label": "RESULTS_IN"}, {"from_node": "more vapors", "to_node": "LESS evaporation", "label": "NOT_RESULTS_IN"}], "explanation": {"steps": "(1. Water is exposed to heat energy, like sunlight. 2. The water temperature is raised above 212 degrees fahrenheit. 3. The heat breaks down the molecules in the water. 4. These molecules escape from the water. 5. The water becomes vapor. 6. The vapor evaporates into the atmosphere. 7. .)", "x_sent_id": 1, "y_sent_id": 4, "x_dir": "RESULTS_IN", "y_dir": "RESULTS_IN", "x_grounding": "increase water temperatures at least 100 C", "y_grounding": "more vapors", "is_valid_tuple": true, "eq_dir_orig_answer": "RESULTS_IN"}, "id": "NA"}, "tags": ["O", "E+", "E+", "E+", "O", "O", "E-"]} 58 | @staticmethod 59 | def from_(prediction_on_this_example: WIQAExplanation, 60 | json_from_question: Dict, 61 | expl_type: WIQAExplanationType): 62 | # get expected answer and whether predicted and expected match. 63 | expected = WIQAExplanation.instantiate_from(json_from_question) 64 | 65 | metrics = {} 66 | metrics[FineEvalMetrics.EQDIR] = expected.de == prediction_on_this_example.de 67 | if expl_type == WIQAExplanationType.PARA_SENT_EXPL: 68 | metrics[FineEvalMetrics.XDIR] = expected.di == prediction_on_this_example.di 69 | metrics[FineEvalMetrics.XSENTID] = expected.i == prediction_on_this_example.i 70 | metrics[FineEvalMetrics.YSENTID] = expected.j == prediction_on_this_example.j 71 | return InputRequiredByFineEval(graph_id=json_from_question["metadata"]["graph_id"], 72 | ques_type=WIQAQuesType.from_str(json_from_question["metadata"]["question_type"]), 73 | path_arr=[], 74 | metrics=metrics) 75 | 76 | @staticmethod 77 | def accumulate_metrics(current_metrics: Dict[FineEvalMetrics, bool], 78 | all_metrics: Dict[FineEvalMetrics, Precision]): 79 | for k, is_correct in current_metrics.items(): 80 | if k not in all_metrics: 81 | all_metrics[k] = Precision() 82 | if is_correct: 83 | all_metrics[k].tp += 1 84 | else: 85 | all_metrics[k].fp += 1 86 | 87 | 88 | class MetricEvaluator: 89 | def __init__(self): 90 | self.entries: List[InputRequiredByFineEval] = [] 91 | 92 | def add_entry(self, entry: InputRequiredByFineEval): 93 | self.entries.append(entry) 94 | 95 | def group_by_ques_type(self): 96 | per_ques_type: Dict[WIQAQuesType, Dict[FineEvalMetrics, Precision]] = {} 97 | for e in self.entries: 98 | if e.ques_type not in per_ques_type: 99 | per_ques_type[e.ques_type] = {} 100 | InputRequiredByFineEval.accumulate_metrics(current_metrics=e.metrics, 101 | all_metrics=per_ques_type[e.ques_type]) 102 | return per_ques_type 103 | 104 | def group_by_path_len(self): 105 | per_path_len: Dict[int, Dict[FineEvalMetrics, Precision]] = {} 106 | for e in self.entries: 107 | if e.get_path_len() not in per_path_len: 108 | per_path_len[e.get_path_len()] = {} 109 | InputRequiredByFineEval.accumulate_metrics(current_metrics=e.metrics, 110 | all_metrics=per_path_len[e.get_path_len()]) 111 | return per_path_len 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description='Analyze BERT results per question and per question type', 116 | usage="\n\npython src/dataset_creation_of_wiqa/analysis/fine_eval_on_wiqa_models.py" 117 | "\n\t --pred_path dir_path/eval_test.json" 118 | "\n\t --out_path /tmp/bert_anal" 119 | "\n\t --from_model tagging_model" 120 | "\n\t --group_by questype" 121 | ) 122 | 123 | # ------------------------------------------------ 124 | # Mandatory arguments. 125 | # ------------------------------------------------ 126 | 127 | parser.add_argument('--pred_path', 128 | action='store', 129 | dest='pred_path', 130 | required=True, 131 | help='File path containing predictors json output.') 132 | 133 | parser.add_argument('--group_by', 134 | action='store', 135 | dest='group_by', 136 | required=True, 137 | help='questype|pathlen') 138 | 139 | parser.add_argument('--from_model', 140 | action='store', 141 | dest='from_model', 142 | required=True, 143 | help='your_model_name|no_explanation_model') 144 | 145 | parser.add_argument('--out_path', 146 | action='store', 147 | dest='out_path', 148 | required=True, 149 | help='File path to store output such as metrics.json') 150 | 151 | args = parser.parse_args() 152 | m = MetricEvaluator() 153 | 154 | # Compile input for metrics. 155 | if args.from_model == "wordoverlap_baseline_model": 156 | args.pred_path = download_from_url_if_not_in_cache(cloud_path=wiqa_explanations_v1.cloud_path + "test.jsonl") 157 | if args.from_model =="vectoroverlap_baseline_model": 158 | raise NotImplementedError 159 | 160 | if args.from_model == "emnlp19_model": 161 | map_of_expected_keys = {} 162 | for x in Jsonl.load(download_from_url_if_not_in_cache(wiqa_explanations_v1.cloud_path + "test.jsonl")): 163 | key = x["question"]["id"] 164 | value = {"graph_id": x["question"]["graph_id"], 165 | "path_arr": x["question"]["path_info"], 166 | "qtype": x["question"]["more_info"]["tf_q_type"]} 167 | map_of_expected_keys[key] = value 168 | 169 | outfile = open(args.out_path, 'w') 170 | with open(args.pred_path) as infile: 171 | for line in infile: 172 | j = json.loads(line) 173 | if args.from_model == "no_explanation_model": 174 | entry = InputRequiredByFineEval.from_(expl_type=WIQAExplanationType.NO_EXPL, 175 | json_from_question=j["orig_question"], 176 | prediction_on_this_example=WIQAExplanation.instantiate_from(j) 177 | ) 178 | else: 179 | raise NotImplementedError(f"fine_eval script does not support model: {args.from_model}, " 180 | f"implement your own here in evaluation.py") 181 | m.add_entry(entry=entry) 182 | 183 | # Compute metrics. 184 | if args.group_by == "pathlen": 185 | metrics = m.group_by_path_len() 186 | elif args.group_by == "questype": 187 | metrics = m.group_by_ques_type() 188 | else: 189 | raise NotImplementedError(f"fine_eval script does not support group by: {args.group_by}") 190 | 191 | # Write metrics to file. 192 | overall = {} 193 | for k, v_dict in metrics.items(): 194 | for k2, v2 in v_dict.items(): 195 | if k2 not in overall: 196 | overall[k2] = Precision() 197 | overall[k2].tp += v2.tp 198 | overall[k2].fp += v2.fp 199 | outfile.write(f"{k.name}_{k2.name}:{v2.p():0.4}\n") 200 | 201 | outlines = [] 202 | for k, v in overall.items(): 203 | outline = f"{k.name}_overall:{v.p():0.4}" 204 | outlines.append(outline) 205 | outfile.write(f"{outline}\n") 206 | print(f"\nOutput is in {args.out_path}") 207 | print("\n".join(outlines)) 208 | outfile.close() 209 | -------------------------------------------------------------------------------- /src/helpers/ProparaExtendedPara.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | from typing import List, Dict 4 | 5 | from src.helpers.dataset_info import propara_para_info, download_from_url_if_not_in_cache 6 | 7 | 8 | class ProparaExtendedParaEntry: 9 | def __init__(self, topic: str, prompt: str, paraid: str, 10 | s1: str, s2: str, s3: str, 11 | s4: str, s5: str, s6: str, 12 | s7: str, s8: str, s9: str, 13 | s10: str): 14 | self.topic = topic 15 | self.prompt = prompt 16 | self.paraid = paraid 17 | all_sentences = [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] 18 | self.sentences = [x for x in all_sentences if x] 19 | 20 | def sent_at(self, sentidx_startingzero: int): 21 | assert len(self.sentences) > sentidx_startingzero 22 | return self.sentences[sentidx_startingzero] 23 | 24 | def get_sentence_arr(self): 25 | return self.sentences 26 | 27 | def as_json(self): 28 | return json.dumps(self.__dict__) 29 | 30 | @staticmethod 31 | def from_json(j): 32 | raise NotImplementedError("from json is not done yet.") 33 | 34 | @staticmethod 35 | def from_tsv(t): 36 | return ProparaExtendedParaEntry(*t) 37 | 38 | 39 | class ProparaExtendedParaMetadata: 40 | 41 | def __init__(self, 42 | extended_propara_para_fp=download_from_url_if_not_in_cache( 43 | propara_para_info.cloud_path)): 44 | self.para_map: Dict[str, ProparaExtendedParaEntry] = dict() 45 | for row_num, row_as_arr in enumerate(csv.reader(open(extended_propara_para_fp), delimiter="\t")): 46 | if row_num > 0: # skips header 47 | e: ProparaExtendedParaEntry = ProparaExtendedParaEntry.from_tsv(row_as_arr) 48 | self.para_map[e.paraid] = e 49 | 50 | def paraentry_for_id(self, para_id: str) -> ProparaExtendedParaEntry: 51 | return self.para_map.get(para_id, None) 52 | 53 | def sentences_for_paraid(self, paraid: str) -> List[str]: 54 | return self.paraentry_for_id(para_id=paraid).get_sentence_arr() 55 | 56 | def sent_at(self, paraid: str, sentidx_startingzero: int) -> str: 57 | return self.paraentry_for_id(para_id=paraid).sent_at(sentidx_startingzero=sentidx_startingzero) 58 | 59 | 60 | if __name__ == '__main__': 61 | o = ProparaExtendedParaMetadata() 62 | assert o.sent_at(paraid="2453", 63 | sentidx_startingzero=1) == "The heat from the reaction is used to create steam in water." 64 | assert o.sent_at(paraid="24", 65 | sentidx_startingzero=2) == "Depending on what is eroding the valley, the slope of the land, the type of rock or soil and the amount of time the land has been eroded." 66 | assert not o.sent_at(paraid="24", 67 | sentidx_startingzero=2) == "The heat from the reaction is used to create steam in water." 68 | -------------------------------------------------------------------------------- /src/helpers/SGFileLoaders.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from random import sample 4 | 5 | 6 | def compile_input_files(dir_or_file_path): 7 | input_is_directory = os.path.isdir(dir_or_file_path) 8 | input_files = [] 9 | if input_is_directory: 10 | input_dir = os.fsencode(dir_or_file_path) 11 | for infile_bytename in os.listdir(input_dir): 12 | infile_fullpath = os.path.join( 13 | input_dir.decode("utf-8"), infile_bytename.decode("utf-8")) 14 | input_files.append(infile_fullpath) 15 | else: 16 | input_files.append(dir_or_file_path) 17 | return input_files 18 | 19 | 20 | def load_grpkey_to_qkeys(path, meta_info): 21 | ''' 22 | { 23 | "group_key":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 24 | "question_keys":[ 25 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 26 | ] 27 | } 28 | :param path: 29 | :param meta_info: 30 | :return: 31 | ''' 32 | grpkey_to_qkeys = {} 33 | for in_fp in compile_input_files(path): 34 | with open(in_fp, 'r') as infile: 35 | for line in infile: 36 | j = json.loads(line) 37 | grpkey_to_qkeys[j["group_key"]] = j["question_keys"] 38 | return grpkey_to_qkeys 39 | 40 | 41 | def load_qkey_to_grpkey(path, meta_info): 42 | ''' 43 | { 44 | "group_key":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 45 | "question_keys":[ 46 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 47 | ] 48 | } 49 | :param path: 50 | :param meta_info: 51 | :return: 52 | ''' 53 | qkey_to_grpkey = {} 54 | for in_fp in compile_input_files(path): 55 | with open(in_fp, 'r') as infile: 56 | for line in infile: 57 | j = json.loads(line) 58 | grp_key = j["group_key"] 59 | for q_key in j["question_keys"]: 60 | qkey_to_grpkey[q_key] = grp_key 61 | return qkey_to_grpkey 62 | 63 | 64 | def load_qkey_to_qjson(path, meta_info): 65 | ''' 66 | { 67 | "question":{ 68 | "stem":"In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released", 69 | "choices":[ 70 | { 71 | "text":"True", 72 | "label":"C" 73 | }, 74 | { 75 | "text":"False", 76 | "label":"D" 77 | } 78 | ] 79 | }, 80 | "answerKey":"C", 81 | "explanation":"less magma is formed=>more magma is released", 82 | "path_info":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X", 83 | "more_info":{ 84 | "tf_q_type":"EXOGENOUS_EFFECT", 85 | "prompt":"How does igneous rock form?", 86 | "para_id":"34", 87 | "group_ids":{ 88 | "NO_GROUPING":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT", 89 | "BY_SRC_DEST":"V,X", 90 | "BY_SRC_LABEL_DEST":"V->SituationLabel.NOT_RESULTS_IN->X", 91 | "BY_PROMPT":"How does igneous rock form?", 92 | "BY_PARA":"34", 93 | "BY_FULL_PATH":"[V, X]", 94 | "BY_GROUNDING":"34,1,1,0", 95 | "BY_SRC_DEST_INTRA":"1,V,X", 96 | "BY_SRC_DEST_STEM_INTRA":"1,V,X,In the context of how does igneous rock form, suppose less magma is formed, it will not result in more magma is released" 97 | }, 98 | "all_q_keys":[ 99 | "is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 100 | ] 101 | }, 102 | "para":"Volcanos contain magma. The magma is very hot. The magma rises toward the surface of the volcano. The magma cools. The magma starts to harden as it cools. The magma is sometimes released from the volcano as lava. The magma or lava becomes a hard rock as it solidifies.", 103 | "graph_id":"1", 104 | "id":"34:1:1#0#0", 105 | "primary_question_key":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT" 106 | } 107 | :param path: 108 | :param meta_info: 109 | :return: 110 | ''' 111 | qkey_to_qjson = {} 112 | for in_fp in compile_input_files(path): 113 | with open(in_fp, 'r') as infile: 114 | for line in infile: 115 | j = json.loads(line) 116 | qkey_to_qjson[j["primary_question_key"]] = j 117 | return qkey_to_qjson 118 | 119 | 120 | def load_qkey_to_ans(path, meta_info): 121 | ''' 122 | { 123 | "id":"34:1:1#0#0", 124 | "primary_question_key":"is_distractor^False:is_labeled_tgt^False:path_nodes^[V, X]:path_label^V->SituationLabel.NOT_RESULTS_IN->X:graph_id^1:id^34,1,1#0#0:tf_q_type^EXOGENOUS_EFFECT", 125 | "answerKey":"C" 126 | } 127 | :param path: 128 | :param meta_info: 129 | :return: 130 | ''' 131 | qkey_to_ans = {} 132 | for in_fp in compile_input_files(path): 133 | with open(in_fp, 'r') as infile: 134 | for line in infile: 135 | j = json.loads(line) 136 | qkey_to_ans[j["primary_question_key"]] = j["answerKey"] 137 | return qkey_to_ans 138 | 139 | 140 | def load_allennlp_qkey_to_ans(path, qkey_to_qjson_map, meta_info): 141 | ''' 142 | { 143 | "id":"131", 144 | "question":"NA", 145 | "question_text":"It is observed that scavengers increase in number happens. In the context of How are fossils formed? , what is the least likely cause?", 146 | "choice_text_list":[ 147 | "habitat is destroyed", 148 | "humans hunt less animals", 149 | "", 150 | "" 151 | ], 152 | "correct_answer_index":0, 153 | "label_logits":[ 154 | -6.519074440002441, 155 | -6.476490497589111, 156 | -6.935580730438232, 157 | -6.935580730438232 158 | ], 159 | "label_probs":[ 160 | 0.29742464423179626, 161 | 0.31036368012428284, 162 | 0.19610585272312164, 163 | 0.19610585272312164 164 | ], 165 | "answer_index":1 166 | } 167 | :param meta_info: 168 | :return: 169 | ''' 170 | # step 1 : map allennlp_qkey_to_qkey 171 | # step 2 : existing functionality then. 172 | qkey_from_allennlp_key_map = {v["id"]: k for k, v in qkey_to_qjson_map.items()} 173 | qkey_to_ans = {} 174 | for in_fp in compile_input_files(path): 175 | with open(in_fp, 'r') as infile: 176 | for line in infile: 177 | j = json.loads(line) 178 | makeshift_id = j["id"] 179 | qkey_from_allennlp_key = qkey_from_allennlp_key_map[makeshift_id] 180 | qkey_to_ans[qkey_from_allennlp_key] = j["choice_text_list"][int(j["answer_index"])] 181 | print(f"Loaded {len(qkey_to_ans)} system answers from {path}") 182 | return qkey_to_ans 183 | 184 | 185 | def serialize_whole(out_file_path, items, randomize_order=False, header=""): 186 | print(f"Writing {len(items)} items (e.g., question jsons) to directory: {out_file_path}") 187 | curr_file = open(out_file_path, 'w') 188 | if header: 189 | curr_file.write(header) 190 | if "\n" not in header: 191 | curr_file.write("\n") 192 | 193 | items_randomized = items 194 | if randomize_order: 195 | items_randomized = sample(items, len(items)) 196 | for item_num, item in enumerate(items_randomized): 197 | curr_file.write(item) 198 | if "\n" not in item: 199 | curr_file.write("\n") 200 | 201 | curr_file.close() 202 | 203 | 204 | def serialize_in_pieces(out_dir_path, max_items_in_a_piece, items, randomize_order=False, header=""): 205 | print(f"Writing {len(items)} items (e.g., question jsons) to directory: {out_dir_path}") 206 | if not os.path.exists(out_dir_path): 207 | os.makedirs(out_dir_path) 208 | 209 | curr_file_num = 1 210 | curr_file = open(f"{out_dir_path}/{curr_file_num}.json", 'w') 211 | if header: 212 | curr_file.write(header) 213 | 214 | items_randomized = items 215 | if randomize_order: 216 | items_randomized = sample(items, len(items)) 217 | 218 | for item_num, item in enumerate(items_randomized): 219 | if item_num % max_items_in_a_piece == 0 and item_num > 1: 220 | # close the old file. 221 | curr_file.close() 222 | # open a new file. 223 | curr_file_num += 1 224 | curr_file = open(f"{out_dir_path}/{curr_file_num}.json", "w") 225 | if header: 226 | curr_file.write(header) 227 | if "\n" not in header: 228 | curr_file.write("\n") 229 | 230 | curr_file.write(item) 231 | if "\n" not in item: 232 | curr_file.write("\n") 233 | 234 | if not curr_file.closed: 235 | curr_file.close() 236 | -------------------------------------------------------------------------------- /src/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/wiqa-dataset/edeef924188a7caa4493c305209e0b20d20b375c/src/helpers/__init__.py -------------------------------------------------------------------------------- /src/helpers/collections_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def add_key_to_map_arr(key, value, map_): 5 | if key not in map_: 6 | map_[key] = [] 7 | map_[key].append(value) 8 | 9 | 10 | def getElem(arr: List, elem_idx:int, defaultValue): 11 | if not arr or len(arr) <= elem_idx: 12 | return defaultValue 13 | return arr[elem_idx] -------------------------------------------------------------------------------- /src/helpers/dataset_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the latest dataset related files 3 | for EMNLP2019 4 | """ 5 | 6 | from collections import namedtuple 7 | 8 | from src.third_party_utils.allennlp_cached_filepath import cached_path 9 | 10 | DatasetInfo = namedtuple('DatasetInfo', 11 | ['beaker_link', 'cloud_path', 'local_path', 'data_reader', 'metadata_info', 'readme']) 12 | 13 | # EMNLP dataset: "with explanations" in json format. 14 | wiqa_explanations_v1 = DatasetInfo( 15 | beaker_link="https://beaker.org/ds/ds_jsdpme8ixz86/", # an unpublished, baseline model. 16 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_dataset_with_explanation/", 17 | # [cloud_path + partition for partition in ["train.jsonl", "dev.jsonl", "test.jsonl"]] 18 | local_path="", 19 | data_reader="BertMCQAReaderSentAnnotated", # an unpublished, baseline model. 20 | metadata_info="WhatifMetadata", 21 | readme="" 22 | ) 23 | 24 | # EMNLP dataset: "no explanations" in json format. 25 | wiqa_no_explanations_vetted_v1 = DatasetInfo( 26 | beaker_link="https://beaker.org/ds/ds_zyukvhb9ezqa/", 27 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_dataset_no_explanation_v2/", 28 | # [cloud_path + partition for partition in ["train.jsonl", "dev.jsonl", "test.jsonl"]] 29 | local_path="", 30 | data_reader="BertMCQAReaderPara", 31 | metadata_info="WhatifMetadata", 32 | readme="" 33 | ) 34 | 35 | # All the turked influence graphs in json format. 36 | influence_graphs_v1 = DatasetInfo( 37 | beaker_link="", 38 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa_influence_graphs.jsonl", 39 | local_path="", 40 | data_reader="SituationGraph", 41 | metadata_info="WhatifMetadata", 42 | readme="All influence graphs" 43 | ) 44 | 45 | # Train/dev/test in ProPara dataset is partitioned by topics. Every topic in test is therefore, novel. 46 | # In the WIQA dataset, we continue partitioning by topic. The following data contains para id and and partition. 47 | para_partition_info = DatasetInfo( 48 | beaker_link="", 49 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/metadata/para_id.prompt.topic.partition.tsv", 50 | local_path="", 51 | data_reader="", 52 | metadata_info="WhatifMetadata", 53 | readme="Para id, topic, and the partition type" 54 | ) 55 | 56 | # The ProPara dataset (includes paragraph id and paragraph sentences (incl. metadata: paragraph prompt/title & topic) 57 | propara_para_info = DatasetInfo( 58 | beaker_link="", 59 | cloud_path="https://public-aristo-processes.s3-us-west-2.amazonaws.com/metadata/propara-extended-para.tsv", 60 | local_path="", 61 | data_reader="", 62 | metadata_info="WhatifMetadata", 63 | readme="Para id, topic, and the sentences" 64 | ) 65 | 66 | 67 | def download_from_url_if_not_in_cache(cloud_path: str, cache_dir: str = None): 68 | """ 69 | :param cloud_path: e.g., https://public-aristo-processes.s3-us-west-2.amazonaws.com/wiqa-model.tar.gz 70 | :param to_dir: will be regarded as a cache. 71 | :return: the path of file to which the file is downloaded. 72 | """ 73 | return cached_path(url_or_filename=cloud_path, cache_dir=cache_dir) 74 | -------------------------------------------------------------------------------- /src/helpers/situation_graph.py: -------------------------------------------------------------------------------- 1 | # Tasks: 2 | # input = graph in some format 3 | # output1 = graph in json format 4 | # output2 = tf_ques_given_a_json_record 5 | # output3 = generate explanation candidates for_a_tf_ques 6 | 7 | import copy 8 | import enum 9 | import json 10 | from typing import List 11 | 12 | 13 | def strip_special_char(a_string): 14 | return "".join([x for x in a_string if (ord('a') <= ord(x) <= ord('z')) or (ord('A') <= ord(x) <= ord('Z'))]).lower() 15 | 16 | 17 | class SituationNode: 18 | def __init__(self, node_id: str, 19 | the_groundings: [], 20 | is_decision_node: bool = False, 21 | node_semantics: str = ""): 22 | ''' 23 | :param node_id e.g., "VX" 24 | :param node_semantics e.g., for A/D node, semantics is "accelerates", "decelerates"; 25 | or, x/y node can be "causal" or u,v,w,z "indirect". 26 | :param the_groundings: is an array of groundings. 27 | :param is_decision_node: e.g., Accelerates or Decelerates nodes are decision nodes currently 28 | ''' 29 | self.id = node_id 30 | self.node_semantics = node_semantics 31 | self.groundings = [t for t in the_groundings] # we need a copy and not a reference of the supplied array. 32 | self.is_decision_node = is_decision_node 33 | 34 | def join_groundings(self, separator): 35 | if not self.groundings: 36 | return "" 37 | return separator.join(self.groundings) 38 | 39 | def remove_grounding(self, the_grounding_to_remove): 40 | try: 41 | self.groundings.remove(the_grounding_to_remove) 42 | except ValueError: 43 | # Continue, do not stop 44 | print(f"Warning: Removal of a non-existent grounding: '{the_grounding_to_remove}' from node {self.id}") 45 | 46 | def get_grounding_pairs(self): 47 | if not self.groundings or len(self.groundings) < 2: 48 | return [] 49 | return zip(*[self.groundings[i:] for i in range(2)]) 50 | 51 | def is_empty(self): 52 | return not self.groundings or len(self.groundings) == 0 53 | 54 | def __repr__(self): 55 | return self.id 56 | 57 | def __eq__(self, other): 58 | return self.id == other.id 59 | 60 | def __hash__(self): 61 | return hash(self.id) 62 | 63 | def get_specific_grounding(self, specific_grounding=None): 64 | ''' 65 | :param specific_grounding: if set to None, then all groundings are used {source_node.groundings} otherwise the string supplied. 66 | :return: 67 | ''' 68 | return f"{self.groundings}" if not specific_grounding else specific_grounding 69 | 70 | def get_first_grounding(self): 71 | return "" if not self.groundings or len(self.groundings) < 1 else self.groundings[0] 72 | 73 | 74 | class SituationLabel(str, enum.Enum): 75 | NOT_RESULTS_IN ="NOT_RESULTS_IN" 76 | RESULTS_IN = "RESULTS_IN" 77 | NO_EFFECT = "NO_EFFECT" 78 | MARKED_NOISE = "MARKED_NOISE" 79 | 80 | def get_sign(self): 81 | if self == SituationLabel.RESULTS_IN: 82 | return '+' 83 | elif self == SituationLabel.NOT_RESULTS_IN: 84 | return '-' 85 | else: 86 | return '.' 87 | 88 | def get_sign_str(self): 89 | if self == SituationLabel.RESULTS_IN: 90 | return 'MORE' 91 | elif self == SituationLabel.NOT_RESULTS_IN: 92 | return 'LESS' 93 | else: 94 | return 'NOEFFECT' 95 | 96 | def as_less_more(self): 97 | if self == SituationLabel.RESULTS_IN: 98 | return 'more' 99 | elif self == SituationLabel.NOT_RESULTS_IN: 100 | return 'less' 101 | else: 102 | return 'no_effect' 103 | 104 | def get_nickname(self): 105 | if self == SituationLabel.RESULTS_IN: 106 | return 'a' 107 | elif self == SituationLabel.NOT_RESULTS_IN: 108 | return 'd' 109 | else: 110 | return '-' 111 | 112 | def get_opposite_label(self): 113 | if self == SituationLabel.RESULTS_IN: 114 | return SituationLabel.NOT_RESULTS_IN 115 | elif self == SituationLabel.NOT_RESULTS_IN: 116 | return SituationLabel.RESULTS_IN 117 | else: 118 | return self 119 | 120 | def get_emnlp_test_choice(self): 121 | if self == SituationLabel.RESULTS_IN: 122 | return 'A' 123 | elif self == SituationLabel.NOT_RESULTS_IN: 124 | return 'B' 125 | else: 126 | return 'C' 127 | 128 | @staticmethod 129 | def from_str(sl): 130 | if not sl: 131 | raise ValueError( 132 | f"({sl}) is not a valid Enum SituationLabel") 133 | sl = sl.lower().replace('_', ' ').strip() 134 | 135 | if sl in ['-', 'not results in', 'opposite', 'opp', 'results in opp', 'opp effect', 'opposite effect', 'less']: 136 | return SituationLabel.NOT_RESULTS_IN 137 | elif sl in ['+', 'positive', 'results in', 'correct', 'correct effect', 'more']: 138 | return SituationLabel.RESULTS_IN 139 | elif sl in ['.', 'none', 'no effect']: 140 | return SituationLabel.NO_EFFECT 141 | else: 142 | print(f"WARNING: ({sl}) is not a valid Enum SituationLabel") 143 | return SituationLabel.MARKED_NOISE 144 | 145 | def to_readable_str(self): 146 | if self == SituationLabel.RESULTS_IN: 147 | return 'RESULTS_IN' 148 | elif self == SituationLabel.NOT_RESULTS_IN: 149 | return 'RESULTS_IN_OPP' 150 | else: 151 | return 'NO_EFFECT' 152 | 153 | def to_json(self): 154 | return self.name 155 | 156 | 157 | class SituationEdge: 158 | def __init__(self, from_node: SituationNode, to_node: SituationNode, label: SituationLabel, wt=1.0): 159 | self.from_node = from_node 160 | self.to_node = to_node 161 | self.wt = wt 162 | self.label = label 163 | self.id = (from_node.id, to_node.id) 164 | self.is_noise = False 165 | 166 | def mark_noise(self): 167 | # If marked as noise, then we will drop this edge completely 168 | self.is_noise = True 169 | 170 | def reset_label_to(self, new_label: SituationLabel): 171 | # If marked as noise, then we will drop this edge completely 172 | self.label = new_label 173 | 174 | def ground_edge(self): 175 | grounded_edges = [] 176 | # node A (A1, A2) => node B (B1, B2) 177 | for s1 in self.from_node.groundings: 178 | for s2 in self.to_node.groundings: 179 | grounded_edges.append({"from_node": s1, "to_node": s2, "label": self.label.name}) 180 | return grounded_edges 181 | 182 | def __repr__(self): 183 | return self.from_node.id + "-" + self.to_node.id 184 | 185 | 186 | class SituationGraph: 187 | # graph structure is a json of edges and their labels. 188 | # e.g., [(V,X,not) 189 | def __init__(self, situation_nodes: List[SituationNode], situation_edges: List[SituationEdge], 190 | other_properties: dict): 191 | ''' 192 | :param situation_nodes: an array of situation nodes. 193 | :param other_properties: e.g., other_graph_provenance, 194 | or, para_outcome, 195 | or, graph_provenance: paraid__blocknum__paraprompt etc. 196 | ''' 197 | self.nodes = situation_nodes 198 | self.situation_edges = situation_edges 199 | self.other_properties = {k: v for k, v in other_properties.items()} 200 | self.cleanup() 201 | 202 | def cleanup(self, remove_no_effects=True): 203 | """ 204 | This function can also be called at a later point when vetting data is available on a situation graph. 205 | @:param remove_no_effects : unless we extend our graphs to contain no effect nodes, the default should be 206 | True 207 | :return: 208 | """ 209 | self.remove_empty_nodes() 210 | edges_to_remove = [e for e in self.situation_edges 211 | if e.is_noise or e.label == SituationLabel.MARKED_NOISE or ( 212 | remove_no_effects and e.label == SituationLabel.NO_EFFECT)] 213 | for e in edges_to_remove: 214 | self.remove_an_edge(edge_to_remove=e) 215 | 216 | def copy(self): 217 | copied = SituationGraph( 218 | situation_nodes=copy.deepcopy(self.nodes), 219 | situation_edges=copy.deepcopy(self.situation_edges), 220 | other_properties=copy.deepcopy(self.other_properties) 221 | ) 222 | copied.cleanup() 223 | return copied 224 | 225 | def get_empty_nodes(self): 226 | return [n for n in self.nodes if n.is_empty()] 227 | 228 | def remove_empty_nodes(self): 229 | nodes_to_remove = self.get_empty_nodes() 230 | for cand_node in nodes_to_remove: 231 | self.remove_a_node(node_to_remove=cand_node) 232 | 233 | def get_all_node_grounding_pairs(self): 234 | lists = [] 235 | for x in self.nodes: 236 | pairs = x.get_grounding_pairs() 237 | if pairs: 238 | lists.extend(pairs) 239 | return lists 240 | 241 | def remove_a_node(self, node_to_remove): 242 | if not node_to_remove: 243 | return 244 | try: 245 | self.nodes.remove(node_to_remove) 246 | # also remove all edges it occurs in. 247 | edges_to_remove = [e for e in self.situation_edges 248 | if e.from_node.id == node_to_remove.id or e.to_node.id == node_to_remove.id] 249 | for e in edges_to_remove: 250 | self.remove_an_edge(edge_to_remove=e) 251 | except ValueError: 252 | print(f"Warning: Removal of a non-existent node: '{node_to_remove}'") 253 | 254 | def get_exogenous_nodes(self, exogenous_ids=["Z", "V"]): 255 | return [self.lookup_node(node_id=x) for x in exogenous_ids] 256 | 257 | def remove_an_edge(self, edge_to_remove): 258 | if not edge_to_remove: 259 | return 260 | try: 261 | self.situation_edges.remove(edge_to_remove) 262 | except ValueError: 263 | print(f"Warning: Removal of a non-existent edge: '{edge_to_remove}'") 264 | 265 | def lookup_node(self, node_id) -> SituationNode: 266 | for t in self.nodes: 267 | if t.id == node_id: 268 | return t 269 | return None 270 | 271 | def lookup_node_from_grounding(self, grounding) -> SituationNode: 272 | for t in self.nodes: 273 | # There is a very small chance that a grounding in two different nodes is same. 274 | # For those cases, this function just picks the nodes that comes first in the self.nodes array 275 | if grounding in t.groundings: 276 | return t 277 | 278 | # Search unsuccessful, so try removing lowercasing/uppercasing and tokenization mismatches (drop special chars) 279 | stripped_grounding = strip_special_char(a_string=grounding) 280 | if not stripped_grounding: 281 | return None 282 | for t in self.nodes: 283 | if stripped_grounding in [strip_special_char(x) for x in t.groundings]: 284 | return t 285 | 286 | return None 287 | 288 | def lookup_edge(self, edge_source_node, edge_target_node) -> SituationEdge: 289 | for t in self.situation_edges: 290 | if t.from_node == edge_source_node and t.to_node == edge_target_node: 291 | return t 292 | return None 293 | 294 | def speak_path(self, path, path_label, as_symbols=False, specific_grounding=None): 295 | if not path or len(path) < 1: 296 | return "" 297 | return self.speak_path_with_symbols(path=path, path_label=path_label, specific_grounding=specific_grounding) \ 298 | if as_symbols else self.speak_path_in_sentences(path=path, path_label=path_label, 299 | specific_grounding=specific_grounding) 300 | 301 | def construct_path_from_str_arr(self, str_node_arr): 302 | path = [self.lookup_node(node_str.strip()) for node_str in str_node_arr] 303 | path_ok = True 304 | for node in path: 305 | if node is None or node.is_empty(): 306 | path_ok = False 307 | return path if path_ok else None 308 | 309 | def get_graph_id(self): 310 | return self.other_properties["graph_id"] 311 | 312 | def speak_path_with_symbols(self, path, path_label, specific_grounding=None): 313 | ''' 314 | 315 | :param path: array of Situation nodes e.g. V, X, A 316 | :param path_label: array of path_len - 1 nodes, e.g., [RESULTS_IN, NOT_RESULTS_IN] 317 | :return: 318 | ''' 319 | return f"{path[0].get_specific_grounding(specific_grounding)}->" + \ 320 | ("->".join([l.name + "->" + f"{p.get_specific_grounding(specific_grounding)}" 321 | for p, l in zip(path[1:], path_label)])) 322 | 323 | def speak_path_in_sentences(self, path, path_label, specific_grounding=None): 324 | ''' 325 | Symbolic forms are very confusing: 326 | 327 | ```A not B not C 328 | A =/> B =/> C``` 329 | 330 | So instead express it pairwise as multiple sentences (each explanation answer option can be list of sentences): 331 | 332 | ```If A happens, then B will not happen. 333 | If B happens, then C will not happen.``` 334 | 335 | :param path: array of Situation nodes e.g. V, X, A 336 | :param path_label: array of path_len - 1 nodes, e.g., [RESULTS_IN, NOT_RESULTS_IN] 337 | :return: 338 | ''' 339 | 340 | sentences = [] 341 | for edge_num, (n1, n2) in enumerate(zip(path, path[1:])): 342 | speak_label = "will" if path_label[edge_num] == SituationLabel.RESULTS_IN else "will not" 343 | sentences.append(f"If {n1.get_specific_grounding(specific_grounding)} happens, then " 344 | f"{n2.get_specific_grounding(specific_grounding)} {speak_label} happen") 345 | 346 | return ". ".join(sentences).strip() # strip removes last space 347 | # return f"{path[0].get_specific_grounding(specific_grounding)}->" + \ 348 | # ("->".join([l.name + "->" + f"{p.get_specific_grounding(specific_grounding)}" 349 | # for p, l in zip(path[1:], path_label)])) 350 | 351 | def all_labels_in(self, path): 352 | return [self.lookup_edge(src, dest).label if self.lookup_edge(src, 353 | dest) is not None else SituationLabel.NOT_RESULTS_IN 354 | for src, dest in zip(path, path[1:])] 355 | # raw_path_labels = [self.lookup_edge(src, dest).label for src, dest in zip(path, path[1:])] 356 | # fixed_all_labels = raw_path_labels[:-1] 357 | # # FIXME this is possibly a bug and needs to be fixed when end label in path != decision node. 358 | # # and if y=/>a that means flip the inherent label. 359 | # # end node in the path is {accelerate/ decelerate} 360 | # if path[-1].is_decision_node and not self.other_properties.get("is_positive_outcm", True): 361 | # fixed_all_labels.append(raw_path_labels[0].get_opposite_label()) 362 | # else: 363 | # fixed_all_labels.append(raw_path_labels[-1]) 364 | # # fixed_all_labels.append(self.end_label_for_path(path=path)) 365 | # return fixed_all_labels 366 | 367 | # one process per html page (9 situation graphs at most). 368 | # For perl: https://metacpan.org/pod/distribution/Graph-Easy/bin/graph-easy 369 | # For python, see: https://pypi.org/project/graphviz/ 370 | # For javascript, see: http://www.webgraphviz.com/ 371 | @staticmethod 372 | def to_html(graphviz_data_map, html_title, html_outfile_path=None): 373 | ''' 374 | # 375 | #
376 | #html_page_top_str
382 | #