├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── commaqa ├── configs │ ├── README.md │ ├── __init__.py │ ├── dataset_build_config.py │ ├── entities_config.py │ ├── predicate_config.py │ ├── predicate_language_config.py │ ├── step_config.py │ ├── theory_config.py │ └── utils.py ├── dataset │ ├── README.md │ ├── __init__.py │ ├── build_dataset.py │ ├── build_submodel_datasets.py │ ├── generate_decomposition_predictions.py │ ├── generate_decompositions_from_chains.py │ └── utils.py ├── execution │ ├── README.md │ ├── __init__.py │ ├── constants.py │ ├── kblookup.py │ ├── math_model.py │ ├── model_executer.py │ ├── operation_executer.py │ └── utils.py ├── inference │ ├── README.md │ ├── __init__.py │ ├── configurable_inference.py │ ├── constants.py │ ├── dataset_readers.py │ ├── model_search.py │ ├── participant_execution.py │ ├── participant_qgen.py │ ├── participant_util.py │ └── utils.py └── models │ └── generator.py ├── configs ├── commaqav1 │ ├── explicit │ │ ├── entities.libsonnet │ │ ├── movies1.jsonnet │ │ ├── movies1_compgen.jsonnet │ │ ├── movies2.jsonnet │ │ ├── movies2_compgen.jsonnet │ │ ├── predicate_language.libsonnet │ │ ├── table_predicates.libsonnet │ │ ├── text_predicates.libsonnet │ │ ├── theories.libsonnet │ │ └── theories_compgen.libsonnet │ ├── implicit │ │ ├── entities.libsonnet │ │ ├── items0.jsonnet │ │ ├── items1.jsonnet │ │ ├── items2.jsonnet │ │ ├── items3.jsonnet │ │ ├── items4.jsonnet │ │ ├── items5.jsonnet │ │ ├── kb_predicates.libsonnet │ │ ├── predicate_language.libsonnet │ │ ├── text_predicates.libsonnet │ │ └── theories.libsonnet │ └── numeric │ │ ├── entities.libsonnet │ │ ├── predicate_language.libsonnet │ │ ├── sports.jsonnet │ │ ├── sports_compgen.jsonnet │ │ ├── table_predicates.libsonnet │ │ ├── text_predicates.libsonnet │ │ ├── theories.libsonnet │ │ └── theories_compgen.libsonnet └── inference │ ├── commaqav1_beam_search.jsonnet │ ├── commaqav1_brute_force.jsonnet │ ├── commaqav1_greedy_search.jsonnet │ └── commaqav1_sample_search.jsonnet ├── requirements.txt └── scripts ├── build_commaqav1.sh ├── build_datasets.sh ├── build_decompositions.sh ├── build_docker_image.sh └── drop_eval.py /.dockerignore: -------------------------------------------------------------------------------- 1 | output/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IntelliJ specific 2 | *.iml 3 | .idea/ 4 | 5 | # Python specific 6 | *.pyc 7 | 8 | # Output dumps 9 | output/ 10 | 11 | # Scratch folder 12 | scratch/ 13 | 14 | # Wandb 15 | wandb/ 16 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | ENV LC_ALL=C.UTF-8 4 | ENV LANG=C.UTF-8 5 | 6 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 7 | 8 | # Tell nvidia-docker the driver spec that we need as well as to 9 | # use all available devices, which are mounted at /usr/local/nvidia. 10 | # The LABEL supports an older version of nvidia-docker, the env 11 | # variables a newer one. 12 | ENV NVIDIA_VISIBLE_DEVICES all 13 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility 14 | LABEL com.nvidia.volumes.needed="nvidia_driver" 15 | 16 | RUN apt-get update --fix-missing && apt-get install -y \ 17 | gettext-base && \ 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | WORKDIR /stage/ 21 | 22 | COPY requirements.txt . 23 | RUN pip install -r requirements.txt 24 | RUN python -m nltk.downloader stopwords 25 | RUN python -m nltk.downloader punkt 26 | COPY commaqa/ commaqa/ 27 | 28 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CommaQA: *Comm*unicating with *A*gents for *QA* 2 | 3 | CommaQA Dataset is a QA benchmark for learning to communicate with agents. It consists of three 4 | datasets capturing three forms of multi-hop reasoning -- explicit(CommaQA-E), implicit(CommaQA-I), 5 | and numeric(CommaQA-N). 6 | 7 | **Paper Link**: 8 | [Semantic Scholar](https://api.semanticscholar.org/CorpusID:239016681) 9 | 10 | **Citation**: 11 | ``` 12 | @article{Khot2021LearningTS, 13 | title={Learning to Solve Complex Tasks by Talking to Agents}, 14 | author={Tushar Khot and Kyle Richardson and Daniel Khashabi and Ashish Sabharwal}, 15 | journal={ArXiv}, 16 | year={2021}, 17 | volume={abs/2110.08542} 18 | } 19 | ``` 20 | 21 | 22 | Table of Contents 23 | =============== 24 | 25 | * [Dataset](#Dataset) 26 | * [Download](#Download) 27 | * [Formats](#Formats) 28 | * [Models](#Models) 29 | * [Code](#Code) 30 | 31 | ## Dataset 32 | 33 | ### Download 34 | 35 | Download the datasets: 36 | 37 | * [Explicit](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_explicit.zip) 38 | * [Implicit](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_implicit.zip) 39 | * [Numeric](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_numeric.zip) 40 | * [Compositional Generalization (test only)](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_compgen.zip) 41 | 42 | ### Formats 43 | 44 | Each dataset contains three formats: 45 | 46 | * commaqa/: This is default CommaQA format that contains the raw facts, verbalized sentences, 47 | agent language specification, QA pairs and associated theories. Each JSON file consists of list of 48 | items with each item containing a KB and a group of questions. The JSON file uses the following 49 | key structure: 50 | 51 | * `kb`: a map between predicate names and the facts associated with each predicate 52 | * `context`: verbalized context used by black-box models 53 | * `per_fact_context`: map between the facts in the `kb` and the corresponding verbalized 54 | sentence in `context` 55 | * `pred_lang_config`: map between the agent name and the questions answerable by this agent 56 | (see [ModelQuestionConfig](commaqa/configs/predicate_language_config.py) for more details) 57 | * `qa_pairs`: list of QA pairs associated with this context using the keys: 58 | * `id`: question id 59 | * `question`: question 60 | * `answer`: numeric or list answer 61 | * `config`: specific theory config used to construct this example (see 62 | [TheoryConfig](commaqa/configs/theory_config.py) for more details) 63 | * `assignment`: assignment to the variables in the theory 64 | * `decomposition`: agent-based decomposition to answer this question 65 | * `facts_used`: facts needed to answer this question 66 | 67 | * drop/: This contains the dataset in the default [DROP dataset](https://allennlp.org/drop) format 68 | by converting the `context` and `qa_pairs` from the commaqa/ format. 69 | 70 | * seq2seq/: This contains the dataset in a simple text format (one example per line) by converting 71 | the `context` and `qa_pairs` from the commaqa/ format into ` Q: A: `. 72 | 73 | 74 | ### Additional Datasets 75 | 76 | * Decompositions: The training data for the decomposer can be downloaded from 77 | [here](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_decompositions.zip). 78 | The data uses the JSONL format with `train_seqs` field containing the decompositions for each 79 | question. Each entry in the `train_seqs` array corresponds to one step in the decomposition chain. 80 | E.g. 81 | ``` 82 | QC: What awards have the actors of the Hallowcock winning movies received? 83 | QI: (select) [table] The award Hallowcock has been awarded to which movies? A: [\"Clenestration\"] 84 | QS: (project_values_flat_unique) [text] Who all acted in the movie #1? 85 | ``` 86 | can be used to train a model to generate the question 87 | `(project_values_flat_unique) [text] Who all acted in the movie #1?` (string following `QS:`) given 88 | the previous questions and answers. 89 | 90 | * Language: The language specification for agents in each dataset can be downloaded from 91 | [here](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/commaqa_language.zip). It contains 92 | to files: 93 | * `operations.txt`: File containing the list of operations for CommaQA 94 | * `model_questions.tsv`: A TSV file where the first field corresponds to the model name and all 95 | the subsequent fields contain valid questions that can be asked to this model. 96 | 97 | ## Models 98 | We also provide the T5-Large models trained to produce the next sub-question based on the oracle 99 | decompositions. These models can be used to perform inference as described 100 | [here]((commaqa/inference/README.md) to reproduce the TMN results. 101 | * [CommaQA-E Model](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/oracle_tmns/commaqa_e_oracle_model.zip) 102 | * [CommaQA-I Model](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/oracle_tmns/commaqa_i_oracle_model.zip) 103 | * [CommaQA-N Model](https://ai2-public-datasets.s3.amazonaws.com/commaqa/v1/oracle_tmns/commaqa_n_oracle_model.zip) 104 | 105 | ## Code 106 | 107 | Refer to the individual READMEs in each package for instructions on: 108 | 109 | * [Config Format](commaqa/configs/README.md) 110 | * [Building Datasets](commaqa/dataset/README.md) 111 | * [Building Agents/Operations](commaqa/execution/README.md) 112 | * [Running Inference](commaqa/inference/README.md) 113 | -------------------------------------------------------------------------------- /commaqa/configs/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Configs 2 | 3 | To create your own dataset with new KB, Agents and Theories, you would need to specify your own 4 | dataset config file. This README describes the format of such a config file and the individual 5 | objects within the config file. Refer to the [README](../dataset/README.md) for details about 6 | building the dataset. 7 | 8 | ### Defining your own config 9 | To define a new dataset configuration, create a JSONNET file with the following format: 10 | ```jsonnet 11 | { 12 | version: 3.0, 13 | entities: Entities, 14 | predicates: Predicates, 15 | predicate_language: Predicate_Language, 16 | theories: Theories, 17 | } 18 | ``` 19 | Each object in this config is described below. You may choose to format the jsonnet file using imported libsonnet files, e.g., in 20 | [CommaQA-E](../../configs/commaqav1/explicit). 21 | 22 | ### Config Objects 23 | 24 | #### Entities 25 | A dictionary of `EntityType` to a list of strings that correspond to entities of this type. 26 | E.g. 27 | ```jsonnet 28 | { 29 | entities: { 30 | movie: [ 31 | "Godfather", 32 | "Kill Bill", 33 | ], 34 | actor: [ 35 | "Marlon Brando", 36 | "Uma Thurman", 37 | ], 38 | year: [ 39 | "1990", 40 | "1995", 41 | ] 42 | } 43 | } 44 | ``` 45 | 46 | 47 | #### Predicates 48 | Each predicate defines a relation between `EntityType`s in the list of `entities`. 49 | E.g. 50 | ```jsonnet 51 | { 52 | acted_in: { 53 | args: ["movie", "actor"], 54 | nary: ["n", "1"], 55 | language: ["movie: $1 ; actor: $2", "$2 acted in $1."], 56 | } 57 | } 58 | ``` 59 | Here the `args` specify the arguments of the `acted_in` predicate. `nary` specifies that the 60 | relation is 1-to-many. Specifically, when this KB is grounded (via sampling), each movie can appear 61 | multiple times in the KB relations but the actor can only appear once. This will essentially cause 62 | each movie to have multiple actors and each actor to appear in only one movie. Changing `nary` to 63 | `["n", "n"]` would make it a many-to-many relation. Note that it is also specify a tree (e.g. Isa) 64 | or chain structure (e.g. successor) to your KB by setting `type` to `tree` or `chain` instead of 65 | `nary` field. The `language` field specifies how the relation will be verbalized in text. This is not relevant for 66 | the agents but defines the input context for black-box models. One of the verbalizations from the 67 | `language` is randomly sampled for each KB fact when generating the context. 68 | 69 | 70 | The `predicates` in the configuration file is a dictionary of `PredicateName` to the properties of 71 | the predicate as described above. E.g. 72 | ```jsonnet 73 | { 74 | predicates: { 75 | acted_in: { 76 | args: ["movie", "actor"], 77 | nary: ["n", "1"], 78 | language: ["movie: $1 ; actor: $2", "$2 acted in $1."], 79 | }, 80 | released_in: { 81 | args: ["movie", "year"], 82 | nary: ["1", "n"], 83 | language: ["movie: $1 ; year: $2"], 84 | } 85 | } 86 | } 87 | ``` 88 | 89 | #### Predicate Language 90 | This field defines the questions that can be answered by each agent and how each agent answers these 91 | questions given their KB. The configuration specifies a helper predicate, what `EntityType`s to use 92 | for the question, the agent name, how the questions should be phrased and how the questions are 93 | answered using the KB. For example, 94 | ```jsonnet 95 | { 96 | "acted_a($1, ?)": { 97 | "init": { 98 | "$1": "movie" 99 | }, 100 | "model": "text", 101 | "questions": [ 102 | "Who all acted in the movie $1?", 103 | "Who are the actors in the movie $1?" 104 | ], 105 | "steps": [ 106 | { 107 | "answer": "#1", 108 | "operation": "select", 109 | "question": "text_actor($1, ?)" 110 | } 111 | ] 112 | } 113 | } 114 | ``` 115 | In this example: 116 | * `acted_a($1, ?)`: This is a helper predicate that will be used to define the theories later. 117 | This question takes one input argument `$1` and returns the second argument `?` as the answer. 118 | * `init`: This dictionary specifies the entity type for each argument in the predicate name. In 119 | this case the question takes one argument which is a movie. 120 | * `model`: This field specifies the agent name -- `text` (i.e. TextQA agent) in this case 121 | * `questions`: This list specifies the different formulations of this question that can be 122 | answered by the agent. Here the `text` agent can answer questions about the actors of a movie using 123 | either of these forms. Note that our synthetic dataset uses symbolic agents that need one of these 124 | formulations to be used exactly. 125 | * `steps`: This field can be used to describe a multi-step procedure to answer these questions. In 126 | this case, we defined a single-step procedure where the agent will take the input movie `$1` and 127 | lookup the `text_actor` relation in the KB. The first argument of this KB lookup will be grounded to 128 | the movie (`$1`) and the second argument (`?`) will be returned as the answer. The answer is named 129 | as `#1` (this name can be used in future steps to refer to this answer). Refer to the paper for an 130 | explanation about the `select` operation (and other possible operations). 131 | 132 | 133 | #### Theories 134 | Finally we define the theories that will be used to create the questions for the complex tasks based 135 | on the components described above. Consider the following sample theory: 136 | ```jsonnet 137 | { 138 | "init": { 139 | "$1": "nation" 140 | }, 141 | "questions": [ 142 | "What movies have people from the country $1 acted in?" 143 | ], 144 | "steps": [ 145 | { 146 | "answer": "#1", 147 | "operation": "select", 148 | "question": "nation_p($1, ?)" 149 | }, 150 | { 151 | "answer": "#2", 152 | "operation": "project_values_flat_unique", 153 | "question": "acted_m(#1, ?)" 154 | } 155 | ] 156 | } 157 | ``` 158 | * `init`: Again refers to the `EntityType` used to create the question 159 | * `questions`: List of possible formalizations of this question 160 | * `steps`: Multi-step procedure to execute this question. Refer to the explanation of 161 | `steps` [above](#predicate-language). For explanation about `operation`, refer to our paper. The 162 | `question` field refers to the helper predicate names introduced in the 163 | [Predicate Language](#predicate-language) section -------------------------------------------------------------------------------- /commaqa/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/CommaQA/663fda8fa280306297a805aefb671b04661aef74/commaqa/configs/__init__.py -------------------------------------------------------------------------------- /commaqa/configs/dataset_build_config.py: -------------------------------------------------------------------------------- 1 | from commaqa.configs.entities_config import EntitiesConfig 2 | from commaqa.configs.predicate_config import PredicateConfig 3 | from commaqa.configs.predicate_language_config import PredicateLanguageConfig 4 | from commaqa.configs.theory_config import TheoryConfig 5 | 6 | 7 | class DatasetBuildConfig: 8 | def __init__(self, input_json): 9 | self.version = input_json["version"] 10 | self.entities = EntitiesConfig(input_json["entities"]) 11 | self.predicates = [PredicateConfig(x) for x in input_json["predicates"].items()] 12 | self.theories = [TheoryConfig(x) for x in input_json["theories"]] 13 | self.pred_lang_config = PredicateLanguageConfig(input_json["predicate_language"]) 14 | -------------------------------------------------------------------------------- /commaqa/configs/entities_config.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import ceil 3 | from typing import Dict, Any, List 4 | 5 | 6 | class EntitiesConfig: 7 | def __init__(self, entities_json: Dict[str, List[str]]): 8 | self.entity_type_map = entities_json 9 | 10 | def subsample(self, num_ents): 11 | new_ent_map = {} 12 | for etype, elist in self.entity_type_map.items(): 13 | # if fraction passed, sample ratio 14 | if num_ents <= 1: 15 | new_ent_map[etype] = random.sample(elist, ceil(len(elist) * num_ents)) 16 | else: 17 | new_ent_map[etype] = random.sample(elist, num_ents) 18 | 19 | return EntitiesConfig(new_ent_map) 20 | 21 | def __getitem__(self, item: str): 22 | return self.entity_type_map[item] 23 | -------------------------------------------------------------------------------- /commaqa/configs/predicate_config.py: -------------------------------------------------------------------------------- 1 | import random 2 | from copy import deepcopy 3 | from typing import List, Dict 4 | 5 | from commaqa.configs.entities_config import EntitiesConfig 6 | from commaqa.dataset.utils import get_predicate_args 7 | 8 | 9 | class PredicateConfig: 10 | def __init__(self, pred_json): 11 | self.pred_name = pred_json[0] 12 | self.args = pred_json[1]["args"] 13 | self.nary = pred_json[1].get("nary") 14 | self.graph_type = pred_json[1].get("type") 15 | self.language = pred_json[1].get("language") 16 | 17 | def populate_chains(self, entity_config: EntitiesConfig) -> List[str]: 18 | if len(self.args) != 2 or self.args[0] != self.args[1]: 19 | raise ValueError("Chains KB can only be created with binary predicates having the same " 20 | "arg types. Change args for {}".format(self.pred_name)) 21 | kb = [] 22 | entity_list = deepcopy(entity_config.entity_type_map[self.args[0]]) 23 | last_entity = None 24 | while len(entity_list) > 0: 25 | if last_entity is None: 26 | last_entity = random.choice(entity_list) 27 | entity_list.remove(last_entity) 28 | next_entity = random.choice(entity_list) 29 | entity_arr = [last_entity, next_entity] 30 | fact = self.pred_name + "(" + ", ".join(entity_arr) + ")" 31 | kb.append(fact) 32 | last_entity = next_entity 33 | entity_list.remove(last_entity) 34 | return kb 35 | 36 | def populate_trees(self, entity_config: EntitiesConfig) -> List[str]: 37 | if len(self.args) != 2 or self.args[0] != self.args[1]: 38 | raise ValueError("Trees KB can only be created with binary predicates having the same " 39 | "arg types. Change args for {}".format(self.pred_name)) 40 | if len(self.nary) is None or "1" not in self.nary: 41 | raise ValueError("Nary needs to be set with at least one index set to 1 to produce" 42 | "a tree structure kb. Pred: {}".format(self.pred_name)) 43 | kb = [] 44 | entity_list = deepcopy(entity_config.entity_type_map[self.args[0]]) 45 | # create the root node 46 | open_entities = random.sample(entity_list, 1) 47 | entity_list.remove(open_entities[0]) 48 | unique_idx = self.nary.index("1") 49 | while len(entity_list) > 0: 50 | new_open_entities = [] 51 | # for each open node 52 | for open_entity in open_entities: 53 | # select children 54 | if len(entity_list) > 2: 55 | children = random.sample(entity_list, 2) 56 | else: 57 | children = entity_list 58 | # add edge between child and open node 59 | for child in children: 60 | if unique_idx == 1: 61 | entity_arr = [open_entity, child] 62 | else: 63 | entity_arr = [child, open_entity] 64 | # remove child from valid nodes to add 65 | entity_list.remove(child) 66 | # add it to the next set of open nodes 67 | new_open_entities.append(child) 68 | fact = self.pred_name + "(" + ", ".join(entity_arr) + ")" 69 | kb.append(fact) 70 | open_entities = new_open_entities 71 | return kb 72 | 73 | def populate_kb(self, entity_config: EntitiesConfig) -> List[str]: 74 | if self.graph_type == "chain": 75 | return self.populate_chains(entity_config) 76 | elif self.graph_type == "tree": 77 | return self.populate_tree(entity_config) 78 | elif self.graph_type is not None: 79 | raise ValueError("Unknown graph type: {}".format(self.graph_type)) 80 | if self.nary is None: 81 | raise ValueError("At least one of nary or type needs to be set for predicate" 82 | " {}".format(self.pred_name)) 83 | 84 | return self.populate_relations(entity_config) 85 | 86 | def populate_relations(self, entity_config: EntitiesConfig) -> List[str]: 87 | kb = set() 88 | arg_counts = [] 89 | arg_pos_list = [] 90 | for arg in self.args: 91 | if arg not in entity_config.entity_type_map: 92 | raise ValueError("No entity list defined for {}." 93 | "Needed for predicate: {}".format(arg, self.pred_name)) 94 | arg_counts.append(len(entity_config.entity_type_map[arg])) 95 | arg_pos_list.append(deepcopy(entity_config.entity_type_map[arg])) 96 | 97 | max_attempts = 2 * max(arg_counts) 98 | orig_arg_pos_list = deepcopy(arg_pos_list) 99 | while max_attempts > 0: 100 | entity_arr = [] 101 | max_attempts -= 1 102 | for idx in range(len(self.args)): 103 | ent = random.choice(arg_pos_list[idx]) 104 | # assume relations can never be reflexive 105 | if ent in entity_arr: 106 | entity_arr = None 107 | break 108 | entity_arr.append(ent) 109 | if entity_arr is None: 110 | continue 111 | for idx, ent in enumerate(entity_arr): 112 | if self.nary[idx] == "1": 113 | arg_pos_list[idx].remove(ent) 114 | if len(arg_pos_list[idx]) == 0: 115 | max_attempts = 0 116 | elif self.nary[idx] == "n": 117 | arg_pos_list[idx].remove(ent) 118 | # once all entities have been used once, reset to the original list 119 | if len(arg_pos_list[idx]) == 0: 120 | arg_pos_list[idx] = deepcopy(orig_arg_pos_list[idx]) 121 | fact = self.pred_name + "(" + ", ".join(entity_arr) + ")" 122 | if fact not in kb: 123 | kb.add(fact) 124 | return list(kb) 125 | 126 | def generate_kb_fact_map(self, kb: Dict[str, List[str]]) -> Dict[str, str]: 127 | kb_fact_map = {} 128 | for kb_item in kb[self.pred_name]: 129 | if self.language: 130 | pred, args = get_predicate_args(kb_item) 131 | sentence = self.language if isinstance(self.language, str) \ 132 | else random.choice(self.language) 133 | for argidx, arg in enumerate(args): 134 | sentence = sentence.replace("$" + str(argidx + 1), arg) 135 | else: 136 | pred_name, fields = get_predicate_args(kb_item) 137 | if len(fields) != 2: 138 | sentence = kb_item 139 | else: 140 | sentence = fields[0] + " " + pred_name + " " + " ".join(fields[1:]) 141 | kb_fact_map[kb_item] = sentence + "." 142 | return kb_fact_map 143 | 144 | def generate_context(self, kb: Dict[str, List[str]]) -> str: 145 | kb_fact_map = self.generate_kb_fact_map(kb) 146 | return " ".join(kb_fact_map.values()) 147 | -------------------------------------------------------------------------------- /commaqa/configs/predicate_language_config.py: -------------------------------------------------------------------------------- 1 | from commaqa.configs.step_config import StepConfig 2 | from commaqa.dataset.utils import get_predicate_args 3 | 4 | 5 | class ModelQuestionConfig: 6 | def __init__(self, config_json): 7 | self.steps = [StepConfig(x) for x in 8 | config_json["steps"]] if "steps" in config_json else [] 9 | self.questions = config_json.get("questions") 10 | self.init = config_json["init"] 11 | self.model = config_json["model"] 12 | self.predicate = config_json["predicate"] 13 | 14 | def to_json(self): 15 | return { 16 | "steps": [x.to_json() for x in self.steps], 17 | "questions": self.questions, 18 | "init": self.init, 19 | "model": self.model, 20 | "predicate": self.predicate 21 | } 22 | 23 | 24 | class PredicateLanguageConfig: 25 | def __init__(self, pred_lang_config): 26 | # import json 27 | # print(json.dumps(pred_lang_config, indent=2)) 28 | self.predicate_config = {} 29 | self.model_config = {} 30 | for predicate, config in pred_lang_config.items(): 31 | config["predicate"] = predicate 32 | question_config = ModelQuestionConfig(config) 33 | self.predicate_config[predicate] = question_config 34 | model = config["model"] 35 | if model not in self.model_config: 36 | self.model_config[model] = [] 37 | self.model_config[model].append(question_config) 38 | 39 | def model_config_as_json(self): 40 | return {model: [config.to_json() for config in configs] 41 | for model, configs in self.model_config.items()} 42 | 43 | def find_model(self, question_predicate): 44 | matching_configs = self.find_valid_configs(question_predicate) 45 | if len(matching_configs) == 0: 46 | return None 47 | matching_models = {x.model for x in matching_configs} 48 | if len(matching_models) != 1: 49 | raise ValueError("Unexpected number of matching models: {} for {}. " 50 | "Expected one model".format(matching_models, question_predicate)) 51 | return matching_models.pop() 52 | 53 | def find_valid_configs(self, question_predicate): 54 | qpred, qargs = get_predicate_args(question_predicate) 55 | matching_configs = [] 56 | for key, config in self.predicate_config.items(): 57 | config_qpred, config_qargs = get_predicate_args(key) 58 | if config_qpred == qpred: 59 | assert len(qargs) == len(config_qargs), \ 60 | "{} {}\n{}".format(qargs, config_qargs, question_predicate) 61 | mismatch = False 62 | for qarg, cqarg in zip(qargs, config_qargs): 63 | if (cqarg == "?") ^ (qarg == "?"): 64 | mismatch = True 65 | if not mismatch: 66 | matching_configs.append(config) 67 | return matching_configs 68 | -------------------------------------------------------------------------------- /commaqa/configs/step_config.py: -------------------------------------------------------------------------------- 1 | class StepConfig: 2 | def __init__(self, step_json): 3 | self.operation = step_json["operation"] 4 | self.question = step_json["question"] 5 | self.answer = step_json["answer"] 6 | 7 | def to_json(self): 8 | return self.__dict__ 9 | -------------------------------------------------------------------------------- /commaqa/configs/theory_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | import string 5 | from typing import Dict, List 6 | 7 | from commaqa.configs.step_config import StepConfig 8 | from commaqa.configs.utils import execute_steps 9 | from commaqa.dataset.utils import dict_product, align_assignments, nonempty_answer, is_question_var 10 | from commaqa.execution.model_executer import ModelExecutor 11 | from commaqa.execution.operation_executer import OperationExecuter 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class TheoryConfig: 17 | def __init__(self, theory_json): 18 | self.steps = [StepConfig(x) for x in theory_json["steps"]] 19 | self.questions = theory_json.get("questions") 20 | self.init = theory_json["init"] 21 | 22 | def to_json(self): 23 | return { 24 | "steps": [x.to_json() for x in self.steps], 25 | "questions": self.questions, 26 | "init": self.init 27 | } 28 | 29 | def to_str(self): 30 | return json.dumps(self.to_json()) 31 | 32 | def get_possible_assignments(self, entities: Dict[str, List[str]], 33 | model_library: Dict[str, ModelExecutor], 34 | pred_lang_config): 35 | assignment_dict = {} 36 | for key, ent_type in self.init.items(): 37 | assignment_dict[key] = entities[ent_type] 38 | possible_assignments = dict_product(assignment_dict) 39 | # assume no duplicates in assignments 40 | possible_assignments = [assignment for assignment in possible_assignments 41 | if len(set(assignment.values())) == len(assignment.values())] 42 | op_executor = OperationExecuter(model_library=model_library) 43 | output_assignments = [] 44 | for curr_assignment in possible_assignments: 45 | # print(self.to_json()) 46 | new_assignment = execute_steps(steps=self.steps, input_assignments=curr_assignment, 47 | executer=op_executor, pred_lang_config=pred_lang_config, 48 | input_model=None) 49 | if new_assignment: 50 | output_assignments.append(new_assignment) 51 | 52 | if len(output_assignments) < 2: 53 | logger.debug("Few assignments: {} found for theory: {} given kb:\n {}".format( 54 | json.dumps(output_assignments, indent=2), self.to_str(), 55 | json.dumps(list(model_library.values())[0].kblookup.kb, indent=2))) 56 | return output_assignments 57 | 58 | def create_decompositions(self, pred_lang_config, assignment): 59 | decomposition = [] 60 | for step in self.steps: 61 | valid_configs = pred_lang_config.find_valid_configs(step.question) 62 | if len(valid_configs) == 0: 63 | raise ValueError("No predicate config matches {}".format(step.question)) 64 | 65 | # # model less operation 66 | # model = "N/A" 67 | # print(step.question) 68 | # question = step.question 69 | # for k, v in assignment.items(): 70 | # if k.startswith("$"): 71 | # question = question.replace(k, v) 72 | # else: 73 | lang_conf = random.choice(valid_configs) 74 | model = lang_conf.model 75 | question = random.choice(lang_conf.questions) 76 | _, assignment_map = align_assignments(lang_conf.predicate, step.question, 77 | assignment) 78 | for lang_pred_arg, question_pred_arg in assignment_map.items(): 79 | if is_question_var(question_pred_arg): 80 | question = question.replace(lang_pred_arg, assignment[question_pred_arg]) 81 | else: 82 | # replace the question idx with the appropriate answer idx in the theory 83 | question = question.replace(lang_pred_arg, question_pred_arg) 84 | answer = assignment[step.answer] 85 | decomposition.append({ 86 | "m": model, 87 | "q": question, 88 | "a": answer, 89 | "op": step.operation 90 | }) 91 | return decomposition 92 | 93 | def create_questions(self, entities: Dict[str, List[str]], pred_lang_config, model_library): 94 | possible_assignments = self.get_possible_assignments(entities=entities, 95 | pred_lang_config=pred_lang_config, 96 | model_library=model_library) 97 | qa = [] 98 | for assignment in possible_assignments: 99 | decomposition = self.create_decompositions(pred_lang_config=pred_lang_config, 100 | assignment=assignment) 101 | # move facts_used out of the assignment structure 102 | facts_used = list(set(assignment["facts_used"])) 103 | del assignment["facts_used"] 104 | question = random.choice(self.questions) 105 | answer = assignment[self.steps[-1].answer] 106 | for p, f in assignment.items(): 107 | if p in question: 108 | question = question.replace(p, f) 109 | if decomposition[-1]["a"] != answer: 110 | raise ValueError("Answer to the last question in decomposition not the same as the " 111 | "final answer!.\n Decomposition:{} \n Question: {} \n Answer: {}" 112 | "".format(decomposition, question, answer)) 113 | # ignore questions with no valid answers 114 | if not nonempty_answer(answer): 115 | continue 116 | # ignore questions with too many answers 117 | if isinstance(answer, list) and len(answer) > 5: 118 | continue 119 | qa.append({ 120 | "question": question, 121 | "answer": answer, 122 | "assignment": assignment, 123 | "config": self.to_json(), 124 | "decomposition": decomposition, 125 | "facts_used": facts_used, 126 | "id": "".join([random.choice(string.hexdigits) for n in range(16)]).lower() 127 | }) 128 | return qa 129 | -------------------------------------------------------------------------------- /commaqa/configs/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | from typing import List, Dict 4 | 5 | from commaqa.configs.predicate_language_config import PredicateLanguageConfig 6 | from commaqa.configs.step_config import StepConfig 7 | from commaqa.dataset.utils import is_question_var, nonempty_answer 8 | from commaqa.execution.operation_executer import OperationExecuter 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def execute_steps(steps: List[StepConfig], input_assignments: Dict[str, str], 14 | executer: OperationExecuter, pred_lang_config: PredicateLanguageConfig = None, 15 | input_model: str = None): 16 | curr_assignment = deepcopy(input_assignments) 17 | if "facts_used" not in curr_assignment: 18 | curr_assignment["facts_used"] = [] 19 | 20 | for step in steps: 21 | if input_model is None: 22 | model = pred_lang_config.find_model(step.question) 23 | if model is None: 24 | raise ValueError("No model found for {}".format(step.question)) 25 | else: 26 | model = input_model 27 | 28 | new_question = step.question 29 | for k, v in curr_assignment.items(): 30 | # only replace question variables($1, $2). Answer variables (#1, #2) used by executer 31 | if is_question_var(k): 32 | new_question = new_question.replace(k, v) 33 | answers, curr_facts = executer.execute_operation(operation=step.operation, 34 | model=model, 35 | question=new_question, 36 | assignments=curr_assignment) 37 | if answers is None: 38 | # execution failed 39 | return None 40 | elif nonempty_answer(answers): 41 | curr_assignment[step.answer] = answers 42 | curr_assignment["facts_used"].extend(curr_facts) 43 | else: 44 | logger.debug("Stopped Execution. Empty answer: {}\n" 45 | "Question: {}\n Step: {}\n Assignment: {}".format( 46 | answers, new_question, step.to_json(), curr_assignment)) 47 | return {} 48 | return curr_assignment 49 | -------------------------------------------------------------------------------- /commaqa/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Building Datasets 2 | 3 | ## Building QA Dataset 4 | To build the CommaQA dataset, you can use the script [build_commaqav1](../../scripts/build_commaqav1.sh) 5 | script. It will generate the dataset in `output/commaqav1` folder using the configs from 6 | `configs/commaqav1` directory. Note that this is will be a different sample of the synthetic 7 | dataset and won't be exactly same as the dataset used in our paper. Use the dataset links 8 | provided in the [main README](../../README.md) to get the same dataset as the one used in the paper. 9 | 10 | 11 | To build your own dataset, create the configuration file as per the specifications described in 12 | [configs README](../configs/README.md). You can then use the `build_dataset` script with your 13 | new configuration file: 14 | ```shell 15 | sh scripts/build_datasets.sh \ 16 | [YOUR NEW CONFIG FILE].jsonnet \ 17 | [OUTPUT DIRECTORY] 18 | ``` 19 | The script will generate JSON files in CommaQA format with train/dev/test splits in the output 20 | directory. 21 | 22 | ## Generating Decompositions 23 | To produce the decompositions for any generated dataset in CommaQA format, you can use the 24 | [build_decompositions](../../scripts/build_decompositions.sh) script. e.g. 25 | ```shell 26 | sh scripts/build_decompositions.sh \ 27 | output/commaqav1/explicit \ 28 | output/commaqav1/explicit_decomp 29 | ``` 30 | will generate the decompositions for the CommaQA-E dataset. For each train/dev/test.json file in the 31 | input, there is a corresponding file in the output folder (`output/commaqav1/explicit_decomp` here). 32 | The decompositions are added as new keys: `train_seqs` for each example in the JSON. E.g. 33 | ```jsonnet 34 | { 35 | "train_seqs": [ 36 | " QC: What awards have the actors of the Hallowcock winning movies received? QS: (select) [table] The award Hallowcock has been awarded to which movies?", 37 | " QC: What awards have the actors of the Hallowcock winning movies received? QI: (select) [table] The award Hallowcock has been awarded to which movies? A: [\"Clenestration\"] QS: (project_values_flat_unique) [text] Who all acted in the movie #1?", 38 | " QC: What awards have the actors of the Hallowcock winning movies received? QI: (select) [table] The award Hallowcock has been awarded to which movies? A: [\"Clenestration\"] QI: (project_values_flat_unique) [text] Who all acted in the movie #1? A: [\"Huckberryberry\", \"Sapien\"] QS: (project_values_flat_unique) [table] #2 has been awarded which awards?", 39 | " QC: What awards have the actors of the Hallowcock winning movies received? QI: (select) [table] The award Hallowcock has been awarded to which movies? A: [\"Clenestration\"] QI: (project_values_flat_unique) [text] Who all acted in the movie #1? A: [\"Huckberryberry\", \"Sapien\"] QI: (project_values_flat_unique) [table] #2 has been awarded which awards? A: [\"Custodio\", \"Lameze\"] QS: [EOQ]" 40 | ] 41 | } 42 | ``` 43 | captures the decomposition of "What awards have the actors of the Hallowcock winning movies 44 | received?". Each entry in `train_seqs` captures the next question to be generated given the previous 45 | decomposition history. This format can be used to train a `NextGen` model for TMNs which is also 46 | trained to generate the next question given the history. 47 | -------------------------------------------------------------------------------- /commaqa/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/CommaQA/663fda8fa280306297a805aefb671b04661aef74/commaqa/dataset/__init__.py -------------------------------------------------------------------------------- /commaqa/dataset/build_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import random 6 | from math import ceil 7 | from random import shuffle 8 | from shutil import copyfile 9 | from typing import List 10 | 11 | import _jsonnet 12 | 13 | from commaqa.configs.dataset_build_config import DatasetBuildConfig 14 | from commaqa.execution.utils import build_models 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def parse_arguments(): 20 | arg_parser = argparse.ArgumentParser(description='Build a CommaQA dataset from inputs') 21 | arg_parser.add_argument('--input_json', type=str, required=True, 22 | help="Input JSON configuration files " 23 | "(comma-separated for multiple files)") 24 | arg_parser.add_argument('--output', "-o", type=str, required=True, help="Output folder") 25 | arg_parser.add_argument('--num_groups', type=int, required=False, default=100, 26 | help="Number of example groups to create") 27 | arg_parser.add_argument('--num_examples_per_group', type=int, required=False, default=10, 28 | help="Number of examples per group") 29 | arg_parser.add_argument('--entity_percent', type=float, required=False, default=1.0, 30 | help="Percentage of entities to sample for each group") 31 | arg_parser.add_argument('--debug', action="store_true", default=False, 32 | help="Enable debug logging") 33 | 34 | return arg_parser.parse_args() 35 | 36 | 37 | class DatasetBuilder: 38 | def __init__(self, configs: List[DatasetBuildConfig]): 39 | self.configs = configs 40 | 41 | def build_dataset(self, 42 | num_entities_per_group=5, 43 | num_groups: int = 10, num_examples_per_group: int = 10): 44 | data = [] 45 | numqs_per_theory = {} 46 | logger.info("Creating examples with {} questions per group. #Groups: {}. " 47 | "#Entities per group: {}".format(num_examples_per_group, num_groups, 48 | num_entities_per_group)) 49 | # Distribution over input configs to sample equal #examples from each config 50 | # Start with equal number (num_groups) and reduce by len(configs) each time 51 | # This will ensure that we get num_groups/len(configs) groups per config as well as 52 | # a peaky distribution that samples examples from rarely used configs 53 | config_distribution = [num_groups for x in self.configs] 54 | group_idx = 0 55 | num_attempts = 0 56 | while group_idx < num_groups: 57 | num_attempts += 1 58 | config_idx = random.choices(range(len(self.configs)), config_distribution)[0] 59 | current_config = self.configs[config_idx] 60 | # sample entities based on the current config 61 | entities = current_config.entities.subsample(num_entities_per_group) 62 | 63 | # build a KB based on the entities 64 | complete_kb = {} 65 | complete_kb_fact_map = {} 66 | for pred in current_config.predicates: 67 | complete_kb[pred.pred_name] = pred.populate_kb(entities) 68 | curr_pred_kb_fact_map = pred.generate_kb_fact_map(complete_kb) 69 | for k, v in curr_pred_kb_fact_map.items(): 70 | complete_kb_fact_map[k] = v 71 | 72 | questions_per_theory = {} 73 | context = " ".join(complete_kb_fact_map.values()) 74 | 75 | output_data = { 76 | "kb": complete_kb, 77 | "context": context, 78 | "per_fact_context": complete_kb_fact_map, 79 | "pred_lang_config": current_config.pred_lang_config.model_config_as_json() 80 | } 81 | 82 | # build questions using KB and language config 83 | model_library = build_models(current_config.pred_lang_config.model_config, complete_kb) 84 | for theory in current_config.theories: 85 | theory_qs = theory.create_questions(entities=entities.entity_type_map, 86 | pred_lang_config=current_config.pred_lang_config, 87 | model_library=model_library) 88 | theory_key = theory.to_str() 89 | if theory_key not in numqs_per_theory: 90 | numqs_per_theory[theory_key] = 0 91 | numqs_per_theory[theory_key] += len(theory_qs) 92 | questions_per_theory[theory_key] = theory_qs 93 | all_questions = [qa for qa_per_theory in questions_per_theory.values() 94 | for qa in qa_per_theory] 95 | if len(all_questions) < num_examples_per_group: 96 | # often happens when a configuration has only one theory, skip print statement 97 | if len(current_config.theories) != 1: 98 | logger.warning("Insufficient examples: {} generated. Sizes:{} KB:\n{}".format( 99 | len(all_questions), 100 | [(tidx, len(final_questions)) for (tidx, final_questions) in 101 | questions_per_theory.items()], 102 | json.dumps(complete_kb, indent=2) 103 | )) 104 | logger.debug("Skipping config: {} Total #questions: {}".format(config_idx, 105 | len(all_questions))) 106 | continue 107 | 108 | # subsample questions to equalize #questions per theory 109 | min_size = min([len(qa) for qa in questions_per_theory.values()]) 110 | subsampled_questions = [] 111 | for qa_per_theory in questions_per_theory.values(): 112 | subsampled_questions.extend(random.sample(qa_per_theory, min_size)) 113 | if len(subsampled_questions) < num_examples_per_group: 114 | logger.warning("Skipping config: {} Sub-sampled questions: {}".format( 115 | config_idx, len(subsampled_questions))) 116 | continue 117 | final_questions = random.sample(subsampled_questions, num_examples_per_group) 118 | output_data["all_qa"] = all_questions 119 | output_data["qa_pairs"] = final_questions 120 | data.append(output_data) 121 | group_idx += 1 122 | # update distribution over configs 123 | config_distribution[config_idx] -= len(self.configs) 124 | if group_idx % 100 == 0: 125 | logger.info("Created {} groups. Attempted: {}".format(group_idx, 126 | num_attempts)) 127 | for theory_key, numqs in numqs_per_theory.items(): 128 | logger.debug("Theory: <{}> \n NumQs: [{}]".format(theory_key, numqs)) 129 | return data 130 | 131 | 132 | if __name__ == '__main__': 133 | args = parse_arguments() 134 | dataset_configs = [] 135 | counter = 0 136 | if args.debug: 137 | logging.basicConfig(level=logging.DEBUG) 138 | else: 139 | logging.basicConfig(level=logging.INFO) 140 | for filename in args.input_json.split(","): 141 | counter += 1 142 | output_dir = "" 143 | # if output is a json file 144 | if args.output.endswith(".json"): 145 | output_dir = os.path.dirname(args.output) 146 | else: 147 | output_dir = args.output 148 | 149 | if filename.endswith(".jsonnet"): 150 | data = json.loads(_jsonnet.evaluate_file(filename)) 151 | # dump the configuration as a the source file 152 | with open(output_dir + "/source{}.json".format(counter), "w") as output_fp: 153 | json.dump(data, output_fp, indent=2) 154 | dataset_config = DatasetBuildConfig(data) 155 | dataset_configs.append(dataset_config) 156 | else: 157 | # dump the configuration as a the source file 158 | copyfile(filename, output_dir + "/source{}.json".format(counter)) 159 | with open(filename, "r") as input_fp: 160 | input_json = json.load(input_fp) 161 | dataset_config = DatasetBuildConfig(input_json) 162 | dataset_configs.append(dataset_config) 163 | 164 | builder = DatasetBuilder(dataset_configs) 165 | data = builder.build_dataset(num_groups=args.num_groups, 166 | num_entities_per_group=args.entity_percent, 167 | num_examples_per_group=args.num_examples_per_group) 168 | num_examples = len(data) 169 | print("Number of example groups: {}".format(num_examples)) 170 | if args.output.endswith(".json"): 171 | print("Single file output name provided (--output file ends with .json)") 172 | print("Dumping examples into a single file instead of train/dev/test splits") 173 | with open(args.output, "w") as output_fp: 174 | json.dump(data, output_fp, indent=4) 175 | else: 176 | shuffle(data) 177 | train_ex = ceil(num_examples * 0.8) 178 | dev_ex = ceil(num_examples * 0.1) 179 | test_ex = num_examples - train_ex - dev_ex 180 | print("Train/Dev/Test: {}/{}/{}".format(train_ex, dev_ex, test_ex)) 181 | files = [args.output + "/train.json", args.output + "/dev.json", args.output + "/test.json"] 182 | datasets = [data[:train_ex], data[train_ex:train_ex + dev_ex], data[train_ex + dev_ex:]] 183 | for file, dataset in zip(files, datasets): 184 | with open(file, "w") as output_fp: 185 | json.dump(dataset, output_fp, indent=4) 186 | -------------------------------------------------------------------------------- /commaqa/dataset/build_submodel_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import re 6 | import string 7 | from math import ceil 8 | from pathlib import Path 9 | from shutil import copyfile 10 | from typing import List 11 | 12 | import _jsonnet 13 | from tqdm import tqdm 14 | 15 | from commaqa.configs.dataset_build_config import DatasetBuildConfig 16 | from commaqa.dataset.utils import get_predicate_args, dict_product, nonempty_answer 17 | from commaqa.execution.utils import build_models 18 | 19 | 20 | def parse_arguments(): 21 | arg_parser = argparse.ArgumentParser(description='Build a ReModeL dataset from inputs') 22 | arg_parser.add_argument('--input_json', type=str, required=True, 23 | help="Input JSON configuration files " 24 | "(comma-separated for multiple files)") 25 | arg_parser.add_argument('--output', "-o", type=str, required=True, help="Output folder") 26 | arg_parser.add_argument('--num_groups', type=int, required=False, default=500, 27 | help="Number of example groups to create") 28 | arg_parser.add_argument('--num_examples_per_group', type=int, required=False, default=10, 29 | help="Number of examples per group") 30 | arg_parser.add_argument('--entity_percent', type=float, required=False, default=0.25, 31 | help="Percentage of entities to sample for each group") 32 | 33 | return arg_parser.parse_args() 34 | 35 | 36 | class SubDatasetBuilder: 37 | def __init__(self, configs: List[DatasetBuildConfig]): 38 | self.configs = configs 39 | 40 | def build_entities(self, entities, ent_type): 41 | m = re.match("list\((.*)\)", ent_type) 42 | if m: 43 | # too many possible permutations, only build a list of size ent_type*2 44 | returned_list = [] 45 | ent_type = m.group(1) 46 | for i in range(len(entities[ent_type]) * 2): 47 | sample_size = random.choice(range(2, 5)) 48 | sampled_ents = random.sample(entities[ent_type], sample_size) 49 | returned_list.append(json.dumps(sampled_ents)) 50 | return returned_list 51 | else: 52 | return entities[ent_type] 53 | 54 | def build_sub_dataset(self, 55 | num_entities_per_group=5, 56 | num_groups: int = 10, num_examples_per_group: int = 10): 57 | per_model_dataset = {} 58 | for g in tqdm(range(num_groups)): 59 | config = random.choice(self.configs) 60 | entities = config.entities.subsample(num_entities_per_group) 61 | complete_kb = {} 62 | for pred in config.predicates: 63 | complete_kb[pred.pred_name] = pred.populate_kb(entities) 64 | 65 | model_library = build_models(config.pred_lang_config.model_config, complete_kb) 66 | # per_model_qa = {} 67 | # per_model_kb = {} 68 | for model, model_configs in config.pred_lang_config.model_config.items(): 69 | all_qa = {} 70 | gold_kb = {} 71 | for model_config in model_configs: 72 | if model_config.init is None: 73 | raise ValueError("Initialization needs to be specified to build the " 74 | "sub-model dataset for {}".format(model_config)) 75 | 76 | # Add the model-specific kb based on the steps 77 | for step in model_config.steps: 78 | qpred, qargs = get_predicate_args(step.question) 79 | if qpred not in gold_kb: 80 | gold_kb[qpred] = complete_kb[qpred] 81 | context = "" 82 | gold_context = "" 83 | for pred in config.predicates: 84 | context_rep = pred.generate_context(complete_kb) 85 | context += context_rep 86 | if pred.pred_name in gold_kb: 87 | gold_context += context_rep 88 | output_data = { 89 | "all_kb": complete_kb, 90 | "kb": gold_kb, 91 | "context": gold_context, 92 | "all_context": context, 93 | } 94 | # Generate questions 95 | assignment_dict = {} 96 | # Initialize question arguments 97 | for key, ent_type in model_config.init.items(): 98 | assignment_dict[key] = self.build_entities(entities, ent_type) 99 | # For each assignment, generate a question 100 | for assignment in dict_product(assignment_dict): 101 | if isinstance(model_config.questions, str): 102 | questions = [model_config.questions] 103 | else: 104 | questions = model_config.questions 105 | # for each question format, generate a question 106 | for question in questions: 107 | source_question = question 108 | for key, val in assignment.items(): 109 | question = question.replace(key, val) 110 | answers, facts_used = model_library[model].ask_question(question) 111 | if nonempty_answer(answers): 112 | if source_question not in all_qa: 113 | all_qa[source_question] = [] 114 | all_qa[source_question].append({ 115 | "question": question, 116 | "answer": answers, 117 | "facts_used": facts_used, 118 | "assignment": assignment, 119 | "config": model_config.to_json(), 120 | "id": "".join( 121 | [random.choice(string.hexdigits) for n in 122 | range(16)]).lower() 123 | }) 124 | # subsample questions to equalize #questions per theory 125 | min_size = min([len(qa) for qa in all_qa.values()]) 126 | subsampled_questions = [] 127 | for qa_per_sourceq in all_qa.values(): 128 | subsampled_questions.extend(random.sample(qa_per_sourceq, min_size)) 129 | qa = random.sample(subsampled_questions, num_examples_per_group) 130 | output_data["all_qa"] = [qa for qa_per_sourceq in all_qa.values() 131 | for qa in qa_per_sourceq] 132 | output_data["qa_pairs"] = qa 133 | if model not in per_model_dataset: 134 | per_model_dataset[model] = [] 135 | per_model_dataset[model].append(output_data) 136 | return per_model_dataset 137 | 138 | 139 | if __name__ == '__main__': 140 | args = parse_arguments() 141 | dataset_configs = [] 142 | counter = 0 143 | 144 | for filename in args.input_json.split(","): 145 | counter += 1 146 | output_dir = "" 147 | if args.output.endswith(".json"): 148 | output_dir = os.path.dirname(args.output) 149 | else: 150 | output_dir = args.output 151 | 152 | if filename.endswith(".jsonnet"): 153 | data = json.loads(_jsonnet.evaluate_file(filename)) 154 | with open(output_dir + "/source{}.json".format(counter), "w") as output_fp: 155 | json.dump(data, output_fp, indent=2) 156 | dataset_config = DatasetBuildConfig(data) 157 | dataset_configs.append(dataset_config) 158 | else: 159 | copyfile(filename, output_dir + "/source{}.json".format(counter)) 160 | with open(filename, "r") as input_fp: 161 | input_json = json.load(input_fp) 162 | dataset_config = DatasetBuildConfig(input_json) 163 | dataset_configs.append(dataset_config) 164 | 165 | builder = SubDatasetBuilder(dataset_configs) 166 | per_model_dataset = builder.build_sub_dataset(num_groups=args.num_groups, 167 | num_entities_per_group=args.entity_percent, 168 | num_examples_per_group=args.num_examples_per_group) 169 | for model, data in per_model_dataset.items(): 170 | num_examples = len(data) 171 | print("Model: {}".format(model)) 172 | print("Number of example groups: {}".format(num_examples)) 173 | train_ex = ceil(num_examples * 0.8) 174 | dev_ex = ceil(num_examples * 0.1) 175 | test_ex = num_examples - train_ex - dev_ex 176 | print("Train/Dev/Test: {}/{}/{}".format(train_ex, dev_ex, test_ex)) 177 | output_dir = args.output + "/" + model 178 | Path(output_dir).mkdir(parents=True, exist_ok=True) 179 | files = [output_dir + "/train.json", output_dir + "/dev.json", output_dir + "/test.json"] 180 | datasets = [data[:train_ex], data[train_ex:train_ex + dev_ex], data[train_ex + dev_ex:]] 181 | for file, dataset in zip(files, datasets): 182 | with open(file, "w") as output_fp: 183 | json.dump(dataset, output_fp, indent=4) 184 | -------------------------------------------------------------------------------- /commaqa/dataset/generate_decomposition_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from copy import deepcopy 4 | from math import ceil 5 | from random import shuffle 6 | 7 | from commaqa.configs.predicate_language_config import ModelQuestionConfig 8 | from commaqa.dataset.utils import nonempty_answer 9 | from commaqa.execution.operation_executer import OperationExecuter 10 | from commaqa.execution.utils import build_models 11 | 12 | 13 | def parse_arguments(): 14 | arg_parser = argparse.ArgumentParser(description='Solve a ReModeL dataset using composition') 15 | arg_parser.add_argument('--input_json', type=str, required=True, 16 | help="Input JSON dataset files") 17 | arg_parser.add_argument('--pred_json', type=str, required=False, help="Output predictions") 18 | arg_parser.add_argument('--decomp_json', type=str, required=False, help="Output decompositions") 19 | arg_parser.add_argument('--max_examples', type=float, required=False, default=1.0, 20 | help="Maximum number of examples to use. " 21 | "If set to <=1.0, use as fraction.") 22 | return arg_parser.parse_args() 23 | 24 | 25 | def build_chain(prev_chain, operation, model, question): 26 | return prev_chain + " QS: ({}) [{}] {}".format(operation, model, question) 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_arguments() 31 | with open(args.input_json, "r") as input_fp: 32 | input_json = json.load(input_fp) 33 | 34 | pred_json = {} 35 | decomp_json = [] 36 | for input_item in input_json: 37 | kb = input_item["kb"] 38 | model_configurations = {} 39 | for model_name, configs in input_item["pred_lang_config"].items(): 40 | model_configurations[model_name] = [ModelQuestionConfig(config) for config in configs] 41 | model_lib = build_models(model_configurations, kb) 42 | 43 | executor = OperationExecuter(model_lib) 44 | for qa_pair in input_item["qa_pairs"]: 45 | qid = qa_pair["id"] 46 | # use oracle decomposition 47 | curr_assignment = {} 48 | last_answer = "" 49 | train_seqs = [] 50 | prev_chain = " QC: " + qa_pair["question"] 51 | for idx, step in enumerate(qa_pair["decomposition"]): 52 | train_seq = build_chain(prev_chain=prev_chain, 53 | operation=step["op"], 54 | model=step["m"], 55 | question=step["q"]) 56 | train_seqs.append(train_seq) 57 | answers, facts_used = executor.execute_operation(operation=step["op"], 58 | model=step["m"], 59 | question=step["q"], 60 | assignments=curr_assignment) 61 | last_answer = answers 62 | if not nonempty_answer(answers): 63 | print("no answer!") 64 | print(step, curr_assignment, kb) 65 | break 66 | prev_chain = train_seq.replace(" QS: ", " QI: ") + " A: " + json.dumps(answers) 67 | curr_assignment["#" + str(idx + 1)] = answers 68 | train_seqs.append(prev_chain + " QS: [EOQ]") 69 | decomp = deepcopy(qa_pair) 70 | decomp["train_seqs"] = train_seqs 71 | decomp_json.append(decomp) 72 | if isinstance(last_answer, list): 73 | pred_json[qid] = last_answer 74 | else: 75 | pred_json[qid] = str(last_answer) 76 | 77 | if args.pred_json: 78 | with open(args.pred_json, "w") as output_fp: 79 | json.dump(pred_json, output_fp, indent=2) 80 | if args.decomp_json: 81 | # sample examples here as they will be ungrouped 82 | if args.max_examples < 1.0: 83 | shuffle(decomp_json) 84 | decomp_json = decomp_json[:ceil(len(decomp_json) * args.max_examples)] 85 | elif args.max_examples > 1.0: 86 | shuffle(decomp_json) 87 | decomp_json = decomp_json[:args.max_examples] 88 | 89 | with open(args.decomp_json, "w") as output_fp: 90 | for decomp in decomp_json: 91 | output_fp.write(json.dumps(decomp) + "\n") 92 | -------------------------------------------------------------------------------- /commaqa/dataset/generate_decompositions_from_chains.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from copy import deepcopy 4 | from math import ceil 5 | from random import shuffle 6 | 7 | from commaqa.inference.utils import LIST_JOINER, EOQ_MARKER, INTERQ_MARKER, ANSWER_MARKER, \ 8 | SIMPQ_MARKER 9 | 10 | 11 | def parse_arguments(): 12 | arg_parser = argparse.ArgumentParser(description='Solve a ReModeL dataset using composition') 13 | arg_parser.add_argument('--input_json', type=str, required=True, 14 | help="Input JSON dataset files") 15 | arg_parser.add_argument('--chains', type=str, required=True, 16 | help="Input chains TSV file") 17 | arg_parser.add_argument('--decomp_json', type=str, required=False, help="Output decompositions") 18 | arg_parser.add_argument('--max_examples', type=float, required=False, default=1.0, 19 | help="Maximum number of examples to use. " 20 | "If set to <=1.0, use as fraction.") 21 | return arg_parser.parse_args() 22 | 23 | 24 | def is_valid_answer(predicted_answer, gold_answer): 25 | if isinstance(gold_answer, list): 26 | gold_answer_str = LIST_JOINER.join(sorted(gold_answer)) 27 | else: 28 | gold_answer_str = str(gold_answer) 29 | 30 | if isinstance(predicted_answer, list): 31 | predicted_answer_str = LIST_JOINER.join(sorted([str(s) for s in predicted_answer])) 32 | else: 33 | predicted_answer_str = str(gold_answer) 34 | # print(predicted_answer_str, gold_answer_str) 35 | return predicted_answer_str == gold_answer_str 36 | 37 | 38 | def build_train_seqs(question_seq): 39 | question_seq = question_seq.strip() + " " + EOQ_MARKER 40 | train_seqs = [question_seq] 41 | while INTERQ_MARKER in question_seq: 42 | answer_idx = question_seq.rfind(ANSWER_MARKER) 43 | question_seq = question_seq[:answer_idx].strip() 44 | interq_idx = question_seq.rfind(INTERQ_MARKER) 45 | question_seq = question_seq[:interq_idx] + SIMPQ_MARKER + question_seq[ 46 | interq_idx + len(INTERQ_MARKER):] 47 | train_seqs.append(question_seq) 48 | return train_seqs 49 | 50 | 51 | if __name__ == '__main__': 52 | args = parse_arguments() 53 | with open(args.input_json, "r") as input_fp: 54 | input_json = json.load(input_fp) 55 | 56 | predictions_per_qid = {} 57 | with open(args.chains, "r") as chains_fp: 58 | for line in chains_fp: 59 | fields = line.strip().split("\t") 60 | qid = fields[0] 61 | if qid not in predictions_per_qid: 62 | predictions_per_qid[qid] = [] 63 | predictions_per_qid[qid].append(fields[1:]) 64 | 65 | decomp_json = [] 66 | num_chains_correct_answer = 0 67 | num_questions_correct_chains = 0 68 | num_question_no_chains = 0 69 | num_questions = 0 70 | num_chains = 0 71 | for input_item in input_json: 72 | for qa_pair in input_item["qa_pairs"]: 73 | qid = qa_pair["id"] 74 | num_questions += 1 75 | if qid not in predictions_per_qid: 76 | # print(qid) 77 | num_question_no_chains += 1 78 | continue 79 | found_match = False 80 | for potential_seq in predictions_per_qid[qid]: 81 | num_chains += 1 82 | if is_valid_answer(json.loads(potential_seq[1]), qa_pair["answer"]): 83 | found_match = True 84 | num_chains_correct_answer += 1 85 | train_seqs = build_train_seqs(potential_seq[0]) 86 | decomp = deepcopy(qa_pair) 87 | decomp["train_seqs"] = train_seqs 88 | decomp_json.append(decomp) 89 | if found_match: 90 | num_questions_correct_chains += 1 91 | 92 | num_questions_with_chains = (num_questions - num_question_no_chains) 93 | print("Num Questions: {}".format(num_questions)) 94 | print("Num Questions with no chains: {} ({:.2f}%)".format( 95 | num_question_no_chains, (num_question_no_chains * 100 / num_questions))) 96 | print("Num Questions with chains: {} ({:.2f}%)".format( 97 | num_questions_with_chains, (num_questions_with_chains * 100 / num_questions))) 98 | print("Num Questions with at least one correct chain: {}" 99 | "({:.2f}% of predicted, {:.2f}% of total)".format( 100 | num_questions_correct_chains, 101 | (num_questions_correct_chains * 100 / num_questions_with_chains), 102 | (num_questions_correct_chains * 100 / num_questions))) 103 | print("Num Chains: {}({:.2f} c per predicted, {:.2f} c per total)".format( 104 | num_chains, num_chains / num_questions_with_chains, num_chains / num_questions)) 105 | print("Num Chains with correct answer: {}({:.2f}%)".format( 106 | num_chains_correct_answer, (num_chains_correct_answer * 100 / num_chains))) 107 | 108 | if args.decomp_json: 109 | # sample examples here as they will be ungrouped 110 | if args.max_examples < 1.0: 111 | shuffle(decomp_json) 112 | decomp_json = decomp_json[:ceil(len(decomp_json) * args.max_examples)] 113 | elif args.max_examples > 1.0: 114 | shuffle(decomp_json) 115 | decomp_json = decomp_json[:args.max_examples] 116 | 117 | with open(args.decomp_json, "w") as output_fp: 118 | for decomp in decomp_json: 119 | output_fp.write(json.dumps(decomp) + "\n") 120 | -------------------------------------------------------------------------------- /commaqa/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | 4 | pred_match = re.compile("(.*)\((.*)\)$") 5 | 6 | 7 | def get_answer_indices(question_str): 8 | return [int(m.group(1)) for m in re.finditer("#(\d)", question_str)] 9 | 10 | 11 | def get_question_indices(question_str): 12 | return [int(m.group(1)) for m in re.finditer("\$(\d)", question_str)] 13 | 14 | 15 | def is_question_var(var_name): 16 | return var_name.startswith("$") 17 | 18 | 19 | def get_predicate_args(predicate_str): 20 | mat = pred_match.match(predicate_str) 21 | if mat is None: 22 | return None, None 23 | predicate = mat.group(1) 24 | pred_args = mat.group(2).split(", ") if " | " not in mat.group(2) else mat.group(2).split(" | ") 25 | return predicate, pred_args 26 | 27 | 28 | def flatten_list(input_list): 29 | output_list = [] 30 | for item in input_list: 31 | if isinstance(item, list): 32 | output_list.extend(flatten_list(item)) 33 | else: 34 | output_list.append(item) 35 | return output_list 36 | 37 | 38 | def align_assignments(target_predicate, source_predicate, source_assignments): 39 | """ 40 | Returns a (map from target_predicate arg name to the assignment in source_assignments), 41 | (map from target_predicate arg name to the source predicate arg) 42 | """ 43 | target_pred, target_args = get_predicate_args(target_predicate) 44 | source_pred, source_args = get_predicate_args(source_predicate) 45 | if target_pred != source_pred: 46 | raise ValueError("Source predicate: {} does not match target predicate: {}".format( 47 | source_predicate, target_predicate 48 | )) 49 | if len(target_args) != len(source_args): 50 | raise ValueError("Number of target arguments: {} don't match source arguments: {}".format( 51 | target_args, source_args 52 | )) 53 | target_assignment = {} 54 | target_assignment_map = {} 55 | for target_arg, source_arg in zip(target_args, source_args): 56 | if source_arg == "?": 57 | if target_arg != "?": 58 | raise ValueError("Source ({}) and Target ({}) predicates have mismatch" 59 | " on '?'".format(source_predicate, target_predicate)) 60 | continue 61 | if source_arg not in source_assignments: 62 | raise ValueError("No assignment for {} in input assignments: {}".format( 63 | source_arg, source_assignments 64 | )) 65 | target_assignment[target_arg] = source_assignments[source_arg] 66 | target_assignment_map[target_arg] = source_arg 67 | return target_assignment, target_assignment_map 68 | 69 | 70 | def dict_product(dicts): 71 | return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values())) 72 | 73 | 74 | def nonempty_answer(answer): 75 | if isinstance(answer, list) and len(answer) == 0: 76 | return False 77 | if isinstance(answer, str) and answer == "": 78 | return False 79 | return True 80 | 81 | 82 | NOANSWER = None 83 | 84 | 85 | def valid_answer(answer): 86 | return answer is not None 87 | -------------------------------------------------------------------------------- /commaqa/execution/README.md: -------------------------------------------------------------------------------- 1 | # Operations and Agents 2 | 3 | ## Building a new agent 4 | Currently our agents are purely defined by the dataset configuration. Specifically, the 5 | [Predicate Language](../configs/README.md#predicate-language) configuration defines the class of 6 | questions answerable by each agent. Internally, the configuration is used to create a 7 | [Model Executer](model_executer.py) that can answer questions matching the language (specified in 8 | the config) by executing the steps (specified in the config). 9 | 10 | ## Using agents in your code 11 | Agents have to be built for each example since each example has a different world context. To 12 | build an agent for a given question in your code, refer to code in 13 | [participant_execution.py](../inference/participant_execution.py#L55-L59) 14 | 15 | 16 | ## Defining a new operation 17 | The operations (as shown in Table 1 and Table 2) are implemented in `operation_executer.py`. You can 18 | add more operations to this class by modifying the [execute_operation](operation_executer.py#L190) 19 | function. 20 | -------------------------------------------------------------------------------- /commaqa/execution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/CommaQA/663fda8fa280306297a805aefb671b04661aef74/commaqa/execution/__init__.py -------------------------------------------------------------------------------- /commaqa/execution/constants.py: -------------------------------------------------------------------------------- 1 | MATH_MODEL = "math_special" 2 | KBLOOKUP_MODEL = "kblookup" 3 | -------------------------------------------------------------------------------- /commaqa/execution/kblookup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from commaqa.dataset.utils import get_predicate_args 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class KBLookup: 9 | def __init__(self, kb): 10 | self.kb = kb 11 | 12 | def ask_question(self, question_predicate): 13 | return self.ask_question_predicate(question_predicate) 14 | 15 | def ask_question_predicate(self, question_predicate): 16 | predicate, pred_args = get_predicate_args(question_predicate) 17 | answers = [] 18 | facts_used = [] 19 | for fact in self.kb[predicate]: 20 | fact_pred, fact_args = get_predicate_args(fact) 21 | if len(pred_args) != len(fact_args): 22 | raise ValueError( 23 | "Mismatch in specification args {} and fact args {}".format( 24 | pred_args, fact_args 25 | )) 26 | mismatch = False 27 | answer = "" 28 | for p, f in zip(pred_args, fact_args): 29 | # KB fact arg doesn't match the predicate arg 30 | if p != "?" and p != f and p != "_": 31 | mismatch = True 32 | # predicate arg is a query, populate answer with fact arg 33 | elif p == "?": 34 | answer = f 35 | # if all args matched, add answer 36 | if not mismatch: 37 | answers.append(answer) 38 | facts_used.append(fact) 39 | if len(answers) == 0: 40 | logger.debug("No matching facts for {}. Facts:\n{}".format(question_predicate, 41 | self.kb[predicate])) 42 | 43 | # If its a boolean query, use number of answers 44 | if "?" not in pred_args: 45 | if len(answers) == 0: 46 | return "no", facts_used 47 | else: 48 | return "yes", facts_used 49 | else: 50 | return answers, facts_used 51 | -------------------------------------------------------------------------------- /commaqa/execution/math_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from json import JSONDecodeError 5 | 6 | from commaqa.execution.model_executer import ModelExecutor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class MathModel(ModelExecutor): 12 | 13 | def __init__(self, **kwargs): 14 | self.func_regex = { 15 | "is_greater\((.+) \| (.+)\)": self.greater_than, 16 | "is_smaller\((.+) \| (.+)\)": self.smaller_than, 17 | "diff\((.+) \| (.+)\)": self.diff, 18 | "belongs_to\((.+) \| (.+)\)": self.belongs_to, 19 | "max\((.+)\)": self.max, 20 | "min\((.+)\)": self.min, 21 | "count\((.+)\)": self.count 22 | 23 | } 24 | super(MathModel, self).__init__(**kwargs) 25 | 26 | @staticmethod 27 | def get_number(num): 28 | # can only extract numbers from strings 29 | if isinstance(num, int) or isinstance(num, float): 30 | return num 31 | if not isinstance(num, str): 32 | return None 33 | try: 34 | item = json.loads(num) 35 | except JSONDecodeError: 36 | logger.debug("Could not JSON parse: " + num) 37 | return None 38 | if isinstance(item, list): 39 | if (len(item)) != 1: 40 | logger.debug("List of values instead of single number in {}".format(num)) 41 | return None 42 | item = item[0] 43 | if isinstance(item, list): 44 | logger.debug("Could not parse float from list within the list: {}".format(item)) 45 | return None 46 | try: 47 | return float(item) 48 | except ValueError: 49 | logger.debug("Could not parse float from: " + item) 50 | return None 51 | 52 | def max(self, groups): 53 | if len(groups) != 1: 54 | raise ValueError("Incorrect regex for max. " 55 | "Did not find 1 group: {}".format(groups)) 56 | try: 57 | entity = json.loads(groups[0]) 58 | 59 | if isinstance(entity, list): 60 | numbers = [] 61 | for x in entity: 62 | num = MathModel.get_number(x) 63 | if num is None: 64 | if self.ignore_input_mismatch: 65 | logger.debug("Cannot parse as number: {}".format(x)) 66 | return None, [] 67 | else: 68 | raise ValueError("Cannot parse as number: {} in {}".format(x, entity)) 69 | numbers.append(num) 70 | else: 71 | logger.debug("max can only handle list of entities. Arg: " + str(entity)) 72 | return None, [] 73 | except JSONDecodeError: 74 | logger.error("Could not parse: {}".format(groups[0])) 75 | raise 76 | return max(numbers), [] 77 | 78 | def min(self, groups): 79 | if len(groups) != 1: 80 | raise ValueError("Incorrect regex for min. " 81 | "Did not find 1 group: {}".format(groups)) 82 | try: 83 | entity = json.loads(groups[0]) 84 | if isinstance(entity, list): 85 | numbers = [] 86 | for x in entity: 87 | num = MathModel.get_number(x) 88 | if num is None: 89 | if self.ignore_input_mismatch: 90 | logger.debug("Cannot parse as number: {}".format(x)) 91 | return None, [] 92 | else: 93 | raise ValueError("Cannot parse as number: {} in {}".format(x, entity)) 94 | numbers.append(num) 95 | else: 96 | logger.debug("min can only handle list of entities. Arg: " + str(entity)) 97 | return None, [] 98 | except JSONDecodeError: 99 | logger.debug("Could not parse: {}".format(groups[0])) 100 | if self.ignore_input_mismatch: 101 | return None, [] 102 | else: 103 | raise 104 | return min(numbers), [] 105 | 106 | def count(self, groups): 107 | if len(groups) != 1: 108 | raise ValueError("Incorrect regex for max. " 109 | "Did not find 1 group: {}".format(groups)) 110 | try: 111 | entity = json.loads(groups[0]) 112 | if isinstance(entity, list): 113 | return len(entity), [] 114 | else: 115 | logger.debug("count can only handle list of entities. Arg: " + str(entity)) 116 | return None, [] 117 | except JSONDecodeError: 118 | logger.debug("Could not parse: {}".format(groups[0])) 119 | if self.ignore_input_mismatch: 120 | return None, [] 121 | else: 122 | raise 123 | 124 | def belongs_to(self, groups): 125 | if len(groups) != 2: 126 | raise ValueError("Incorrect regex for belongs_to. " 127 | "Did not find 2 groups: {}".format(groups)) 128 | try: 129 | entity = json.loads(groups[0]) 130 | if isinstance(entity, list): 131 | if len(entity) > 1: 132 | logger.debug( 133 | "belongs_to can only handle single entity as 1st arg. Args:" + str(groups)) 134 | return None, [] 135 | else: 136 | entity = entity[0] 137 | except JSONDecodeError: 138 | entity = groups[0] 139 | try: 140 | ent_list = json.loads(groups[1]) 141 | except JSONDecodeError: 142 | logger.debug("Could not JSON parse: " + groups[1]) 143 | raise 144 | 145 | if not isinstance(ent_list, list): 146 | logger.debug("belongs_to can only handle lists as 2nd arg. Args:" + str(groups)) 147 | return None, [] 148 | if entity in ent_list: 149 | return "yes", [] 150 | else: 151 | return "no", [] 152 | 153 | def diff(self, groups): 154 | if len(groups) != 2: 155 | raise ValueError("Incorrect regex for diff. " 156 | "Did not find 2 groups: {}".format(groups)) 157 | num1 = MathModel.get_number(groups[0]) 158 | num2 = MathModel.get_number(groups[1]) 159 | if num1 is None or num2 is None: 160 | if self.ignore_input_mismatch: 161 | # can not compare with Nones 162 | return None, [] 163 | else: 164 | raise ValueError("Cannot answer diff with {}".format(groups)) 165 | if num2 > num1: 166 | return round(num2 - num1, 3), [] 167 | else: 168 | return round(num1 - num2, 3), [] 169 | 170 | def greater_than(self, groups): 171 | if len(groups) != 2: 172 | raise ValueError("Incorrect regex for greater_than. " 173 | "Did not find 2 groups: {}".format(groups)) 174 | num1 = MathModel.get_number(groups[0]) 175 | num2 = MathModel.get_number(groups[1]) 176 | if num1 is None or num2 is None: 177 | if self.ignore_input_mismatch: 178 | # can not compare with Nones 179 | return None, [] 180 | else: 181 | raise ValueError("Cannot answer gt with {}".format(groups)) 182 | if num1 > num2: 183 | return "yes", [] 184 | else: 185 | return "no", [] 186 | 187 | def smaller_than(self, groups): 188 | if len(groups) != 2: 189 | raise ValueError("Incorrect regex for smaller_than. " 190 | "Did not find 2 groups: {}".format(groups)) 191 | num1 = MathModel.get_number(groups[0]) 192 | num2 = MathModel.get_number(groups[1]) 193 | if num1 is None or num2 is None: 194 | if self.ignore_input_mismatch: 195 | # can not compare with Nones 196 | return None, [] 197 | else: 198 | raise ValueError("Cannot answer lt with {}".format(groups)) 199 | if num1 < num2: 200 | return "yes", [] 201 | else: 202 | return "no", [] 203 | 204 | def ask_question_predicate(self, question_predicate): 205 | for regex, func in self.func_regex.items(): 206 | m = re.match(regex, question_predicate) 207 | if m: 208 | return func(m.groups()) 209 | raise ValueError("Could not parse: {}".format(question_predicate)) 210 | -------------------------------------------------------------------------------- /commaqa/execution/model_executer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from commaqa.configs.utils import execute_steps 5 | from commaqa.dataset.utils import get_predicate_args, align_assignments, get_question_indices, \ 6 | valid_answer, NOANSWER 7 | from commaqa.execution.constants import KBLOOKUP_MODEL 8 | from commaqa.execution.operation_executer import OperationExecuter 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ModelExecutor: 14 | def __init__(self, predicate_language, model_name, kblookup, ignore_input_mismatch=False): 15 | self.predicate_language = predicate_language 16 | self.model_name = model_name 17 | self.kblookup = kblookup 18 | self.ignore_input_mismatch = ignore_input_mismatch 19 | self.num_calls = 0 20 | 21 | def find_qpred_assignments(self, input_question, question_definition): 22 | question_re = re.escape(question_definition) 23 | varid_groupid = {} 24 | qindices = get_question_indices(question_definition) 25 | for num in qindices: 26 | question_re = question_re.replace("\\$" + str(num), 27 | "(?P.+)") 28 | varid_groupid["$" + str(num)] = "G" + str(num) 29 | 30 | qmatch = re.match(question_re, input_question) 31 | if qmatch: 32 | assignments = {} 33 | for varid, groupid in varid_groupid.items(): 34 | assignments[varid] = qmatch.group(groupid) 35 | return assignments 36 | return None 37 | 38 | def ask_question(self, input_question): 39 | self.num_calls += 1 40 | qpred, qargs = get_predicate_args(input_question) 41 | if qpred is not None: 42 | return self.ask_question_predicate(question_predicate=input_question) 43 | else: 44 | answers, facts_used = None, None 45 | for pred_lang in self.predicate_language: 46 | for question in pred_lang.questions: 47 | assignments = self.find_qpred_assignments(input_question=input_question, 48 | question_definition=question) 49 | if assignments is not None: 50 | new_pred = pred_lang.predicate 51 | for varid, assignment in assignments.items(): 52 | new_pred = new_pred.replace(varid, assignment) 53 | answers, facts_used = self.ask_question_predicate(new_pred) 54 | if valid_answer(answers): 55 | # if this is valid answer, return it 56 | return answers, facts_used 57 | 58 | # if answers is not None: 59 | # # some match found for the question but no valid answer. 60 | # # Return the last matching answer. 61 | # return answers, facts_used 62 | if not self.ignore_input_mismatch: 63 | raise ValueError("No matching question found for {} " 64 | "in pred_lang:\n{}".format(input_question, 65 | self.predicate_language)) 66 | else: 67 | # no matching question. return NOANSWER 68 | return NOANSWER, [] 69 | 70 | def ask_question_predicate(self, question_predicate): 71 | qpred, qargs = get_predicate_args(question_predicate) 72 | for pred_lang in self.predicate_language: 73 | mpred, margs = get_predicate_args(pred_lang.predicate) 74 | if mpred != qpred: 75 | continue 76 | if pred_lang.steps: 77 | model_library = {KBLOOKUP_MODEL: self.kblookup} 78 | kb_executor = OperationExecuter(model_library) 79 | source_assignments = {x: x for x in qargs} 80 | curr_assignment, assignment_map = align_assignments( 81 | target_predicate=pred_lang.predicate, 82 | source_predicate=question_predicate, 83 | source_assignments=source_assignments 84 | ) 85 | assignments = execute_steps(steps=pred_lang.steps, 86 | input_assignments=curr_assignment, 87 | executer=kb_executor, 88 | pred_lang_config=None, 89 | input_model=KBLOOKUP_MODEL) 90 | 91 | if assignments: 92 | last_answer = pred_lang.steps[-1].answer 93 | return assignments[last_answer], assignments["facts_used"] 94 | elif assignments is None: 95 | # execution failed, try next predicate 96 | continue 97 | else: 98 | logger.debug("No answer found for question: {}".format(question_predicate)) 99 | return [], [] 100 | else: 101 | return self.kblookup.ask_question_predicate(question_predicate) 102 | # No match found for predicate 103 | error = "No matching predicate for {}".format(question_predicate) 104 | if self.ignore_input_mismatch: 105 | logger.debug(error) 106 | return NOANSWER, [] 107 | else: 108 | raise ValueError(error) 109 | -------------------------------------------------------------------------------- /commaqa/execution/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from commaqa.execution.constants import MATH_MODEL 4 | from commaqa.execution.kblookup import KBLookup 5 | from commaqa.execution.math_model import MathModel 6 | from commaqa.execution.model_executer import ModelExecutor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def build_models(pred_lang_config, complete_kb, ignore_input_mismatch=False): 12 | model_library = {} 13 | kblookup = KBLookup(kb=complete_kb) 14 | for model_name, configs in pred_lang_config.items(): 15 | if model_name == MATH_MODEL: 16 | model = MathModel(predicate_language=configs, 17 | model_name=model_name, 18 | kblookup=kblookup, 19 | ignore_input_mismatch=ignore_input_mismatch) 20 | else: 21 | model = ModelExecutor(predicate_language=configs, 22 | model_name=model_name, 23 | kblookup=kblookup, 24 | ignore_input_mismatch=ignore_input_mismatch) 25 | model_library[model_name] = model 26 | return model_library 27 | -------------------------------------------------------------------------------- /commaqa/inference/README.md: -------------------------------------------------------------------------------- 1 | # Running Inference 2 | There are two modes of inference as described in our paper. 3 | 4 | ## Greedy Search 5 | Greedy Search selects the most likely question decomposition at each step rather than considering 6 | multiple decomposition strategies. This is much faster than beam search but can not recover from any 7 | failures, e.g. if the most likely decomposition at a given step asks a question to the textqa agent 8 | but it can not answer it. 9 | 10 | To run inference using greedy search, run 11 | ```shell 12 | model_path=[PATH TO DECOMPOSER MODEL] \ 13 | remodel_path=[PATH TO DATASET FOLDER IN COMMAQA FORMAT] \ 14 | filename=[FILENAME, e.g., train/dev/test.json] \ 15 | python commaqa/inference/configurable_inference.py \ 16 | --input [FILE IN DROP FORMAT] \ 17 | --config configs/inference/commaqav1_greedy_search.jsonnet \ 18 | --reader drop \ 19 | --output predictions.json 20 | ``` 21 | 22 | 23 | ## Beam Search 24 | Since our dataset (and other tasks in general) don't always have a pre-determined strategy to answer 25 | a question, we may need to consider multiple question decompositions at each step and then select 26 | the ones that do succeed. We use beam search to consider multiple decompositions at each step. To 27 | run inference in this mode, use: 28 | 29 | ```shell 30 | model_path=[PATH TO DECOMPOSER MODEL] \ 31 | remodel_path=[PATH TO DATASET FOLDER IN COMMAQA FORMAT] \ 32 | filename=[FILENAME, e.g., train/dev/test.json] \ 33 | python commaqa/inference/configurable_inference.py \ 34 | --input [FILE IN DROP FORMAT] \ 35 | --config configs/inference/commaqav1_beam_search.jsonnet \ 36 | --reader drop \ 37 | --output predictions.json 38 | ``` 39 | 40 | 41 | ## Inference using provided dataset and models 42 | For example, to run inference on CommaQA-E using the provided [datasets](../../README.md#Dataset) 43 | and [models](../../README.md#Models), 44 | 1. Unzip the dataset `commaqa_explicit.zip` into `commaqa_explicit` 45 | 2. Unzip the model `commaqa_e_oracle_model.zip` into `commaqa_explicit_oracle_model` 46 | 3. Call inference: 47 | ```shell 48 | model_path=commaqa_explicit_oracle_model/ \ 49 | remodel_path=commaqa_explicit/commaqa/ \ 50 | filename=test.json \ 51 | python commaqa/inference/configurable_inference.py \ 52 | --input commaqa_explicit/drop/${filename} \ 53 | --config configs/inference/commaqav1_beam_search.jsonnet \ 54 | --reader drop \ 55 | --output predictions.json 56 | ``` 57 | 58 | You can change the dataset and model paths to run inference on a different split. You can run greedy 59 | inference by changing the config file to `commaqav1_greedy_search.jsonnet`. -------------------------------------------------------------------------------- /commaqa/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/CommaQA/663fda8fa280306297a805aefb671b04661aef74/commaqa/inference/__init__.py -------------------------------------------------------------------------------- /commaqa/inference/configurable_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | 6 | import _jsonnet 7 | 8 | from commaqa.inference.constants import MODEL_NAME_CLASS, READER_NAME_CLASS 9 | from commaqa.inference.dataset_readers import DatasetReader 10 | from commaqa.inference.model_search import ( 11 | ModelController, 12 | BestFirstDecomposer, QuestionGeneratorData) 13 | from commaqa.inference.utils import get_environment_variables 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def parse_arguments(): 19 | arg_parser = argparse.ArgumentParser(description='Convert HotPotQA dataset into SQUAD format') 20 | arg_parser.add_argument('--input', type=str, required=True, help="Input QA file") 21 | arg_parser.add_argument('--output', type=str, required=True, help="Output file") 22 | arg_parser.add_argument('--config', type=str, required=True, help="Model configs") 23 | arg_parser.add_argument('--reader', type=str, required=True, help="Dataset reader", 24 | choices=READER_NAME_CLASS.keys()) 25 | arg_parser.add_argument('--debug', action='store_true', default=False, 26 | help="Debug output") 27 | arg_parser.add_argument('--demo', action='store_true', default=False, 28 | help="Demo mode") 29 | arg_parser.add_argument('--threads', default=1, type=int, 30 | help="Number of threads (use MP if set to >1)") 31 | return arg_parser.parse_args() 32 | 33 | 34 | def load_decomposer(config_map): 35 | print("loading participant models (might take a while)...") 36 | model_map = {} 37 | for key, value in config_map["models"].items(): 38 | class_name = value.pop("name") 39 | if class_name not in MODEL_NAME_CLASS: 40 | raise ValueError("No class mapped to model name: {} in MODEL_NAME_CLASS:{}".format( 41 | class_name, MODEL_NAME_CLASS)) 42 | model = MODEL_NAME_CLASS[class_name](**value) 43 | if key in config_map: 44 | raise ValueError("Overriding key: {} with value: {} using instantiated model of type:" 45 | " {}".format(key, config_map[key], class_name)) 46 | config_map[key] = model.query 47 | model_map[key] = model 48 | ## instantiating 49 | controller = ModelController(config_map, QuestionGeneratorData) 50 | decomposer = BestFirstDecomposer(controller) 51 | return decomposer, model_map 52 | 53 | 54 | if __name__ == "__main__": 55 | 56 | args = parse_arguments() 57 | if args.debug: 58 | logging.basicConfig(level=logging.DEBUG) 59 | else: 60 | logging.basicConfig(level=logging.ERROR) 61 | if args.config.endswith(".jsonnet"): 62 | ext_vars = get_environment_variables() 63 | logger.info("Parsing config with external variables: {}".format(ext_vars)) 64 | config_map = json.loads(_jsonnet.evaluate_file(args.config, ext_vars=ext_vars)) 65 | else: 66 | with open(args.config, "r") as input_fp: 67 | config_map = json.load(input_fp) 68 | 69 | decomposer, model_map = load_decomposer(config_map) 70 | reader: DatasetReader = READER_NAME_CLASS[args.reader]() 71 | 72 | print("Running decomposer on examples") 73 | qid_answer_chains = [] 74 | 75 | if args.demo: 76 | while True: 77 | qid = input("QID: ") 78 | question = input("Question: ") 79 | example = { 80 | "qid": qid, 81 | "query": question, 82 | "question": question 83 | } 84 | final_state, other_states = decomposer.find_answer_decomp(example, debug=args.debug) 85 | if final_state is None: 86 | print("FAILED!") 87 | else: 88 | if args.debug: 89 | for other_state in other_states: 90 | data = other_state.data 91 | for q, a, s in zip(data["question_seq"], data["answer_seq"], 92 | data["score_seq"]): 93 | print("Q: {} A: {} S:{}".format(q, a, s), end='\t') 94 | print("Score: " + str(other_state._score)) 95 | data = final_state._data 96 | chain = example["question"] 97 | for q, a in zip(data["question_seq"], data["answer_seq"]): 98 | chain += " Q: {} A: {}".format(q, a) 99 | chain += " S: " + str(final_state._score) 100 | print(chain) 101 | else: 102 | if args.threads > 1: 103 | import multiprocessing as mp 104 | 105 | mp.set_start_method("spawn") 106 | with mp.Pool(args.threads) as p: 107 | qid_answer_chains = p.map(decomposer.return_qid_prediction, 108 | reader.read_examples(args.input)) 109 | else: 110 | for example in reader.read_examples(args.input): 111 | qid_answer_chains.append( 112 | decomposer.return_qid_prediction(example, debug=args.debug)) 113 | 114 | num_call_metrics = {} 115 | for participant in model_map.values(): 116 | for model, num_calls in participant.return_model_calls().items(): 117 | print("Number of calls to {}: {}".format(model, num_calls)) 118 | num_call_metrics[model] = num_calls 119 | metrics_json = { 120 | "num_calls": num_call_metrics 121 | } 122 | metrics_file = os.path.join(os.path.dirname(args.output), "metrics.json") 123 | 124 | with open(metrics_file, "w") as output_fp: 125 | json.dump(metrics_json, output_fp) 126 | 127 | predictions = {x[0]: x[1] for x in qid_answer_chains} 128 | with open(args.output, "w") as output_fp: 129 | json.dump(predictions, output_fp) 130 | 131 | chains = [x[2] for x in qid_answer_chains] 132 | ext_index = args.output.rfind(".") 133 | chain_tsv = args.output[:ext_index] + "_chains.tsv" 134 | with open(chain_tsv, "w") as output_fp: 135 | for chain in chains: 136 | output_fp.write(chain + "\n") 137 | -------------------------------------------------------------------------------- /commaqa/inference/constants.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from commaqa.inference.dataset_readers import HotpotQAReader, DatasetReader, DropReader 4 | from commaqa.inference.participant_execution import ExecutionParticipant 5 | from commaqa.inference.participant_qgen import LMGenParticipant, RandomGenParticipant 6 | from commaqa.inference.participant_util import DumpChainsParticipant 7 | 8 | MODEL_NAME_CLASS = { 9 | "lmgen": LMGenParticipant, 10 | "randgen": RandomGenParticipant, 11 | "dump_chains": DumpChainsParticipant, 12 | "operation_executer": ExecutionParticipant 13 | } 14 | 15 | READER_NAME_CLASS: Dict[str, DatasetReader] = { 16 | "hotpot": HotpotQAReader, 17 | "drop": DropReader 18 | } 19 | -------------------------------------------------------------------------------- /commaqa/inference/dataset_readers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class DatasetReader: 5 | 6 | def read_examples(self, file): 7 | return NotImplementedError("read_examples not implemented by " + self.__class__.__name__) 8 | 9 | 10 | class HotpotQAReader(DatasetReader): 11 | 12 | def read_examples(self, file): 13 | with open(file, 'r') as input_fp: 14 | input_json = json.load(input_fp) 15 | 16 | for entry in input_json: 17 | yield { 18 | "qid": entry["_id"], 19 | "query": entry["question"], 20 | # metadata 21 | "answer": entry["answer"], 22 | "question": entry["question"], 23 | "type": entry.get("type", ""), 24 | "level": entry.get("level", "") 25 | } 26 | 27 | 28 | def format_drop_answer(answer_json): 29 | if answer_json["number"]: 30 | return answer_json["number"] 31 | if len(answer_json["spans"]): 32 | return answer_json["spans"] 33 | # only date possible 34 | date_json = answer_json["date"] 35 | if not (date_json["day"] or date_json["month"] or date_json["year"]): 36 | print("Number, Span or Date not set in {}".format(answer_json)) 37 | return None 38 | return date_json["day"] + "-" + date_json["month"] + "-" + date_json["year"] 39 | 40 | 41 | class DropReader(DatasetReader): 42 | 43 | def read_examples(self, file): 44 | with open(file, 'r') as input_fp: 45 | input_json = json.load(input_fp) 46 | 47 | for paraid, item in input_json.items(): 48 | para = item["passage"] 49 | for qa_pair in item["qa_pairs"]: 50 | question = qa_pair["question"] 51 | qid = qa_pair["query_id"] 52 | answer = format_drop_answer(qa_pair["answer"]) 53 | yield { 54 | "qid": qid, 55 | "query": question, 56 | # metadata 57 | "answer": answer, 58 | "question": question 59 | } 60 | -------------------------------------------------------------------------------- /commaqa/inference/model_search.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import heapq 3 | import json 4 | import logging 5 | 6 | 7 | class BasicDataInstance(dict): 8 | _REQUIRED_ATTRS = set([]) 9 | 10 | def __init__(self, input_data): 11 | dict.__init__({}) 12 | self.update(input_data) 13 | for item in type(self)._REQUIRED_ATTRS: 14 | if item not in self: 15 | self[item] = [] 16 | 17 | 18 | class QuestionGeneratorData(BasicDataInstance): 19 | _REQUIRED_ATTRS = set([ 20 | "question_seq", 21 | "subquestion_seq", 22 | "answer_seq", 23 | "command_seq", 24 | "model_seq", 25 | "operation_seq", 26 | "score_seq", 27 | "para_seq" 28 | ]) 29 | 30 | 31 | class ParticipantModel(object): 32 | """Base model in this case for coordinating different models. Provides a general 33 | class to structure all contributing models (in this case, by defining a single 34 | function `query`, which is the single method that is called for each model). 35 | 36 | """ 37 | 38 | def query(self, state, debug=False): 39 | """The main function that interfaces with the overall search and 40 | model controller, and manipulates the incoming data. 41 | 42 | :param state: the state of controller and model flow. 43 | :type state: launchpadqa.question_search.model_search.SearchState 44 | :rtype: list 45 | """ 46 | raise NotImplementedError("Must implement to work inside of controller!") 47 | 48 | def return_model_calls(self): 49 | """ 50 | :return: a dict of made by this participant 51 | """ 52 | raise NotImplementedError("Must implement to work inside of controller!") 53 | 54 | 55 | class ModelController(object): 56 | """This class is a `ModelController` that takes multiple (arbitrary) 57 | models and a control specification of how to interface the different 58 | models (which can be thought of as a kind of state graph). For example 59 | 60 | """ 61 | 62 | def __init__(self, model_list, 63 | data_class=BasicDataInstance): 64 | """Create an instance of a ComplexModel 65 | 66 | :param model_list: a list of models with identifiers and 67 | control flow. 68 | :type model_list: dict 69 | """ 70 | if "start_state" not in model_list: 71 | raise ValueError('Must specify start state') 72 | if "end_state" not in model_list: 73 | raise ValueError('Must specify end state') 74 | self.model_list = model_list 75 | self.data_class = data_class 76 | 77 | def execute(self, state, debug=False): 78 | """Executes a command and query 79 | 80 | :param state: a given state in search 81 | :type state: SearchState (defined here) 82 | :returns: a list of output 83 | :rtype: list 84 | """ 85 | if state.next not in self.model_list: 86 | self.logger.error("Can not handle next state: " + state.next) 87 | return [] 88 | try: 89 | model_func = self.model_list[state.next] 90 | 91 | model_output = model_func(state, debug=debug) 92 | 93 | if not isinstance(model_output, list): 94 | return [model_output] 95 | return model_output 96 | except Exception as e: 97 | self.logger.error(e, exc_info=True) 98 | raise ValueError('Error caught during model execution: %s' % e) 99 | 100 | def init_data(self, data_instance): 101 | """Create an initialized version of the data object 102 | that will get through around. 103 | 104 | :param data_instance: any arbitrary piece of data. 105 | :rtype: self.data_class 106 | """ 107 | return self.data_class(data_instance) 108 | 109 | def command_model(self, command): 110 | return "command := %s" % \ 111 | (self.model_list[command].__name__) 112 | 113 | def key_val(self, key): 114 | return self.model_list["keys"][key] 115 | 116 | @property 117 | def start_state(self): 118 | return self.model_list["start_state"] 119 | 120 | @property 121 | def end_state(self): 122 | return self.model_list["end_state"] 123 | 124 | @property 125 | def logger(self): 126 | """Returns a logger instance 127 | """ 128 | level = '.'.join([__name__, type(self).__name__]) 129 | return logging.getLogger(level) 130 | 131 | 132 | ## utility class for controlling and recording search state 133 | 134 | class SearchState(object): 135 | """Tracks and records the state of a given search. 136 | 137 | """ 138 | 139 | def __init__(self, json_data, 140 | command, 141 | score=0.0, 142 | last_output='UNKNOWN', 143 | ): 144 | """Keep track of different stages in the state 145 | 146 | :param json_data: some basic, json represntation of data 147 | """ 148 | self._data = json_data 149 | self._score = score 150 | self._next = command 151 | self._last_output = last_output 152 | 153 | def copy(self): 154 | """Does a deep copy of the state 155 | 156 | :returns: new search state 157 | """ 158 | new_data = copy.deepcopy(self._data) 159 | new_score = copy.deepcopy(self._score) 160 | new_next = copy.deepcopy(self._next) 161 | 162 | return SearchState( 163 | new_data, 164 | new_next, 165 | new_score, 166 | last_output="UNKNOWN", 167 | ) 168 | 169 | @property 170 | def last_output(self): 171 | return self._last_output 172 | 173 | @last_output.setter 174 | def last_output(self, new_output): 175 | self._last_output = new_output 176 | 177 | ## important to implement to work 178 | ## with the heap datastructures 179 | def __lt__(self, other): 180 | if self.score < other.score: 181 | return True 182 | return False 183 | 184 | def __eq__(self, other): 185 | if self.score == other.score: 186 | return True 187 | return False 188 | 189 | @property 190 | def data(self): 191 | return self._data 192 | 193 | @property 194 | def score(self): 195 | return self._score 196 | 197 | @property 198 | def next(self): 199 | return self._next 200 | 201 | @next.setter 202 | def next(self, value): 203 | self._next = value 204 | 205 | @data.setter 206 | def data(self, value): 207 | self._data = value 208 | 209 | ## string method (especially for debugging) 210 | def __str__(self): 211 | return "[OUTPUT] val=%s [SCORE] %s" % (self._last_output, 212 | str(self._score)) 213 | 214 | 215 | ## THE BASIC SEARCH STRATEGIES (largely from the other code) 216 | 217 | class QuestionSearchBase(object): 218 | 219 | def __init__(self, model_controller): 220 | """Create a `QuestionDecomposer instance` 221 | 222 | :param model_ensemble: a collection of models with control instructions 223 | """ 224 | self.controller = model_controller 225 | 226 | def find_answer_decomp(self, json_input, debug=False): 227 | """Main question decomposition function 228 | 229 | :param json_input: the input to all of the models. 230 | """ 231 | raise NotImplementedError 232 | 233 | @classmethod 234 | def from_config(cls, config): 235 | """Load a model from configuration 236 | 237 | :param config: the global configuration 238 | """ 239 | pass 240 | 241 | def return_qid_prediction(self, example, debug=False): 242 | final_state, other_states = self.find_answer_decomp(example, debug=debug) 243 | if final_state is None: 244 | print(example["question"] + " FAILED!") 245 | chain = example["qid"] + "\t" + example["question"] 246 | return (example["qid"], "", chain) 247 | else: 248 | data = final_state._data 249 | chain = example["qid"] + "\t" + example["question"] 250 | for m, q, a in zip(data["model_seq"], data["question_seq"], data["answer_seq"]): 251 | chain += "\tQ: ({}) {} A: {}".format(m, q, a) 252 | chain += "\tS: " + str(final_state._score) 253 | print(chain) 254 | final_answer = data["answer_seq"][-1] 255 | try: 256 | json_answer = json.loads(final_answer) 257 | # use this only if list (ignore numbers, etc) 258 | if isinstance(json_answer, list): 259 | final_answer = json_answer 260 | except ValueError: 261 | # Not a valid json ignore 262 | pass 263 | return (example["qid"], final_answer, chain) 264 | 265 | 266 | class BestFirstDecomposer(QuestionSearchBase): 267 | 268 | def find_answer_decomp(self, json_input, debug=False): 269 | """Run the question decomposer. The main function here is to use 270 | the controller to pass around inputs to the different models, then 271 | keep a track of the search state and terminate when the shortest path 272 | has been found. 273 | 274 | :param json_input: some input to the model 275 | """ 276 | ## start state of controller : e.g., generate 277 | start_command = self.controller.start_state 278 | start_data = self.controller.init_data(json_input) 279 | 280 | ## min-heap 281 | heap = [] 282 | init_input = json_input["question"] if json_input["question"] else "UNKNOWN" 283 | if debug: print("[START QUERY] : %s" % init_input) 284 | 285 | init_state = SearchState(start_data, ## initial input 286 | start_command, ## starting point 287 | score=0.0, ## starting score 288 | ) 289 | 290 | ## push it to heap 291 | heapq.heappush(heap, init_state) 292 | max_step = 0 293 | 294 | ## todo : add constraints on search (e.g., beam sizes, etc..) 295 | 296 | ## start the main search 297 | while True: 298 | if len(heap) == 0: 299 | if debug: print("[FAILED]: %s" % init_input) 300 | return None, [] 301 | 302 | ## pop from heap 303 | current_state = heapq.heappop(heap) 304 | 305 | if debug: 306 | print("[MIN_STATE] command=%s" % (current_state.next)) 307 | 308 | ## end state 309 | if current_state.next == self.controller.end_state: 310 | if debug: print("[TERMINATED]\n%s" % current_state) 311 | return current_state, heap 312 | 313 | ## generate output and new stated 314 | for new_state in self.controller.execute(current_state, debug=debug): 315 | ## debug view 316 | if debug: print("\t%s" % new_state) 317 | 318 | ## push onto heap 319 | heapq.heappush(heap, new_state) 320 | -------------------------------------------------------------------------------- /commaqa/inference/participant_execution.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | 5 | from commaqa.configs.predicate_language_config import ModelQuestionConfig 6 | from commaqa.dataset.utils import valid_answer, nonempty_answer 7 | from commaqa.execution.operation_executer import OperationExecuter 8 | from commaqa.execution.utils import build_models 9 | from commaqa.inference.model_search import ParticipantModel 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class ExecutionParticipant(ParticipantModel): 15 | def __init__(self, remodel_file, next_model="gen", skip_empty_answers=False): 16 | self.next_model = next_model 17 | self.skip_empty_answers = skip_empty_answers 18 | self.per_model_calls = {"executer": 0, "op_executer": 0} 19 | if remodel_file: 20 | with open(remodel_file, "r") as input_fp: 21 | input_json = json.load(input_fp) 22 | self.kb_lang_groups = [] 23 | self.qid_to_kb_lang_idx = {} 24 | for input_item in input_json: 25 | kb = input_item["kb"] 26 | pred_lang = input_item["pred_lang_config"] 27 | idx = len(self.kb_lang_groups) 28 | self.kb_lang_groups.append((kb, pred_lang)) 29 | for qa_pair in input_item["qa_pairs"]: 30 | qid = qa_pair["id"] 31 | self.qid_to_kb_lang_idx[qid] = idx 32 | self.operation_regex = re.compile("\((.+)\) \[([^\]]+)\] (.*)") 33 | 34 | def return_model_calls(self): 35 | return self.per_model_calls 36 | 37 | def query(self, state, debug=False): 38 | """The main function that interfaces with the overall search and 39 | model controller, and manipulates the incoming data. 40 | 41 | :param state: the state of controller and model flow. 42 | :type state: launchpadqa.question_search.model_search.SearchState 43 | :rtype: list 44 | """ 45 | ## the data 46 | data = state._data 47 | self.per_model_calls["executer"] += 1 48 | step_model_key = "executer_step{}".format(len(data["question_seq"])) 49 | if step_model_key not in self.per_model_calls: 50 | self.per_model_calls[step_model_key] = 0 51 | self.per_model_calls[step_model_key] += 1 52 | 53 | question = data["question_seq"][-1] 54 | qid = data["qid"] 55 | (kb, pred_lang) = self.kb_lang_groups[self.qid_to_kb_lang_idx[qid]] 56 | model_configurations = {} 57 | for model_name, configs in pred_lang.items(): 58 | model_configurations[model_name] = [ModelQuestionConfig(config) for config in configs] 59 | model_lib = build_models(model_configurations, kb, ignore_input_mismatch=True) 60 | ### run the model (as before) 61 | if debug: print(": %s, qid=%s" % (question, qid)) 62 | m = self.operation_regex.match(question) 63 | if m is None: 64 | logger.debug("No match for {}".format(question)) 65 | return [] 66 | assignment = {} 67 | for ans_idx, ans in enumerate(data["answer_seq"]): 68 | assignment["#" + str(ans_idx + 1)] = json.loads(ans) 69 | executer = OperationExecuter(model_library=model_lib, ignore_input_mismatch=True) 70 | answers, facts_used = executer.execute_operation(operation=m.group(1), 71 | model=m.group(2), 72 | question=m.group(3), 73 | assignments=assignment) 74 | for model_name, model in model_lib.items(): 75 | if model_name not in self.per_model_calls: 76 | self.per_model_calls[model_name] = 0 77 | self.per_model_calls[model_name] += model.num_calls 78 | 79 | self.per_model_calls["op_executer"] += executer.num_calls 80 | if not valid_answer(answers): 81 | logger.debug("Invalid answer for qid: {} question: {} chain: {}!".format( 82 | qid, question, ", ".join(data["question_seq"]))) 83 | return [] 84 | if self.skip_empty_answers and not nonempty_answer(answers): 85 | logger.debug("Empty answer for qid: {} question: {} chain: {}!".format( 86 | qid, question, ", ".join(data["question_seq"]))) 87 | return [] 88 | 89 | # copy state 90 | new_state = state.copy() 91 | 92 | ## add answer 93 | new_state.data["answer_seq"].append(json.dumps(answers)) 94 | new_state.data["para_seq"].append("") 95 | new_state.data["command_seq"].append("qa") 96 | new_state.data["model_seq"].append(m.group(2)) 97 | new_state.data["operation_seq"].append(m.group(1)) 98 | new_state.data["subquestion_seq"].append(m.group(3)) 99 | ## change output 100 | new_state.last_output = answers 101 | new_state.next = self.next_model 102 | 103 | return [new_state] 104 | -------------------------------------------------------------------------------- /commaqa/inference/participant_qgen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import random 4 | import re 5 | from itertools import product, permutations 6 | 7 | from commaqa.inference.model_search import ParticipantModel 8 | from commaqa.inference.utils import get_sequence_representation, stem_filter_tokenization, BLANK, \ 9 | stop_words_set 10 | from commaqa.models.generator import LMGenerator 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LMGenParticipant(ParticipantModel): 16 | 17 | def __init__(self, scale_by_step=1, add_eos=False, add_prefix="", next_model="execute", 18 | end_state="[EOQ]", **kwargs): 19 | self.scale_by_step = scale_by_step 20 | self.add_eos = add_eos 21 | self.add_prefix = add_prefix 22 | self.next_model = next_model 23 | self.end_state = end_state 24 | self.num_calls = 0 25 | self.lmgen = LMGenerator(**kwargs) 26 | 27 | def return_model_calls(self): 28 | return {"lmgen": self.num_calls} 29 | 30 | def query(self, state, debug=False): 31 | """The main function that interfaces with the overall search and 32 | model controller, and manipulates the incoming data. 33 | 34 | :param data: should have a dictionary as input containing 35 | mutable data 36 | :type data: dict 37 | :param state: the state of controller and model flow. 38 | :type state: launchpadqa.question_search.model_search.SearchState 39 | :rtype: list 40 | :raises: ValueError 41 | """ 42 | ## first checks state of `json_input` to figure out how to format things 43 | ## the first question 44 | data = state.data 45 | question_seq = data["question_seq"] 46 | answer_seq = data["answer_seq"] 47 | gen_seq = get_sequence_representation(origq=data["query"], question_seq=question_seq, 48 | answer_seq=answer_seq) 49 | if self.add_prefix: 50 | gen_seq = self.add_prefix + gen_seq 51 | if self.add_eos: 52 | gen_seq = gen_seq + "" 53 | 54 | if debug: print(": %s" % gen_seq) 55 | 56 | ## eventual output 57 | new_states = [] 58 | ## go through generated questions 59 | output_seq_scores = self.lmgen.generate_text_sequence(gen_seq) 60 | self.num_calls += 1 61 | observed_outputs = set() 62 | for (output_seq, score) in output_seq_scores: 63 | output = output_seq.strip() 64 | # catch potentially spurious duplicates 65 | if output in observed_outputs: 66 | continue 67 | else: 68 | observed_outputs.add(output) 69 | # copy state 70 | new_state = state.copy() 71 | ## add new question to question_seq 72 | new_state.data["question_seq"].append(output) 73 | if output == self.end_state: 74 | new_state.next = self.end_state 75 | else: 76 | new_state.next = self.next_model 77 | # lower is better, same as the scores returned by generate_text_sequence 78 | assert score >= 0, "Score from generation assumed to be +ve. Got: {}! Needs to be " \ 79 | "+ve to ensure monotonically increasing scores as expected by the" \ 80 | " search.".format(score) 81 | new_state._score += score 82 | new_state.data["score_seq"].append(score) 83 | new_state.data["command_seq"].append("gen") 84 | ## mark the last output 85 | new_state.last_output = output 86 | new_states.append(new_state) 87 | ## 88 | return new_states 89 | 90 | 91 | class RandomGenParticipant(ParticipantModel): 92 | 93 | def __init__(self, operations_file, model_questions_file, sample_operations, sample_questions, 94 | max_steps=6, next_model="execute", topk_questions=True, end_state="[EOQ]"): 95 | self.operations = self.load_operations(operations_file) 96 | self.model_questions = self.load_model_questions(model_questions_file) 97 | self.sample_operations = sample_operations 98 | self.sample_questions = sample_questions 99 | self.end_state = end_state 100 | self.next_model = next_model 101 | self.max_steps = max_steps 102 | self.num_calls = 0 103 | self.topk_questions = topk_questions 104 | 105 | def return_model_calls(self): 106 | return {"randomgen": self.num_calls} 107 | 108 | def load_operations(self, operations_file): 109 | with open(operations_file, "r") as input_fp: 110 | ops = [x.strip() for x in input_fp.readlines()] 111 | return ops 112 | 113 | def load_model_questions(self, model_questions_file): 114 | model_question_list = {} 115 | with open(model_questions_file, "r") as input_fp: 116 | for line in input_fp: 117 | fields = line.strip().split("\t") 118 | model = fields[0] 119 | if model not in model_question_list: 120 | model_question_list[model] = [] 121 | for question in fields[1:]: 122 | question_entities = self.find_question_entities(question) 123 | for q_ent in question_entities: 124 | question = question.replace(q_ent, BLANK) 125 | model_question_list[model].append(question) 126 | # get unique questions 127 | output_model_questions = [] 128 | for model_key, question_list in model_question_list.items(): 129 | unique_questions = list(set(question_list)) 130 | for q in unique_questions: 131 | output_model_questions.append((model_key, q)) 132 | print(model_key, q) 133 | logger.info("{} Questions in {} language".format(len(unique_questions), 134 | model_key)) 135 | 136 | return output_model_questions 137 | 138 | def select(self, population, sample_size_or_prop, samplek=True): 139 | if sample_size_or_prop >= 1: 140 | if samplek: 141 | return random.sample(population, k=sample_size_or_prop) 142 | else: 143 | return population[:sample_size_or_prop] 144 | else: 145 | if samplek: 146 | return random.sample(population, k=math.ceil(sample_size_or_prop * len(population))) 147 | else: 148 | return population[:math.ceil(sample_size_or_prop * len(population))] 149 | 150 | def build_end_state(self, state): 151 | new_state = state.copy() 152 | output = self.end_state 153 | new_state.data["question_seq"].append(output) 154 | new_state.next = self.end_state 155 | new_state.data["score_seq"].append(0) 156 | new_state.data["command_seq"].append("gen") 157 | ## mark the last output 158 | new_state.last_output = output 159 | return new_state 160 | 161 | def score_question(self, sub_question, complex_question): 162 | sub_question_tokens = set(stem_filter_tokenization(sub_question)) 163 | if len(sub_question_tokens) == 0: 164 | logger.debug("No tokens found in sub_question: {}!!".format(sub_question)) 165 | return 0.0 166 | complex_question_tokens = set(stem_filter_tokenization(complex_question)) 167 | overlap = sub_question_tokens.intersection(complex_question_tokens) 168 | # only penalized for sub-question length 169 | return len(overlap) / len(sub_question_tokens) 170 | 171 | def find_question_entities(self, origq): 172 | entities = [] 173 | for m in re.finditer("\\b([A-Z]\w+)", origq): 174 | if m.group(1).lower() not in stop_words_set: 175 | entities.append(m.group(1)) 176 | 177 | for m in re.finditer("([0-9\.]+)", origq): 178 | entities.append(m.group(1)) 179 | return entities 180 | 181 | def replace_blanks(self, blanked_str, fillers): 182 | num_blanks = blanked_str.count(BLANK) 183 | output_strings = [] 184 | if num_blanks > 0: 185 | filler_permutations = permutations(fillers, num_blanks) 186 | for permutation in filler_permutations: 187 | new_str = blanked_str 188 | for filler_val in permutation: 189 | new_str = new_str.replace(BLANK, filler_val, 1) 190 | output_strings.append(new_str) 191 | else: 192 | output_strings = [blanked_str] 193 | return output_strings 194 | 195 | def query(self, state, debug=False): 196 | data = state.data 197 | num_steps = len(data["question_seq"]) 198 | # push for one extra step so that all shorter chains have been explored 199 | if num_steps > self.max_steps: 200 | return [self.build_end_state(state)] 201 | origq = data["query"] 202 | answer_strs = [] 203 | if num_steps == 0: 204 | # hard-coded to only consider select in the first step 205 | ops = ["select"] 206 | else: 207 | for x in range(num_steps): 208 | answer_strs.append("#" + str(x + 1)) 209 | operations_pool = [] 210 | for op in self.operations: 211 | operations_pool.extend(self.replace_blanks(op, answer_strs)) 212 | ops = self.select(operations_pool, self.sample_operations) 213 | 214 | question_entities = self.find_question_entities(origq) 215 | # hack to only use a filler in one of the steps 216 | potential_fillers = question_entities + answer_strs 217 | filler_pool = [] 218 | for filler in potential_fillers: 219 | found_match = False 220 | for question in state.data["subquestion_seq"]: 221 | if filler in question: 222 | found_match = True 223 | break 224 | if not found_match: 225 | filler_pool.append(filler) 226 | 227 | questions_pool = [(m, newq) for (m, q) in self.model_questions 228 | for newq in self.replace_blanks(q, filler_pool)] 229 | if self.topk_questions: 230 | sorted_model_questions = sorted(questions_pool, reverse=True, 231 | key=lambda x: self.score_question(x[1], origq)) 232 | model_questions = self.select(sorted_model_questions, self.sample_questions, 233 | samplek=False) 234 | else: 235 | model_questions = self.select(questions_pool, self.sample_questions, samplek=True) 236 | op_model_qs_prod = product(ops, model_questions) 237 | ## eventual output 238 | new_states = [] 239 | self.num_calls += 1 240 | for (op, model_qs) in op_model_qs_prod: 241 | (model, question) = model_qs 242 | 243 | # no point repeating the exact same question 244 | if question in state.data["subquestion_seq"]: 245 | continue 246 | # copy state 247 | 248 | new_state = state.copy() 249 | output = "({}) [{}] {}".format(op, model, question) 250 | ## add new question to question_seq 251 | new_state.data["question_seq"].append(output) 252 | new_state.next = self.next_model 253 | new_state.data["score_seq"].append(1) 254 | new_state._score += 1 255 | new_state.data["command_seq"].append("gen") 256 | ## mark the last output 257 | new_state.last_output = output 258 | new_states.append(new_state) 259 | ## 260 | # if len(data["question_seq"]) > 0: 261 | # new_states.append(self.build_end_state(state)) 262 | return new_states 263 | -------------------------------------------------------------------------------- /commaqa/inference/participant_util.py: -------------------------------------------------------------------------------- 1 | from commaqa.inference.model_search import ParticipantModel 2 | from commaqa.inference.utils import get_sequence_representation 3 | 4 | 5 | class DumpChainsParticipant(ParticipantModel): 6 | 7 | def __init__(self, output_file, next_model="gen"): 8 | self.output_file = output_file 9 | self.next_model = next_model 10 | self.num_calls = 0 11 | 12 | def return_model_calls(self): 13 | return {"dumpchains": self.num_calls} 14 | 15 | def dump_chain(self, state): 16 | data = state.data 17 | origq = data["query"] 18 | qchain = data["question_seq"] 19 | achain = data["answer_seq"] 20 | sequence = get_sequence_representation(origq=origq, question_seq=qchain, answer_seq=achain) 21 | ans = achain[-1] 22 | with open(self.output_file, 'a') as chains_fp: 23 | chains_fp.write(data["qid"] + "\t" + sequence + "\t" + ans + "\n") 24 | 25 | def query(self, state, debug=False): 26 | self.num_calls += 1 27 | if len(state.data["question_seq"]) > 0: 28 | self.dump_chain(state) 29 | new_state = state.copy() 30 | new_state.next = self.next_model 31 | return new_state 32 | -------------------------------------------------------------------------------- /commaqa/inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict 3 | 4 | from nltk import word_tokenize 5 | from nltk.corpus import stopwords 6 | from nltk.stem.porter import PorterStemmer 7 | 8 | stemmer = PorterStemmer() 9 | 10 | stop_words_set = set(stopwords.words('english')) 11 | 12 | QUESTION_MARKER = " Q: " 13 | COMPQ_MARKER = " QC: " 14 | SIMPQ_MARKER = " QS: " 15 | INTERQ_MARKER = " QI: " 16 | ANSWER_MARKER = " A: " 17 | EOQ_MARKER = "[EOQ]" 18 | LIST_JOINER = " + " 19 | BLANK = "__" 20 | WH_WORDS = set(["who", "what", "where", "how", "why", "when", "which"]) 21 | 22 | 23 | def get_sequence_representation(origq: str, question_seq: List[str], answer_seq: List[str]): 24 | ret_seq = COMPQ_MARKER + origq 25 | if len(question_seq) != len(answer_seq): 26 | raise ValueError("Number of generated questions and answers should match before" 27 | "question generation. Qs: {} As: {}".format(question_seq, answer_seq)) 28 | 29 | for aidx in range(len(answer_seq)): 30 | ret_seq += INTERQ_MARKER 31 | ret_seq += question_seq[aidx] 32 | ret_seq += ANSWER_MARKER + answer_seq[aidx] 33 | ret_seq += SIMPQ_MARKER 34 | return ret_seq 35 | 36 | 37 | def tokenize_str(input_str): 38 | return word_tokenize(input_str) 39 | 40 | 41 | def stem_tokens(token_arr): 42 | return [stemmer.stem(token) for token in token_arr] 43 | 44 | 45 | def filter_stop_tokens(token_arr): 46 | return [token for token in token_arr if token not in stop_words_set] 47 | 48 | 49 | def stem_filter_tokenization(input_str): 50 | return stem_tokens(filter_stop_tokens(tokenize_str(input_str.lower()))) 51 | 52 | 53 | # functions borrowed from AllenNLP to parse JSONNET with env vars 54 | def get_environment_variables() -> Dict[str, str]: 55 | """ 56 | Wraps `os.environ` to filter out non-encodable values. 57 | """ 58 | return {key: value for key, value in os.environ.items() if _is_encodable(value)} 59 | 60 | 61 | def _is_encodable(value: str) -> bool: 62 | """ 63 | We need to filter out environment variables that can't 64 | be unicode-encoded to avoid a "surrogates not allowed" 65 | error in jsonnet. 66 | """ 67 | # Idiomatically you'd like to not check the != b"" 68 | # but mypy doesn't like that. 69 | return (value == "") or (value.encode("utf-8", "ignore") != b"") 70 | -------------------------------------------------------------------------------- /commaqa/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoTokenizer, AutoModelWithLMHead 3 | from transformers.generation_utils import SampleEncoderDecoderOutput 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class LMGenerator: 10 | 11 | def __init__(self, model_path, device=None, 12 | generation_args={}, encoder_args={}, decoder_args={}): 13 | if device is None: 14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | else: 16 | self.device = device 17 | 18 | self.config = AutoConfig.from_pretrained(model_path) 19 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 20 | self.model = AutoModelWithLMHead.from_pretrained(model_path, config=self.config).to( 21 | self.device) 22 | self.generation_args = generation_args 23 | # always generate output with scores 24 | self.generation_args["output_scores"] = True 25 | self.generation_args["return_dict_in_generate"] = True 26 | self.encoder_args = encoder_args 27 | self.decoder_args = decoder_args 28 | 29 | def generate_text_sequence(self, input_text): 30 | """ 31 | :param input_text: 32 | :return: returns a sequence of tuples (string, score) where lower score is better 33 | """ 34 | encoded_prompt = self.tokenizer.encode(input_text, **self.encoder_args) 35 | 36 | encoded_prompt = encoded_prompt.to(self.device) 37 | generated_dict = self.model.generate(input_ids=encoded_prompt, **self.generation_args) 38 | 39 | generated_seqs = generated_dict.sequences 40 | if isinstance(generated_dict, SampleEncoderDecoderOutput): 41 | logger.warning("No scores generated when sampled sequences") 42 | generated_scores = [0] * len(generated_seqs) 43 | else: 44 | generated_scores = generated_dict.sequences_scores.tolist() 45 | if len(generated_seqs.shape) > 2: 46 | generated_seqs.squeeze_() 47 | 48 | output_seq_score = [] 49 | 50 | for generated_sequence_idx, generated_seq in enumerate(generated_seqs): 51 | generated_output = generated_seq.tolist() 52 | text = self.tokenizer.decode(generated_output, **self.decoder_args) 53 | # flip the negative logit so that sequence with lowest scores is best 54 | output_seq_score.append((text, -generated_scores[generated_sequence_idx])) 55 | 56 | # Ensure sorted output 57 | return sorted(output_seq_score, key=lambda x: x[1]) 58 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/entities.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | movie: [ 3 | "Pestok", 4 | "Gambilla", 5 | "Dewbar", 6 | "Tannine", 7 | "Bronchum", 8 | "Skia", 9 | "Hoopdoodle", 10 | "Tayenne", 11 | "Misapportionment", 12 | "Spadesan", 13 | "Assamplifier", 14 | "Myristorrhoid", 15 | "Clenestration", 16 | "Sprezzler", 17 | "Polytetrafluoromethane", 18 | "Waxhead", 19 | "Pastillobox", 20 | "Compresse", 21 | "Zalate", 22 | "Covetic", 23 | "Conforancy", 24 | "Vitrilateral", 25 | "Partnershipmaker", 26 | "Hutter", 27 | "Tantor", 28 | "Riften", 29 | "Tetroxidine", 30 | "Gigato", 31 | "Spidertail", 32 | "Nippurum", 33 | "Honeybean", 34 | "Epicuratorion", 35 | "Tirelock", 36 | "Sahaki", 37 | "Premercy", 38 | "Pugo", 39 | "Quassa", 40 | "Geissant", 41 | "Chimpwurst", 42 | "Quinsid", 43 | "Jenga", 44 | "Printip", 45 | "Vitter", 46 | "Sheepcrest", 47 | "Gordanoel", 48 | "Hayout", 49 | "Coesestion", 50 | "Nohit", 51 | "Pipesia", 52 | "Sagali", 53 | "Featsaw", 54 | "Caudacite", 55 | "Kavashpat", 56 | "Skob", 57 | "Meather", 58 | "Misgendery", 59 | "Tarta", 60 | "Calamba", 61 | "Peaseman", 62 | "Pneumodendron", 63 | "Nilitude", 64 | "Dickerhead", 65 | "Kraof", 66 | "Booyah", 67 | "Biscus", 68 | "Warpstone", 69 | "Midshipwoman", 70 | "Chickenpot", 71 | "Cougarism", 72 | "Noenometer", 73 | "Percevalence", 74 | "Vucumber", 75 | "Coule", 76 | "Subwort", 77 | "Skirtsicine", 78 | "Coacheship", 79 | "Teetermark", 80 | "Pigmold", 81 | "Calcivore", 82 | "Cataflower", 83 | ], 84 | person: [ 85 | "Jimayo", 86 | "Bioperatology", 87 | "Gigafuna", 88 | "Alpinista", 89 | "Wetherality", 90 | "Gastrat", 91 | "Sequinolone", 92 | "Teeplemole", 93 | "Huckberryberry", 94 | "Parchin", 95 | "Straviolence", 96 | "Chickenshaw", 97 | "Gutskin", 98 | "Flumph", 99 | "Sclerocybin", 100 | "Midcareer", 101 | "Dormitula", 102 | "Haldron", 103 | "Monsterscar", 104 | "Comander", 105 | "Firmline", 106 | "Lappagee", 107 | "Microcomputing", 108 | "Bibbogey", 109 | "Conanopeia", 110 | "Sustainableness", 111 | "Bioplankton", 112 | "Sealt", 113 | "Dumasite", 114 | "Carpoon", 115 | "Magainitis", 116 | "Fondlement", 117 | "Metatoun", 118 | "Mimicocycle", 119 | "Carblock", 120 | "Fidelice", 121 | "Impassivism", 122 | "Zayage", 123 | "Lougerière", 124 | "Gigabut", 125 | "Deplexology", 126 | "Topboard", 127 | "Metrix", 128 | "Kapod", 129 | "Muntaril", 130 | "Zakti", 131 | "Sickkin", 132 | "Diarmallurgy", 133 | "Compositon", 134 | "Sapien", 135 | ], 136 | r_year: [ 137 | "1999", 138 | "1965", 139 | "1960", 140 | "1954", 141 | "1964", 142 | "1998", 143 | "1971", 144 | "2009", 145 | "1991", 146 | "1989", 147 | "2005", 148 | "1973", 149 | "1984", 150 | "1970", 151 | "2011", 152 | "1959", 153 | "1972", 154 | "1975", 155 | "1987", 156 | "1956", 157 | "2016", 158 | "1986", 159 | "1966", 160 | "1974", 161 | "1967", 162 | "1993", 163 | "2012", 164 | "2010", 165 | "2019", 166 | "2000", 167 | "2004", 168 | "1963", 169 | "1957", 170 | "2013", 171 | "2018", 172 | "1997", 173 | "1962", 174 | "2007", 175 | "1977", 176 | "1979", 177 | "1961", 178 | "2003", 179 | "1982", 180 | "2001", 181 | "1990", 182 | "1953", 183 | "1988", 184 | "2008", 185 | "2006", 186 | "1992", 187 | ], 188 | b_year: [ 189 | "1902", 190 | "1945", 191 | "1949", 192 | "1938", 193 | "1910", 194 | "1904", 195 | "1931", 196 | "1948", 197 | "1912", 198 | "1911", 199 | "1942", 200 | "1939", 201 | "1935", 202 | "1933", 203 | "1924", 204 | "1928", 205 | "1917", 206 | "1909", 207 | "1905", 208 | "1921", 209 | "1922", 210 | "1915", 211 | "1913", 212 | "1906", 213 | "1943", 214 | "1926", 215 | "1947", 216 | "1930", 217 | "1907", 218 | "1929", 219 | ], 220 | p_award: [ 221 | "Alterygium", 222 | "Rafflecopter", 223 | "Kinkster", 224 | "Mosei", 225 | "Monkeynote", 226 | "Divetail", 227 | "Quinion", 228 | "Custodio", 229 | "Roontang", 230 | "Minimiseries", 231 | "Hollowside", 232 | "Diaqum", 233 | "Polyquadrase", 234 | "Aniconder", 235 | "Malwarp", 236 | "Waxseer", 237 | "Lidus", 238 | "Goatfly", 239 | "Lameze", 240 | "Microthèsema", 241 | "Halfbill", 242 | "Glodome", 243 | "Prophococcus", 244 | "Semilist", 245 | "Trifogation", 246 | "Undercabin", 247 | "Metricization", 248 | "Counterprogram", 249 | "Modiparity", 250 | "Fannyxist", 251 | ], 252 | nation: [ 253 | "Kinneticket", 254 | "Rattlider", 255 | "Microbead", 256 | "Siphonometer", 257 | "Biopsie", 258 | "Dynope", 259 | "Lampplate", 260 | "Thistlewood", 261 | "Chickenskull", 262 | "Louage", 263 | "Tatkin", 264 | "Whime", 265 | "Clony", 266 | "Moulminer", 267 | "Dentalogy", 268 | "Legault", 269 | "Triclops", 270 | "Moulole", 271 | "Calderita", 272 | "Poquet", 273 | "Dessaless", 274 | "Schelpla", 275 | "Zaggery", 276 | "Spanulum", 277 | "Microfouling", 278 | "Knoppock", 279 | "Obility", 280 | "Stridery", 281 | "Loisy", 282 | "Piperfish", 283 | ], 284 | m_award: [ 285 | "Maldezine", 286 | "Chowwurst", 287 | "Glag", 288 | "Sockbox", 289 | "Erowid", 290 | "Prowecap", 291 | "Leatherie", 292 | "Microsouenesis", 293 | "Posteria", 294 | "Hallowcock", 295 | "Neuropsychotaxis", 296 | "Sabonade", 297 | "Brownbeard", 298 | "Pistarmen", 299 | "Hydrallium", 300 | "Potcrash", 301 | "Stoptite", 302 | "Cockspit", 303 | "Gutney", 304 | "Dextrite", 305 | "Triptychology", 306 | "Pianogram", 307 | "Airpipe", 308 | "Tachychronograph", 309 | "Antipositive", 310 | "Segumen", 311 | "Periosteis", 312 | "Grangebagger", 313 | "Paleodactyl", 314 | "Slauspost", 315 | "Electrodesal", 316 | "Dysmetis", 317 | "Heptelphism", 318 | "Zorgion", 319 | "Pludgel", 320 | "Goosehead", 321 | "Bouchery", 322 | "Hochelfoil", 323 | "Pompasole", 324 | "Po'Rsiera", 325 | "Mariskenna", 326 | "Monoxandrite", 327 | "Heilbron", 328 | "Siligar", 329 | "Handt", 330 | "Jubeus", 331 | "Trummer", 332 | "Dessication", 333 | "Headlet", 334 | "Pennepiece", 335 | ], 336 | } 337 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/movies1.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories.libsonnet"; 7 | local predicate_names = ["table_year", "table_directed", "table_maward", "table_writer", "text_actor", "text_produced", "text_paward", "text_dob", "text_nation"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/movies1_compgen.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories_compgen.libsonnet"; 7 | local predicate_names = ["table_year", "table_directed", "table_maward", "table_writer", "text_actor", "text_produced", "text_paward", "text_dob", "text_nation"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/movies2.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories.libsonnet"; 7 | local predicate_names = ["table_year", "table_directed", "table_maward", "text_writer", "text_actor", "text_produced", "table_paward", "text_dob", "text_nation"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/movies2_compgen.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories_compgen.libsonnet"; 7 | local predicate_names = ["table_year", "table_directed", "table_maward", "text_writer", "text_actor", "text_produced", "table_paward", "text_dob", "text_nation"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/table_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | table_year: { 3 | args: ["movie", "r_year"], 4 | nary: ["1", "n"], 5 | language: ["movie: $1 ; year: $2", "movie: $1 ; release year: $2"], 6 | }, 7 | table_directed: { 8 | args: ["movie", "person"], 9 | nary: ["1", "n"], 10 | language: ["movie: $1 ; director: $2", "movie: $1 ; directed by: $2"], 11 | }, 12 | table_actor: { 13 | args: ["movie", "person"], 14 | nary: ["n", "n"], 15 | language: ["movie: $1 ; actor: $2", "actor: $2 ; movie: $1"], 16 | }, 17 | table_writer: { 18 | args: ["movie", "person"], 19 | nary: ["n", "n"], 20 | language: ["movie: $1 ; writer: $2", "movie: $1 ; written by: $2"], 21 | }, 22 | table_produced: { 23 | args: ["movie", "person"], 24 | nary: ["n", "n"], 25 | language: ["movie: $1 ; producer: $2", "producer: $1 ; movie: $1"], 26 | }, 27 | table_maward: { 28 | args: ["movie", "m_award"], 29 | nary: ["1", "n"], 30 | language: ["movie: $1 ; award: $2", "movie: $1 ; awarded: $2"], 31 | }, 32 | table_paward: { 33 | args: ["person", "p_award"], 34 | nary: ["1", "n"], 35 | language: ["person: $1 ; award: $2", "award: $2 ; winner: $1"], 36 | }, 37 | } 38 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/text_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | text_directed: { 3 | args: ["movie", "person"], 4 | nary: ["1", "n"], 5 | language: ["$1 was a movie directed by $2", "$1 directed the movie $2"], 6 | }, 7 | text_actor: { 8 | args: ["movie", "person"], 9 | nary: ["n", "n"], 10 | language: ["$2 acted in the movie $1", "$2 was an actor in the movie $1"], 11 | }, 12 | text_writer: { 13 | args: ["movie", "person"], 14 | nary: ["n", "n"], 15 | language: ["$2 was one of the writers for the movie $1", "$2 wrote for the movie $1"], 16 | }, 17 | text_produced: { 18 | args: ["movie", "person"], 19 | nary: ["n", "n"], 20 | language: ["$2 was one of the producers of the movie $1", "$2 produced the movie $1 with others"], 21 | }, 22 | text_paward: { 23 | args: ["person", "p_award"], 24 | nary: ["1", "n"], 25 | language: ["$1 won the $2 award", "$2 was awarded to $1"], 26 | }, 27 | text_dob: { 28 | args: ["person", "b_year"], 29 | nary: ["1", "n"], 30 | language: ["$1 was born in $2", "$1 was born in the year $2"], 31 | }, 32 | text_nation: { 33 | args: ["person", "nation"], 34 | nary: ["1", "n"], 35 | language: ["$1 is from the country of $2", "$1 grew up in the nation of $2"], 36 | }, 37 | } 38 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/theories.libsonnet: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | init: { 4 | "$1": "nation", 5 | }, 6 | questions: [ 7 | "What movies have people from the country $1 acted in?", 8 | ], 9 | steps: [ 10 | { 11 | operation: "select", 12 | question: "nation_p($1, ?)", 13 | answer: "#1", 14 | }, 15 | { 16 | operation: "project_values_flat_unique", 17 | question: "acted_m(#1, ?)", 18 | answer: "#2", 19 | }, 20 | ], 21 | }, 22 | { 23 | init: { 24 | "$1": "nation", 25 | }, 26 | questions: [ 27 | "What movies have the directors from $1 directed?", 28 | ], 29 | steps: [ 30 | { 31 | operation: "select", 32 | question: "nation_p($1, ?)", 33 | answer: "#1", 34 | }, 35 | { 36 | operation: "project_values_flat_unique", 37 | question: "directed_m(#1, ?)", 38 | answer: "#2", 39 | }, 40 | ], 41 | }, 42 | { 43 | init: { 44 | "$1": "b_year", 45 | }, 46 | questions: [ 47 | "What awards have movies produced by people born in $1 won?", 48 | ], 49 | steps: [ 50 | { 51 | operation: "select", 52 | question: "dob_p($1, ?)", 53 | answer: "#1", 54 | }, 55 | { 56 | operation: "project_values_flat_unique", 57 | question: "produced_m(#1, ?)", 58 | answer: "#2", 59 | }, 60 | { 61 | operation: "project_values_flat_unique", 62 | question: "maward_a(#2, ?)", 63 | answer: "#3", 64 | }, 65 | ], 66 | }, 67 | { 68 | init: { 69 | "$1": "b_year", 70 | }, 71 | questions: [ 72 | "What awards have movies written by people born in $1 won?", 73 | ], 74 | steps: [ 75 | { 76 | operation: "select", 77 | question: "dob_p($1, ?)", 78 | answer: "#1", 79 | }, 80 | { 81 | operation: "project_values_flat_unique", 82 | question: "wrote_m(#1, ?)", 83 | answer: "#2", 84 | }, 85 | { 86 | operation: "project_values_flat_unique", 87 | question: "maward_a(#2, ?)", 88 | answer: "#3", 89 | }, 90 | ], 91 | }, 92 | { 93 | init: { 94 | "$1": "p_award", 95 | }, 96 | questions: [ 97 | "What awards did the movies directed by the $1 winners receive?", 98 | ], 99 | steps: [ 100 | { 101 | operation: "select", 102 | question: "paward_p($1, ?)", 103 | answer: "#1", 104 | }, 105 | { 106 | operation: "project_values_flat_unique", 107 | question: "directed_m(#1, ?)", 108 | answer: "#2", 109 | }, 110 | { 111 | operation: "project_values_flat_unique", 112 | question: "maward_a(#2, ?)", 113 | answer: "#3", 114 | }, 115 | ], 116 | }, 117 | { 118 | init: { 119 | "$1": "m_award", 120 | }, 121 | questions: [ 122 | "What awards have the actors of the $1 winning movies received?", 123 | ], 124 | steps: [ 125 | { 126 | operation: "select", 127 | question: "maward_m($1, ?)", 128 | answer: "#1", 129 | }, 130 | { 131 | operation: "project_values_flat_unique", 132 | question: "acted_a(#1, ?)", 133 | answer: "#2", 134 | }, 135 | { 136 | operation: "project_values_flat_unique", 137 | question: "paward_a(#2, ?)", 138 | answer: "#3", 139 | }, 140 | ], 141 | }, 142 | ] 143 | -------------------------------------------------------------------------------- /configs/commaqav1/explicit/theories_compgen.libsonnet: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | init: { 4 | "$1": "nation", 5 | }, 6 | questions: [ 7 | "What movies have the people from $1 written?", 8 | ], 9 | steps: [ 10 | { 11 | operation: "select", 12 | question: "nation_p($1, ?)", 13 | answer: "#1", 14 | }, 15 | { 16 | operation: "project_values_flat_unique", 17 | question: "wrote_m(#1, ?)", 18 | answer: "#2", 19 | }, 20 | ], 21 | }, 22 | { 23 | init: { 24 | "$1": "b_year", 25 | }, 26 | questions: [ 27 | "What movies have people born in $1 acted in?", 28 | ], 29 | steps: [ 30 | { 31 | operation: "select", 32 | question: "dob_p($1, ?)", 33 | answer: "#1", 34 | }, 35 | { 36 | operation: "project_values_flat_unique", 37 | question: "acted_m(#1, ?)", 38 | answer: "#2", 39 | }, 40 | ], 41 | }, 42 | { 43 | init: { 44 | "$1": "nation", 45 | }, 46 | questions: [ 47 | "What awards have movies produced by people from $1 won?", 48 | ], 49 | steps: [ 50 | { 51 | operation: "select", 52 | question: "nation_p($1, ?)", 53 | answer: "#1", 54 | }, 55 | { 56 | operation: "project_values_flat_unique", 57 | question: "produced_m(#1, ?)", 58 | answer: "#2", 59 | }, 60 | { 61 | operation: "project_values_flat_unique", 62 | question: "maward_a(#2, ?)", 63 | answer: "#3", 64 | }, 65 | ], 66 | }, 67 | { 68 | init: { 69 | "$1": "b_year", 70 | }, 71 | questions: [ 72 | "What awards have movies directed by people born in $1 won?", 73 | ], 74 | steps: [ 75 | { 76 | operation: "select", 77 | question: "dob_p($1, ?)", 78 | answer: "#1", 79 | }, 80 | { 81 | operation: "project_values_flat_unique", 82 | question: "directed_m(#1, ?)", 83 | answer: "#2", 84 | }, 85 | { 86 | operation: "project_values_flat_unique", 87 | question: "maward_a(#2, ?)", 88 | answer: "#3", 89 | }, 90 | ], 91 | }, 92 | { 93 | init: { 94 | "$1": "nation", 95 | }, 96 | questions: [ 97 | "What awards have movies written by people from $1 won?", 98 | ], 99 | steps: [ 100 | { 101 | operation: "select", 102 | question: "nation_p($1, ?)", 103 | answer: "#1", 104 | }, 105 | { 106 | operation: "project_values_flat_unique", 107 | question: "wrote_m(#1, ?)", 108 | answer: "#2", 109 | }, 110 | { 111 | operation: "project_values_flat_unique", 112 | question: "maward_a(#2, ?)", 113 | answer: "#3", 114 | }, 115 | ], 116 | }, 117 | { 118 | init: { 119 | "$1": "m_award", 120 | }, 121 | questions: [ 122 | "What awards have the directors of the $1 winning movies received?", 123 | ], 124 | steps: [ 125 | { 126 | operation: "select", 127 | question: "maward_m($1, ?)", 128 | answer: "#1", 129 | }, 130 | { 131 | operation: "project_values_flat_unique", 132 | question: "directed_d(#1, ?)", 133 | answer: "#2", 134 | }, 135 | { 136 | operation: "project_values_flat_unique", 137 | question: "paward_a(#2, ?)", 138 | answer: "#3", 139 | }, 140 | ], 141 | }, 142 | ] 143 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items0.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | //local all_predicates = ["text_dob", "text_dod", "text_occupation", "text_field", "text_invent", "text_used_o", "text_used_f", "text_founded", "text_inventor", "text_developed", "text_makes", "text_usedin", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_invent", "text_founded", "text_inventor", "text_developed", "text_makes", "text_usedin", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 9 | local theories = [all_theories[0]]; 10 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 11 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 12 | { 13 | version: 3.0, 14 | entities: entities, 15 | predicates: predicates, 16 | predicate_language: predicate_language, 17 | theories: theories, 18 | } 19 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items1.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_used_f", "text_founded", "text_inventor", "text_developed", "text_makes", "text_usedin", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local theories = [all_theories[1]]; 9 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 10 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 11 | { 12 | version: 3.0, 13 | entities: entities, 14 | predicates: predicates, 15 | predicate_language: predicate_language, 16 | theories: theories, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items2.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_used_o", "text_founded", "text_inventor", "text_developed", "text_makes", "text_usedin", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local theories = [all_theories[2]]; 9 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 10 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 11 | { 12 | version: 3.0, 13 | entities: entities, 14 | predicates: predicates, 15 | predicate_language: predicate_language, 16 | theories: theories, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items3.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_invent", "text_used_o", "text_used_f", "text_founded", "text_inventor", "text_developed", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local theories = [all_theories[3]]; 9 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 10 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 11 | { 12 | version: 3.0, 13 | entities: entities, 14 | predicates: predicates, 15 | predicate_language: predicate_language, 16 | theories: theories, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items4.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_invent", "text_used_o", "text_used_f", "text_founded", "text_inventor", "text_usedin", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local theories = [all_theories[4]]; 9 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 10 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 11 | { 12 | version: 3.0, 13 | entities: entities, 14 | predicates: predicates, 15 | predicate_language: predicate_language, 16 | theories: theories, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/items5.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local kb_predicate = import "kb_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = text_predicate + kb_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local all_theories = import "theories.libsonnet"; 7 | local predicate_names = ["text_dob", "text_dod", "text_occupation", "text_field", "text_invent", "text_used_o", "text_used_f", "text_founded", "text_inventor", "text_makes", "text_contains", "kb_studied_f", "kb_graduate_o", "kb_isa"]; 8 | local theories = [all_theories[5]]; 9 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 10 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 11 | { 12 | version: 3.0, 13 | entities: entities, 14 | predicates: predicates, 15 | predicate_language: predicate_language, 16 | theories: theories, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/kb_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | kb_studied_f: { 3 | args: ["occupation2", "field"], 4 | nary: ["1", "n"], 5 | language: ["Study $2 | MotivatedByGoal | Work as $1", "Working as $1 | HasPrerequisite | Studying $2"], 6 | }, 7 | kb_graduate_o: { 8 | args: ["field2", "occupation"], 9 | nary: ["1", "n"], 10 | language: ["Study $1 | MotivatedByGoal | Work as $2)", "Working as $2 | HasPrerequisite | Studying $1"], 11 | }, 12 | kb_isa: { 13 | args: ["device", "obj"], 14 | nary: ["n", "1"], 15 | language: ["$1 | Isa | $2", "$1 device | Isa | $2 object"], 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/text_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | text_dob: { 3 | args: ["person", "b_year"], 4 | nary: ["1", "n"], 5 | language: ["$1 was born in the year $2", "$1 was born in $2"], 6 | }, 7 | text_dod: { 8 | args: ["person", "d_year"], 9 | nary: ["1", "n"], 10 | language: ["$1 died in $2", "$1 died in the year $2"], 11 | }, 12 | text_occupation: { 13 | args: ["person", "occupation"], 14 | nary: ["n", "1"], 15 | language: ["$1 works as a $2", "$1's occupation is $2"], 16 | }, 17 | text_field: { 18 | args: ["person", "field"], 19 | nary: ["n", "1"], 20 | language: ["$1 studied $2 in college", "$1's field of study was $2"], 21 | }, 22 | text_invent: { 23 | args: ["obj", "year"], 24 | nary: ["1", "n"], 25 | language: ["$1 was first invented in the year $2", "$1 was invented in $2"], 26 | }, 27 | text_used_o: { 28 | args: ["obj", "occupation2"], 29 | nary: ["1", "n"], 30 | language: ["$1 is often used by people working as $2", "A $2 would often use a $1"], 31 | }, 32 | text_used_f: { 33 | args: ["obj", "field2"], 34 | nary: ["1", "n"], 35 | language: ["$1 is commonly used in the field of $2", "When studying $2, $1 would be used"], 36 | }, 37 | text_founded: { 38 | args: ["person", "company"], 39 | nary: ["n", "1"], 40 | language: ["$1 founded the company $2", "$1 was the founder of the company $2"], 41 | }, 42 | text_inventor: { 43 | args: ["person", "tech"], 44 | nary: ["n", "1"], 45 | language: ["$1 invented the technology of $2", "$1 was the inventor of $2 technology"], 46 | }, 47 | text_developed: { 48 | args: ["company", "device"], 49 | nary: ["n", "1"], 50 | language: ["The $2 was developed at $1", "$2 was developed by the $1 company"], 51 | }, 52 | text_makes: { 53 | args: ["company", "material"], 54 | nary: ["n", "1"], 55 | language: ["$1 is a provider of the material $2", "$2 is produced by the company $1"], 56 | }, 57 | text_usedin: { 58 | args: ["tech", "device"], 59 | nary: ["n", "1"], 60 | language: ["The $1 technology was instrumental in the development of $1", "$2 device was developed based on the $1 technology"], 61 | }, 62 | text_contains: { 63 | args: ["material", "obj"], 64 | nary: ["n", "1"], 65 | language: ["$2 is made using $1 material", "$1 material is needed to make $2"], 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /configs/commaqav1/implicit/theories.libsonnet: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | init: { 4 | "$1": "person", 5 | }, 6 | questions: [ 7 | "What objects has $1 likely used?", 8 | ], 9 | steps: [ 10 | { 11 | operation: "select", 12 | question: "died_y($1, ?)", 13 | answer: "#1", 14 | }, 15 | { 16 | operation: "select", 17 | question: "invented(?)", 18 | answer: "#2", 19 | }, 20 | { 21 | operation: "project", 22 | question: "invent_y(#2, ?)", 23 | answer: "#3", 24 | }, 25 | { 26 | operation: "filterValues(#3)_keys", 27 | question: "is_smaller(#3 | #1)", 28 | answer: "#4", 29 | }, 30 | ], 31 | }, 32 | { 33 | init: { 34 | "$1": "person", 35 | }, 36 | questions: [ 37 | "What objects has $1 likely used?", 38 | ], 39 | steps: [ 40 | { 41 | operation: "select", 42 | question: "works_o($1, ?)", 43 | answer: "#1", 44 | }, 45 | { 46 | operation: "project_values_flat_unique", 47 | question: "graduate_f2(#1, ?)", 48 | answer: "#2", 49 | }, 50 | { 51 | operation: "project_values_flat_unique", 52 | question: "usedf_ob(#2, ?)", 53 | answer: "#3", 54 | }, 55 | ], 56 | }, 57 | { 58 | init: { 59 | "$1": "person", 60 | }, 61 | questions: [ 62 | "What objects has $1 likely used?", 63 | ], 64 | steps: [ 65 | { 66 | operation: "select", 67 | question: "field_f($1, ?)", 68 | answer: "#1", 69 | }, 70 | { 71 | operation: "project_values_flat_unique", 72 | question: "study_o2(#1, ?)", 73 | answer: "#2", 74 | }, 75 | { 76 | operation: "project_values_flat_unique", 77 | question: "usedo_ob(#2, ?)", 78 | answer: "#3", 79 | }, 80 | ], 81 | }, 82 | { 83 | init: { 84 | "$1": "person", 85 | }, 86 | questions: [ 87 | "What objects has $1 helped to make?", 88 | ], 89 | steps: [ 90 | { 91 | operation: "select", 92 | question: "founded_c($1, ?)", 93 | answer: "#1", 94 | }, 95 | { 96 | operation: "project_values_flat_unique", 97 | question: "developed_d(#1, ?)", 98 | answer: "#2", 99 | }, 100 | { 101 | operation: "project_values_flat_unique", 102 | question: "isa_o(#2, ?)", 103 | answer: "#3", 104 | }, 105 | ], 106 | }, 107 | { 108 | init: { 109 | "$1": "person", 110 | }, 111 | questions: [ 112 | "What objects has $1 helped to make?", 113 | ], 114 | steps: [ 115 | { 116 | operation: "select", 117 | question: "invented_t($1, ?)", 118 | answer: "#1", 119 | }, 120 | { 121 | operation: "project_values_flat_unique", 122 | question: "usedin_d(#1, ?)", 123 | answer: "#2", 124 | }, 125 | { 126 | operation: "project_values_flat_unique", 127 | question: "isa_o(#2, ?)", 128 | answer: "#3", 129 | }, 130 | ], 131 | }, 132 | { 133 | init: { 134 | "$1": "person", 135 | }, 136 | questions: [ 137 | "What objects has $1 helped to make?", 138 | ], 139 | steps: [ 140 | { 141 | operation: "select", 142 | question: "founded_c($1, ?)", 143 | answer: "#1", 144 | }, 145 | { 146 | operation: "project_values_flat_unique", 147 | question: "makes_m(#1, ?)", 148 | answer: "#2", 149 | }, 150 | { 151 | operation: "project_values_flat_unique", 152 | question: "contains_o(#2, ?)", 153 | answer: "#3", 154 | }, 155 | ], 156 | }, 157 | ] 158 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/entities.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | nation: [ 3 | "Skaltay", 4 | "Spasmograph", 5 | "Shrineroom", 6 | "Besprit", 7 | "Misapportionment", 8 | "Premercy", 9 | "Antipositive", 10 | "Carpoon", 11 | "Chickenpot", 12 | "Roozoid", 13 | "Epicuratorion", 14 | "Cockspit", 15 | "Waxseer", 16 | "Midshipwoman", 17 | "Pestok", 18 | "Riften", 19 | "Pistarmen", 20 | "Tranzoid", 21 | "Chasmogon", 22 | "Triuplet", 23 | "Cygris", 24 | "Tupeloide", 25 | "Pneumodendron", 26 | "Portel", 27 | "Noenometer", 28 | "Stoveback", 29 | "Pantalooza", 30 | "Vitimix", 31 | "Boochen", 32 | "Coathanger", 33 | "Fannyxist", 34 | "Halfbill", 35 | "Briffle", 36 | "Tremolophore", 37 | "Haystone", 38 | "Lougerière", 39 | "Vexa", 40 | "Blaubrudin", 41 | "Gigafuna", 42 | "Maldezine", 43 | ], 44 | personj: [ 45 | "Tantor", 46 | "Quiltskin", 47 | "Prostigma", 48 | "Hydromagnetism", 49 | "Lushale", 50 | "Mossia", 51 | "Jimayo", 52 | "Syphactery", 53 | "Segurologyphyte", 54 | "Papernike", 55 | "Seeper", 56 | "Moulminer", 57 | "Crowdstrike", 58 | "Biopsie", 59 | "Chaudelaire", 60 | "Erowid", 61 | "Bioplankton", 62 | "Sequinodactyl", 63 | "Entine", 64 | "Cutthrough", 65 | "Sockbox", 66 | "Mochit", 67 | "Jenga", 68 | "Googolome", 69 | "Trifogation", 70 | "Peaseman", 71 | "Featsaw", 72 | "Spanulum", 73 | "Lumberition", 74 | "Palmorra", 75 | "Sealt", 76 | "Partnershipmaker", 77 | "Compositon", 78 | "Trussellation", 79 | "Pludgel", 80 | "Queness", 81 | "Tachydid", 82 | "Knebbit", 83 | "Jungdowda", 84 | "Chittagood", 85 | "Cavette", 86 | "Autocybe", 87 | "Polyacrylate", 88 | "Catbox", 89 | "Trille", 90 | "Magpul", 91 | "Stretchwork", 92 | "Dessaless", 93 | "Vitule", 94 | "Zekkobe", 95 | "Thym", 96 | "Lowrise", 97 | "Pipesia", 98 | "Insimetry", 99 | "Beavertail", 100 | "Saltcoat", 101 | "Kittencrest", 102 | "Modiparity", 103 | "Jockolypse", 104 | "Polyhoney", 105 | "Clenestration", 106 | "Cassamide", 107 | "Ployer", 108 | "Zalate", 109 | "Predigime", 110 | "Diaqum", 111 | "Duckberry", 112 | "Buncha", 113 | "Minimiseries", 114 | "Tartaritis", 115 | "Corporateist", 116 | "Coacheship", 117 | "Wetherality", 118 | "Shadbery", 119 | "Nailbone", 120 | "Headlet", 121 | "Minatura", 122 | "Kinkhole", 123 | "Rigatil", 124 | "Brasscoating", 125 | ], 126 | persond: [ 127 | "Divetail", 128 | "Sahaki", 129 | "Siligar", 130 | "Skullard", 131 | "Biscus", 132 | "Knoppock", 133 | "Mimicocycle", 134 | "Undercabin", 135 | "Riddlemat", 136 | "Vout", 137 | "Stoptite", 138 | "Cabrinder", 139 | "Glodome", 140 | "Colorectomy", 141 | "Fidelice", 142 | "Lechpin", 143 | "Aniconder", 144 | "Sclerotostomy", 145 | "Karmacogram", 146 | "Roor", 147 | "Zayage", 148 | "Gauconium", 149 | "Beancounter", 150 | "Waddletail", 151 | "Triclops", 152 | "Airpipe", 153 | "Noosecutter", 154 | "Defintion", 155 | "Quintoy", 156 | "Cranon", 157 | "Flumph", 158 | "Coordsman", 159 | "Lobsteroid", 160 | "Segumen", 161 | "Terbaryan", 162 | "Chimpwurst", 163 | "Superpredator", 164 | "Misophonia", 165 | "Bronchosol", 166 | "Potcrash", 167 | "Barbition", 168 | "Cabaretillonite", 169 | "Microindication", 170 | "Darecline", 171 | "Metatoun", 172 | "Polypartity", 173 | "Cordic", 174 | "Hoopdoodle", 175 | "Quinsid", 176 | "Waterpipe", 177 | "Waxbox", 178 | "Mechanicism", 179 | "Metrix", 180 | "Honeywax", 181 | "Fremettia", 182 | "Dewbar", 183 | "Malcomoration", 184 | "Pianogram", 185 | "Blumen", 186 | "Whime", 187 | "Nephewskin", 188 | "Guazepam", 189 | "Spursium", 190 | "Infiling", 191 | "Cheapnose", 192 | "Pompasole", 193 | "Cooperativism", 194 | "Brownbeard", 195 | "Bluechase", 196 | "Monsterscar", 197 | "Kavashpat", 198 | "Gigabut", 199 | "Wriststroke", 200 | "Frangile", 201 | "Legault", 202 | "Karfman", 203 | "Endography", 204 | "Dumasite", 205 | "Barbrauch", 206 | "Semilist", 207 | ], 208 | dlength: [ 209 | "65.0", 210 | "61.8", 211 | "73.2", 212 | "58.4", 213 | "49.2", 214 | "65.6", 215 | "48.2", 216 | "65.8", 217 | "52.6", 218 | "60.0", 219 | "44.8", 220 | "71.0", 221 | "70.2", 222 | "64.6", 223 | "70.6", 224 | "48.4", 225 | "62.0", 226 | "57.8", 227 | "64.0", 228 | "67.6", 229 | "66.2", 230 | "66.6", 231 | "54.2", 232 | "53.8", 233 | "53.0", 234 | "66.8", 235 | "66.4", 236 | "69.2", 237 | "63.4", 238 | "54.6", 239 | "72.0", 240 | "46.6", 241 | "63.2", 242 | "55.0", 243 | "72.6", 244 | "72.4", 245 | "63.8", 246 | "49.0", 247 | "63.0", 248 | "51.6", 249 | "45.8", 250 | "45.0", 251 | "47.4", 252 | "58.0", 253 | "52.2", 254 | "50.8", 255 | "45.4", 256 | "62.2", 257 | "68.2", 258 | "55.8", 259 | "57.2", 260 | "71.8", 261 | "58.2", 262 | "44.0", 263 | "44.2", 264 | "58.8", 265 | "71.6", 266 | "59.8", 267 | "67.0", 268 | "46.0", 269 | "50.6", 270 | "58.6", 271 | "72.8", 272 | "50.4", 273 | "50.0", 274 | "73.4", 275 | "52.4", 276 | "48.6", 277 | "59.6", 278 | "46.8", 279 | "44.6", 280 | "57.6", 281 | "60.2", 282 | "63.6", 283 | "71.4", 284 | "60.4", 285 | "49.8", 286 | "62.6", 287 | "54.4", 288 | "61.0", 289 | "69.6", 290 | "57.4", 291 | "52.8", 292 | "53.6", 293 | "60.8", 294 | "57.0", 295 | "73.0", 296 | "65.2", 297 | "72.2", 298 | "56.4", 299 | "51.2", 300 | "52.0", 301 | "51.8", 302 | "61.6", 303 | "47.6", 304 | "64.8", 305 | "54.8", 306 | "48.0", 307 | "61.2", 308 | "68.8", 309 | "59.4", 310 | "47.2", 311 | "56.8", 312 | "46.4", 313 | "71.2", 314 | "49.6", 315 | "73.6", 316 | "50.2", 317 | "70.8", 318 | "55.2", 319 | "69.8", 320 | "70.0", 321 | "59.0", 322 | "55.4", 323 | "62.8", 324 | "70.4", 325 | "48.8", 326 | "51.4", 327 | "44.4", 328 | "64.4", 329 | ], 330 | jlength: [ 331 | "88.0", 332 | "87.0", 333 | "71.2", 334 | "73.8", 335 | "93.8", 336 | "84.8", 337 | "76.4", 338 | "66.4", 339 | "80.2", 340 | "64.8", 341 | "93.0", 342 | "83.6", 343 | "71.8", 344 | "66.6", 345 | "90.8", 346 | "77.0", 347 | "79.4", 348 | "81.0", 349 | "73.0", 350 | "92.0", 351 | "64.6", 352 | "83.2", 353 | "90.4", 354 | "73.6", 355 | "86.6", 356 | "89.6", 357 | "72.6", 358 | "72.2", 359 | "64.0", 360 | "72.0", 361 | "78.6", 362 | "79.2", 363 | "86.2", 364 | "85.8", 365 | "66.8", 366 | "77.2", 367 | "75.2", 368 | "65.2", 369 | "88.8", 370 | "88.6", 371 | "82.0", 372 | "86.8", 373 | "92.2", 374 | "87.8", 375 | "80.0", 376 | "87.6", 377 | "65.8", 378 | "91.0", 379 | "79.0", 380 | "75.4", 381 | "85.6", 382 | "65.0", 383 | "82.8", 384 | "76.6", 385 | "70.0", 386 | "69.4", 387 | "84.4", 388 | "68.4", 389 | "91.6", 390 | "80.4", 391 | "72.4", 392 | "83.8", 393 | "79.6", 394 | "69.0", 395 | "89.2", 396 | "65.4", 397 | "66.0", 398 | "70.6", 399 | "68.0", 400 | "83.4", 401 | "69.6", 402 | "89.4", 403 | "64.2", 404 | "81.8", 405 | "78.2", 406 | "91.8", 407 | "92.4", 408 | "68.8", 409 | "84.6", 410 | "73.2", 411 | "67.0", 412 | "75.8", 413 | "69.2", 414 | "85.4", 415 | "87.2", 416 | "71.6", 417 | "75.6", 418 | "64.4", 419 | "86.4", 420 | "72.8", 421 | "69.8", 422 | "68.6", 423 | "67.2", 424 | "80.8", 425 | "85.2", 426 | "67.4", 427 | "77.8", 428 | "75.0", 429 | "76.0", 430 | "71.0", 431 | "90.2", 432 | "88.4", 433 | "77.6", 434 | "67.8", 435 | "86.0", 436 | "77.4", 437 | "71.4", 438 | "93.2", 439 | "91.2", 440 | "65.6", 441 | "85.0", 442 | "70.2", 443 | "92.8", 444 | "74.8", 445 | "70.8", 446 | "84.0", 447 | "89.8", 448 | "76.2", 449 | "67.6", 450 | "74.0", 451 | ], 452 | } 453 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/predicate_language.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | table_nationj: { 3 | "nationj_p($1, ?)": { 4 | model: "table", 5 | init: { "$1": "nation" }, 6 | questions: ["Who are the javelin throwers from $1?", "Which javelin throwers are from the country $1?"], 7 | steps: [ 8 | { 9 | operation: "select", 10 | question: "table_nationj(?, $1)", 11 | answer: "#1", 12 | }, 13 | ], 14 | }, 15 | "nationj_n($1, ?)": { 16 | model: "table", 17 | init: { "$1": "personj" }, 18 | questions: ["Which country does $1 play for?", "Which country is $1 from?"], 19 | steps: [ 20 | { 21 | operation: "select", 22 | question: "table_nationj($1, ?)", 23 | answer: "#1", 24 | }, 25 | ], 26 | }, 27 | }, 28 | table_nationd: { 29 | "nationd_p($1, ?)": { 30 | model: "table", 31 | init: { "$1": "nation" }, 32 | questions: ["Who are the discus throwers from $1?", "Which discus throwers are from the country $1?"], 33 | steps: [ 34 | { 35 | operation: "select", 36 | question: "table_nationd(?, $1)", 37 | answer: "#1", 38 | }, 39 | ], 40 | }, 41 | "nationd_n($1, ?)": { 42 | model: "table", 43 | init: { "$1": "persond" }, 44 | questions: ["Which country does $1 play for?", "Which country is $1 from?"], 45 | steps: [ 46 | { 47 | operation: "select", 48 | question: "table_nationd($1, ?)", 49 | answer: "#1", 50 | }, 51 | ], 52 | }, 53 | }, 54 | text_jthrow: { 55 | "jthrow_l($1, ?)": { 56 | model: "text", 57 | init: { "$1": "personj" }, 58 | questions: ["What were the lengths of the javelin throws by $1?", "What lengths were $1's javelin throws?"], 59 | steps: [ 60 | { 61 | operation: "select", 62 | question: "text_jthrow($1, ?)", 63 | answer: "#1", 64 | }, 65 | ], 66 | }, 67 | "jthrow_p($1, ?)": { 68 | model: "text", 69 | init: { "$1": "jlength" }, 70 | questions: ["Who threw the javelin for $1?", "Who was a javelin thrower for $1?"], 71 | steps: [ 72 | { 73 | operation: "select", 74 | question: "text_jthrow(?, $1)", 75 | answer: "#1", 76 | }, 77 | ], 78 | }, 79 | "jthrows(?)": { 80 | model: "text", 81 | init: {}, 82 | questions: ["Who performed javelin throws?", "Who threw javelins?"], 83 | steps: [ 84 | { 85 | operation: "select_unique", 86 | question: "text_jthrow(?, _)", 87 | answer: "#1", 88 | }, 89 | ], 90 | }, 91 | }, 92 | text_dthrow: { 93 | "dthrow_l($1, ?)": { 94 | model: "text", 95 | init: { "$1": "personj" }, 96 | questions: ["What were the lengths of the discus throws by $1?", "What lengths were $1's discus throws?"], 97 | steps: [ 98 | { 99 | operation: "select", 100 | question: "text_dthrow($1, ?)", 101 | answer: "#1", 102 | }, 103 | ], 104 | }, 105 | "dthrow_p($1, ?)": { 106 | model: "text", 107 | init: { "$1": "dlength" }, 108 | questions: ["Who threw the discus for $1?", "Who was a discus thrower for $1?"], 109 | steps: [ 110 | { 111 | operation: "select", 112 | question: "text_dthrow(?, $1)", 113 | answer: "#1", 114 | }, 115 | ], 116 | }, 117 | "dthrows(?)": { 118 | model: "text", 119 | init: {}, 120 | questions: ["Who performed discus throws?", "Who threw discus?"], 121 | steps: [ 122 | { 123 | operation: "select_unique", 124 | question: "text_dthrow(?, _)", 125 | answer: "#1", 126 | }, 127 | ], 128 | }, 129 | }, 130 | math_predicates: { 131 | "max($1)": { 132 | model: "math_special", 133 | init: { "$1": "list(dlength)" }, 134 | questions: ["Which is the largest value in $1?", "What is the largest value among $1?"], 135 | }, 136 | "min($1)": { 137 | model: "math_special", 138 | init: { "$1": "list(jlength)" }, 139 | questions: ["Which is the smallest value in $1?", "What is the smallest value among $1?"], 140 | }, 141 | "count($1)": { 142 | model: "math_special", 143 | init: { "$1": "list(dlength)" }, 144 | questions: ["How many items are there in $1?", "What is the length of $1?"], 145 | }, 146 | "is_smaller($1 | $2)": { 147 | model: "math_special", 148 | init: { "$1": "dlength", "$2": "dlength" }, 149 | questions: ["Is $1 smaller than $2?", "Is $1 less in value than $2?"], 150 | }, 151 | "is_greater($1 | $2)": { 152 | model: "math_special", 153 | init: { "$1": "jlength", "$2": "jlength" }, 154 | questions: ["Is $1 greater than $2?", "Is $1 higher in value than $2?"], 155 | }, 156 | "diff($1 | $2)": { 157 | model: "math_special", 158 | init: { "$1": "dlength", "$2": "jlength" }, 159 | questions: ["What is the difference between $1 and $2?", "What is the difference in values between $1 and $2?"], 160 | }, 161 | }, 162 | } 163 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/sports.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories.libsonnet"; 7 | local predicate_names = ["table_nationj", "table_nationd", "text_dthrow", "text_jthrow"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/sports_compgen.jsonnet: -------------------------------------------------------------------------------- 1 | local entities = import "entities.libsonnet"; 2 | local table_predicate = import "table_predicates.libsonnet"; 3 | local text_predicate = import "text_predicates.libsonnet"; 4 | local combined_predicates = table_predicate + text_predicate; 5 | local predicate_languages = import "predicate_language.libsonnet"; 6 | local theories = import "theories_compgen.libsonnet"; 7 | local predicate_names = ["table_nationj", "table_nationd", "text_dthrow", "text_jthrow"]; 8 | local predicates = { [p]: combined_predicates[p] for p in predicate_names }; 9 | local predicate_language = { [key]: predicate_languages[p][key] for p in predicate_names + ["math_predicates"] for key in std.objectFields(predicate_languages[p]) }; 10 | { 11 | version: 3.0, 12 | entities: entities, 13 | predicates: predicates, 14 | predicate_language: predicate_language, 15 | theories: theories, 16 | } 17 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/table_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | table_nationd: { 3 | args: ["persond", "nation"], 4 | nary: ["1", "*"], 5 | language: ["athlete: $1 ; country: $2; sport: Javelin Throw", "Athlete: $1 ; Nation: $2; Sport: Javelin"], 6 | }, 7 | table_nationj: { 8 | args: ["personj", "nation"], 9 | nary: ["1", "*"], 10 | language: ["athlete: $1 ; country: $2; sport: Discus Throw", "Athlete: $1 ; Nation: $2; Sport: Discus"], 11 | }, 12 | } 13 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/text_predicates.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | text_jthrow: { 3 | args: ["personj", "jlength"], 4 | nary: ["*", "1"], 5 | language: ["$1 hurled the javelin to a distance of $2", "$1 registered a throw of $2 in the javelin event"], 6 | }, 7 | text_dthrow: { 8 | args: ["persond", "dlength"], 9 | nary: ["*", "1"], 10 | language: ["$1 threw the discus to a distance of $2", "$1 registered a discus throw of $2"], 11 | }, 12 | } 13 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/theories.libsonnet: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | init: { 4 | "$1": "nation", 5 | }, 6 | questions: [ 7 | "What was the gap between the longest and shortest javelin throws by athletes from $1?", 8 | ], 9 | steps: [ 10 | { 11 | operation: "select", 12 | question: "nationj_p($1, ?)", 13 | answer: "#1", 14 | }, 15 | { 16 | operation: "project_values_flat", 17 | question: "jthrow_l(#1, ?)", 18 | answer: "#2", 19 | }, 20 | { 21 | operation: "select", 22 | question: "max(#2)", 23 | answer: "#3", 24 | }, 25 | { 26 | operation: "select", 27 | question: "min(#2)", 28 | answer: "#4", 29 | }, 30 | { 31 | operation: "select", 32 | question: "diff(#3 | #4)", 33 | answer: "#5", 34 | }, 35 | ], 36 | }, 37 | { 38 | init: { 39 | "$1": "persond", 40 | }, 41 | questions: [ 42 | "What was the gap between the longest and shortest discus throws by $1?", 43 | ], 44 | steps: [ 45 | { 46 | operation: "select", 47 | question: "dthrow_l($1, ?)", 48 | answer: "#1", 49 | }, 50 | { 51 | operation: "select", 52 | question: "max(#1)", 53 | answer: "#2", 54 | }, 55 | { 56 | operation: "select", 57 | question: "min(#1)", 58 | answer: "#3", 59 | }, 60 | { 61 | operation: "select", 62 | question: "diff(#2 | #3)", 63 | answer: "#4", 64 | }, 65 | ], 66 | }, 67 | { 68 | init: { 69 | "$1": "nation", 70 | "$2": "nation", 71 | }, 72 | questions: [ 73 | "What was the gap between the best javelin throws from $1 and $2?", 74 | ], 75 | steps: [ 76 | { 77 | operation: "select", 78 | question: "nationj_p($1, ?)", 79 | answer: "#1", 80 | }, 81 | { 82 | operation: "project_values_flat", 83 | question: "jthrow_l(#1, ?)", 84 | answer: "#2", 85 | }, 86 | { 87 | operation: "select", 88 | question: "max(#2)", 89 | answer: "#3", 90 | }, 91 | { 92 | operation: "select", 93 | question: "nationj_p($2, ?)", 94 | answer: "#4", 95 | }, 96 | { 97 | operation: "project_values_flat", 98 | question: "jthrow_l(#4, ?)", 99 | answer: "#5", 100 | }, 101 | { 102 | operation: "select", 103 | question: "max(#5)", 104 | answer: "#6", 105 | }, 106 | { 107 | operation: "select", 108 | question: "diff(#3 | #6)", 109 | answer: "#7", 110 | }, 111 | ], 112 | }, 113 | { 114 | init: { 115 | "$1": "jlength", 116 | }, 117 | questions: [ 118 | "Who threw javelins longer than $1?", 119 | ], 120 | steps: [ 121 | { 122 | operation: "select", 123 | question: "jthrows(?)", 124 | answer: "#1", 125 | }, 126 | { 127 | operation: "project", 128 | question: "jthrow_l(#1, ?)", 129 | answer: "#2", 130 | }, 131 | { 132 | operation: "projectValues", 133 | question: "max(#2)", 134 | answer: "#3", 135 | }, 136 | { 137 | operation: "filterValues_keys", 138 | question: "is_greater(#3 | $1)", 139 | answer: "#4", 140 | }, 141 | ], 142 | }, 143 | { 144 | init: { 145 | "$1": "dlength", 146 | }, 147 | questions: [ 148 | "Who threw discuses shorter than $1?", 149 | ], 150 | steps: [ 151 | { 152 | operation: "select", 153 | question: "dthrows(?)", 154 | answer: "#1", 155 | }, 156 | { 157 | operation: "project", 158 | question: "dthrow_l(#1, ?)", 159 | answer: "#2", 160 | }, 161 | { 162 | operation: "projectValues", 163 | question: "min(#2)", 164 | answer: "#3", 165 | }, 166 | { 167 | operation: "filterValues_keys", 168 | question: "is_smaller(#3 | $1)", 169 | answer: "#4", 170 | }, 171 | ], 172 | }, 173 | { 174 | init: { 175 | "$1": "dlength", 176 | }, 177 | questions: [ 178 | "How many discus throws were shorter than $1?", 179 | ], 180 | steps: [ 181 | { 182 | operation: "select", 183 | question: "dthrows(?)", 184 | answer: "#1", 185 | }, 186 | { 187 | operation: "project_values_flat", 188 | question: "dthrow_l(#1, ?)", 189 | answer: "#2", 190 | }, 191 | { 192 | operation: "filter(#2)", 193 | question: "is_smaller(#2 | $1)", 194 | answer: "#3", 195 | }, 196 | { 197 | operation: "select", 198 | question: "count(#3)", 199 | answer: "#4", 200 | }, 201 | ], 202 | }, 203 | ] 204 | -------------------------------------------------------------------------------- /configs/commaqav1/numeric/theories_compgen.libsonnet: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | init: { 4 | "$1": "nation", 5 | }, 6 | questions: [ 7 | "What was the gap between the longest and shortest discus throws by athletes from $1?", 8 | ], 9 | steps: [ 10 | { 11 | operation: "select", 12 | question: "nationd_p($1, ?)", 13 | answer: "#1", 14 | }, 15 | { 16 | operation: "project_values_flat", 17 | question: "dthrow_l(#1, ?)", 18 | answer: "#2", 19 | }, 20 | { 21 | operation: "select", 22 | question: "max(#2)", 23 | answer: "#3", 24 | }, 25 | { 26 | operation: "select", 27 | question: "min(#2)", 28 | answer: "#4", 29 | }, 30 | { 31 | operation: "select", 32 | question: "diff(#3 | #4)", 33 | answer: "#5", 34 | }, 35 | ], 36 | }, 37 | { 38 | init: { 39 | "$1": "personj", 40 | }, 41 | questions: [ 42 | "What was the gap between the longest and shortest javelin throws by $1?", 43 | ], 44 | steps: [ 45 | { 46 | operation: "select", 47 | question: "jthrow_l($1, ?)", 48 | answer: "#1", 49 | }, 50 | { 51 | operation: "select", 52 | question: "max(#1)", 53 | answer: "#2", 54 | }, 55 | { 56 | operation: "select", 57 | question: "min(#1)", 58 | answer: "#3", 59 | }, 60 | { 61 | operation: "select", 62 | question: "diff(#2 | #3)", 63 | answer: "#4", 64 | }, 65 | ], 66 | }, 67 | { 68 | init: { 69 | "$1": "nation", 70 | "$2": "nation", 71 | }, 72 | questions: [ 73 | "What was the gap between the best discus throws from $1 and $2?", 74 | ], 75 | steps: [ 76 | { 77 | operation: "select", 78 | question: "nationd_p($1, ?)", 79 | answer: "#1", 80 | }, 81 | { 82 | operation: "project_values_flat", 83 | question: "dthrow_l(#1, ?)", 84 | answer: "#2", 85 | }, 86 | { 87 | operation: "select", 88 | question: "max(#2)", 89 | answer: "#3", 90 | }, 91 | { 92 | operation: "select", 93 | question: "nationd_p($2, ?)", 94 | answer: "#4", 95 | }, 96 | { 97 | operation: "project_values_flat", 98 | question: "dthrow_l(#4, ?)", 99 | answer: "#5", 100 | }, 101 | { 102 | operation: "select", 103 | question: "max(#5)", 104 | answer: "#6", 105 | }, 106 | { 107 | operation: "select", 108 | question: "diff(#3 | #6)", 109 | answer: "#7", 110 | }, 111 | ], 112 | }, 113 | { 114 | init: { 115 | "$1": "jlength", 116 | }, 117 | questions: [ 118 | "Who threw discus throws longer than $1?", 119 | ], 120 | steps: [ 121 | { 122 | operation: "select", 123 | question: "dthrows(?)", 124 | answer: "#1", 125 | }, 126 | { 127 | operation: "project", 128 | question: "dthrow_l(#1, ?)", 129 | answer: "#2", 130 | }, 131 | { 132 | operation: "projectValues", 133 | question: "max(#2)", 134 | answer: "#3", 135 | }, 136 | { 137 | operation: "filterValues_keys", 138 | question: "is_greater(#3 | $1)", 139 | answer: "#4", 140 | }, 141 | ], 142 | }, 143 | { 144 | init: { 145 | "$1": "dlength", 146 | }, 147 | questions: [ 148 | "Who threw javelins shorter than $1?", 149 | ], 150 | steps: [ 151 | { 152 | operation: "select", 153 | question: "jthrows(?)", 154 | answer: "#1", 155 | }, 156 | { 157 | operation: "project", 158 | question: "jthrow_l(#1, ?)", 159 | answer: "#2", 160 | }, 161 | { 162 | operation: "projectValues", 163 | question: "min(#2)", 164 | answer: "#3", 165 | }, 166 | { 167 | operation: "filterValues_keys", 168 | question: "is_smaller(#3 | $1)", 169 | answer: "#4", 170 | }, 171 | ], 172 | }, 173 | { 174 | init: { 175 | "$1": "jlength", 176 | }, 177 | questions: [ 178 | "How many javelin throws were shorter than $1?", 179 | ], 180 | steps: [ 181 | { 182 | operation: "select", 183 | question: "jthrows(?)", 184 | answer: "#1", 185 | }, 186 | { 187 | operation: "project_values_flat", 188 | question: "jthrow_l(#1, ?)", 189 | answer: "#2", 190 | }, 191 | { 192 | operation: "filter(#2)", 193 | question: "is_smaller(#2 | $1)", 194 | answer: "#3", 195 | }, 196 | { 197 | operation: "select", 198 | question: "count(#3)", 199 | answer: "#4", 200 | }, 201 | ], 202 | }, 203 | ] 204 | -------------------------------------------------------------------------------- /configs/inference/commaqav1_beam_search.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "start_state": "gen", 3 | "end_state": "[EOQ]", 4 | "models": { 5 | "gen": { 6 | "name": "lmgen", 7 | "model_path": std.extVar("model_path"), 8 | "generation_args": { 9 | "max_length": 40, 10 | "num_return_sequences": 5, 11 | "top_p": 0.95, 12 | "top_k": 10, 13 | "do_sample": false, 14 | "num_beams": 10 15 | }, 16 | "encoder_args": { 17 | "add_special_tokens": false, 18 | "return_tensors": "pt" 19 | }, 20 | "decoder_args": { 21 | "clean_up_tokenization_spaces": true, 22 | "skip_special_tokens": true 23 | }, 24 | "next_model": "execute", 25 | "end_state": "[EOQ]" 26 | }, 27 | "execute": { 28 | "name": "operation_executer", 29 | "remodel_file": std.extVar("remodel_path") + "/" + std.extVar("filename"), 30 | "next_model": "gen", 31 | "skip_empty_answers": true 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /configs/inference/commaqav1_brute_force.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "start_state": "gen", 3 | "end_state": "[EOQ]", 4 | "models": { 5 | "gen": { 6 | "name": "randgen", 7 | "next_model": "execute", 8 | "end_state": "[EOQ]", 9 | "operations_file": std.extVar("lang_path") + "/operations.txt", 10 | "model_questions_file": std.extVar("lang_path") + "/model_questions.tsv", 11 | "sample_operations": std.parseInt(std.extVar("sample_operations_percent")) / 100, 12 | "sample_questions": std.parseInt(std.extVar("num_questions")), 13 | "max_steps": std.parseInt(std.extVar("max_steps")), 14 | "topk_questions": std.extVar("topk_questions") 15 | }, 16 | "execute": { 17 | "name": "operation_executer", 18 | "remodel_file": std.extVar("remodel_path") + "/" + std.extVar("filename"), 19 | "next_model": "chains", 20 | "skip_empty_answers": true 21 | }, 22 | "chains": { 23 | "name": "dump_chains", 24 | "output_file": std.extVar("output_dir") + "/all_chains.tsv", 25 | "next_model": "gen" 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/inference/commaqav1_greedy_search.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "start_state": "gen", 3 | "end_state": "[EOQ]", 4 | "models": { 5 | "gen": { 6 | "name": "lmgen", 7 | "model_path": std.extVar("model_path"), 8 | "generation_args": { 9 | "max_length": 40, 10 | "num_return_sequences": 1, 11 | "top_p": 0.95, 12 | "top_k": 10, 13 | "do_sample": false, 14 | "num_beams": 10 15 | }, 16 | "encoder_args": { 17 | "add_special_tokens": false, 18 | "return_tensors": "pt" 19 | }, 20 | "decoder_args": { 21 | "clean_up_tokenization_spaces": true, 22 | "skip_special_tokens": true 23 | }, 24 | "next_model": "execute", 25 | "end_state": "[EOQ]" 26 | }, 27 | "execute": { 28 | "name": "operation_executer", 29 | "remodel_file": std.extVar("remodel_path") + "/" + std.extVar("filename"), 30 | "next_model": "gen", 31 | "skip_empty_answers": true 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /configs/inference/commaqav1_sample_search.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "start_state": "gen", 3 | "end_state": "[EOQ]", 4 | "models": { 5 | "gen": { 6 | "name": "lmgen", 7 | "model_path": std.extVar("model_path"), 8 | "generation_args": { 9 | "max_length": 40, 10 | "num_return_sequences": 10, 11 | "top_p": 0.95, 12 | "top_k": 10, 13 | "do_sample": true, 14 | "num_beams": 1 15 | }, 16 | "encoder_args": { 17 | "add_special_tokens": false, 18 | "return_tensors": "pt" 19 | }, 20 | "decoder_args": { 21 | "clean_up_tokenization_spaces": true, 22 | "skip_special_tokens": true 23 | }, 24 | "next_model": "execute", 25 | "end_state": "[EOQ]" 26 | }, 27 | "execute": { 28 | "name": "operation_executer", 29 | "remodel_file": std.extVar("remodel_path") + "/" + std.extVar("filename"), 30 | "next_model": "gen", 31 | "skip_empty_answers": true 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonnet 2 | transformers==4.11.3 3 | torch==1.10.0 4 | # Needed for T5 models 5 | sentencepiece 6 | protobuf==3.19.0 7 | nltk 8 | # Only needed for DROP evaluation script 9 | scipy 10 | -------------------------------------------------------------------------------- /scripts/build_commaqav1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | # 4 | echo "Building CommaQA-E" 5 | bash scripts/build_datasets.sh \ 6 | configs/commaqav1/explicit/movies1.jsonnet,configs/commaqav1/explicit/movies2.jsonnet \ 7 | output/commaqav1/explicit/commaqa 8 | 9 | echo "Building CommaQA-I" 10 | bash scripts/build_datasets.sh \ 11 | configs/commaqav1/implicit/items0.jsonnet,configs/commaqav1/implicit/items1.jsonnet,configs/commaqav1/implicit/items2.jsonnet,configs/commaqav1/implicit/items3.jsonnet,configs/commaqav1/implicit/items4.jsonnet,configs/commaqav1/implicit/items5.jsonnet \ 12 | output/commaqav1/implicit/commaqa 13 | 14 | echo "Building CommaQA-N" 15 | bash scripts/build_datasets.sh \ 16 | configs/commaqav1/numeric/sports.jsonnet \ 17 | output/commaqav1/numeric/commaqa 18 | 19 | 20 | echo "Building decompositions" 21 | bash scripts/build_decompositions.sh \ 22 | output/commaqav1/numeric/commaqa \ 23 | output/commaqav1/numeric/decomp 24 | 25 | bash scripts/build_decompositions.sh \ 26 | output/commaqav1/implicit/commaqa \ 27 | output/commaqav1/implicit/decomp 28 | 29 | bash scripts/build_decompositions.sh \ 30 | output/commaqav1/explicit/commaqa \ 31 | output/commaqav1/explicit/decomp 32 | 33 | 34 | echo "Create language" 35 | 36 | for d in explicit implicit numeric; 37 | do 38 | mkdir -p output/commaqav1/${d}/language 39 | jq -r ".predicate_language[]|[.model, .questions[]]|@tsv" \ 40 | output/commaqav1/${d}/commaqa/source*.json | sed 's/$[0-9]/__/g' | sort -u > \ 41 | output/commaqav1/${d}/language/model_questions.tsv 42 | done 43 | 44 | 45 | for f in explicit implicit numeric; 46 | do 47 | mkdir -p output/commaqav1/${f}/restricted_language 48 | jq -r ".[].qa_pairs[].decomposition[]|[.m, .q]|@tsv" \ 49 | output/commaqav1/${f}/train.json | \ 50 | sed 's/[$#][0-9]/__/g' | sort -u > \ 51 | output/commaqav1/${f}/restricted_language/model_questions.tsv 52 | jq -r ".[].qa_pairs[].decomposition[].op" \ 53 | output/commaqav1/${f}/train.json | \ 54 | sed 's/[$#][0-9]/__/g' | sort -u > \ 55 | output/commaqav1/${f}/restricted_language/operations.txt 56 | done 57 | 58 | mkdir -p output/commaqav1/compgen/ 59 | python commaqa/dataset/build_dataset.py \ 60 | --input_json configs/commaqav1/numeric/sports_compgen.jsonnet \ 61 | --output output/commaqav1/compgen/numeric_test.json --entity_percent 0.2 \ 62 | --num_groups 100 --num_examples_per_group 5 63 | 64 | python commaqa/dataset/build_dataset.py \ 65 | --input_json configs/commaqav1/explicit/movies1_compgen.jsonnet,configs/commaqav1/explicit/movies2_compgen.jsonnet \ 66 | --output output/commaqav1/compgen/explicit_test.json --entity_percent 0.2 \ 67 | --num_groups 100 --num_examples_per_group 5 -------------------------------------------------------------------------------- /scripts/build_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | # FORMAT: sh build_dataset.sh json_files output_directory 4 | json_files=$1 5 | output_dir=$2 6 | ent_per=${ent_per-0.2} 7 | groups=${groups-2000} 8 | egs=${egs-5} 9 | 10 | export PYTHONPATH="." 11 | 12 | mkdir -p ${output_dir} 13 | 14 | python commaqa/dataset/build_dataset.py \ 15 | --input_json ${json_files} \ 16 | --output ${output_dir} --entity_percent ${ent_per} \ 17 | --num_groups ${groups} --num_examples_per_group ${egs} 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts/build_decompositions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | input_dir=$1 5 | output_dir=$2 6 | 7 | if [ "$input_dir" = "$output_dir" ]; then 8 | echo "Same directory specified as input and output directory. Will overwrite files!" 9 | exit 1 10 | fi 11 | 12 | mkdir -p $2 13 | 14 | export PYTHONPATH="." 15 | 16 | for f in train dev test; 17 | do 18 | python commaqa/dataset/generate_decomposition_predictions.py \ 19 | --input_json ${input_dir}/${f}.json --decomp_json ${output_dir}/${f}.json 20 | done 21 | -------------------------------------------------------------------------------- /scripts/build_docker_image.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | IMAGE_NAME=commaqa 6 | DOCKERFILE_NAME=Dockerfile 7 | 8 | # Image name 9 | GIT_HASH=`git log --format="%h" -n 1` 10 | IMAGE=$IMAGE_NAME_$USER-$GIT_HASH 11 | 12 | docker build -f $DOCKERFILE_NAME -t $IMAGE . 13 | 14 | echo -e "\033[0;32m Built image $IMAGE. If using Beaker, now run: \033[0m" 15 | echo -e "\033[0;35m beaker image create --name=$IMAGE --description \"CommaQA Repo; Git Hash: $GIT_HASH\" $IMAGE \033[0m" 16 | --------------------------------------------------------------------------------