├── .flake8 ├── .github ├── dependabot.yml └── workflows │ └── main.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── allennlp_semparse ├── __init__.py ├── common │ ├── __init__.py │ ├── action_space_walker.py │ ├── date.py │ ├── errors.py │ ├── knowledge_graph.py │ ├── sql │ │ ├── __init__.py │ │ └── text2sql_utils.py │ ├── util.py │ └── wikitables │ │ ├── __init__.py │ │ ├── table_question_context.py │ │ └── wikitables_evaluator.py ├── dataset_readers │ ├── __init__.py │ ├── atis.py │ ├── grammar_based_text2sql.py │ ├── nlvr.py │ ├── template_text2sql.py │ └── wikitables.py ├── domain_languages │ ├── __init__.py │ ├── domain_language.py │ ├── nlvr_language.py │ └── wikitables_language.py ├── fields │ ├── __init__.py │ ├── knowledge_graph_field.py │ └── production_rule_field.py ├── models │ ├── __init__.py │ ├── atis │ │ ├── __init__.py │ │ └── atis_semantic_parser.py │ ├── nlvr │ │ ├── __init__.py │ │ ├── nlvr_coverage_semantic_parser.py │ │ ├── nlvr_direct_semantic_parser.py │ │ └── nlvr_semantic_parser.py │ ├── text2sql_parser.py │ └── wikitables │ │ ├── __init__.py │ │ ├── wikitables_erm_semantic_parser.py │ │ ├── wikitables_mml_semantic_parser.py │ │ └── wikitables_semantic_parser.py ├── nltk_languages │ ├── __init__.py │ ├── contexts │ │ └── __init__.py │ ├── type_declarations │ │ ├── __init__.py │ │ └── type_declaration.py │ └── worlds │ │ ├── __init__.py │ │ └── world.py ├── parsimonious_languages │ ├── __init__.py │ ├── contexts │ │ ├── __init__.py │ │ ├── atis_sql_table_context.py │ │ ├── atis_tables.py │ │ ├── sql_context_utils.py │ │ └── text2sql_table_context.py │ ├── executors │ │ ├── __init__.py │ │ └── sql_executor.py │ └── worlds │ │ ├── __init__.py │ │ ├── atis_world.py │ │ └── text2sql_world.py ├── predictors │ ├── __init__.py │ ├── atis_parser.py │ ├── nlvr_parser.py │ └── wikitables_parser.py ├── state_machines │ ├── __init__.py │ ├── beam_search.py │ ├── constrained_beam_search.py │ ├── states │ │ ├── __init__.py │ │ ├── checklist_statelet.py │ │ ├── coverage_state.py │ │ ├── grammar_based_state.py │ │ ├── grammar_statelet.py │ │ ├── lambda_grammar_statelet.py │ │ ├── rnn_statelet.py │ │ └── state.py │ ├── trainers │ │ ├── __init__.py │ │ ├── decoder_trainer.py │ │ ├── expected_risk_minimization.py │ │ └── maximum_marginal_likelihood.py │ ├── transition_functions │ │ ├── __init__.py │ │ ├── basic_transition_function.py │ │ ├── coverage_transition_function.py │ │ ├── linking_coverage_transition_function.py │ │ ├── linking_transition_function.py │ │ └── transition_function.py │ └── util.py └── version.py ├── codecov.yml ├── dev-requirements.txt ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── scripts ├── examine_sql_coverage.py ├── get_version.py ├── nlvr │ ├── generate_data_from_erm_model.py │ ├── get_nlvr_logical_forms.py │ ├── group_nlvr_worlds.py │ └── sed_commands.txt ├── reformat_text2sql_data.py └── wikitables │ ├── generate_data_from_erm_model.py │ ├── preprocess_data.py │ └── search_for_logical_forms.py ├── setup.py ├── test_fixtures ├── atis │ ├── experiment.json │ └── serialization │ │ ├── best.th │ │ ├── model.tar.gz │ │ └── vocabulary │ │ ├── non_padded_namespaces.txt │ │ ├── rule_labels.txt │ │ └── tokens.txt ├── data │ ├── atis │ │ └── sample.json │ ├── corenlp_processed_tables │ │ ├── TEST-1.table │ │ ├── TEST-10.table │ │ ├── TEST-11.table │ │ ├── TEST-2.table │ │ ├── TEST-3.table │ │ ├── TEST-4.table │ │ ├── TEST-5.table │ │ ├── TEST-6.table │ │ ├── TEST-7.table │ │ ├── TEST-8.table │ │ └── TEST-9.table │ ├── nlvr │ │ ├── sample_grouped_data.jsonl │ │ ├── sample_processed_data.jsonl │ │ └── sample_ungrouped_data.jsonl │ ├── quarel.jsonl │ ├── text2sql │ │ ├── restaurants-schema.csv │ │ ├── restaurants.db │ │ └── restaurants_tiny.json │ └── wikitables │ │ ├── action_space_walker_output │ │ ├── nt-0.gz │ │ └── nt-1.gz │ │ ├── action_space_walker_output_with_single_tarball │ │ └── all_lfs_tarball.tar.gz │ │ ├── dpd_output │ │ ├── nt-0.gz │ │ ├── nt-1.gz │ │ └── nt-64.gz │ │ ├── lots_of_ors_example.examples │ │ ├── sample_data.examples │ │ ├── sample_data_preprocessed.jsonl │ │ ├── sample_table.tagged │ │ ├── sample_table.tsv │ │ ├── sample_table_with_date.tsv │ │ └── tables │ │ ├── 109.tsv │ │ ├── 341.tagged │ │ ├── 346.tagged │ │ ├── 590.csv │ │ ├── 590.tagged │ │ ├── 590.tsv │ │ ├── 622.csv │ │ ├── 622.tagged │ │ └── 622.tsv ├── elmo │ ├── config │ │ └── characters_token_embedder.json │ ├── elmo_token_embeddings.hdf5 │ ├── lm_embeddings_0.hdf5 │ ├── lm_embeddings_1.hdf5 │ ├── lm_embeddings_2.hdf5 │ ├── lm_weights.hdf5 │ ├── options.json │ ├── sentences.json │ └── vocab_test.txt ├── nlvr_coverage_semantic_parser │ ├── experiment.json │ ├── mml_init_experiment.json │ ├── serialization │ │ ├── best.th │ │ ├── model.tar.gz │ │ └── vocabulary │ │ │ ├── denotations.txt │ │ │ ├── non_padded_namespaces.txt │ │ │ ├── rule_labels.txt │ │ │ └── tokens.txt │ └── ungrouped_experiment.json ├── nlvr_direct_semantic_parser │ ├── experiment.json │ └── serialization │ │ ├── best.th │ │ ├── model.tar.gz │ │ └── vocabulary │ │ ├── denotations.txt │ │ ├── non_padded_namespaces.txt │ │ ├── rule_labels.txt │ │ └── tokens.txt ├── text2sql │ └── experiment.json └── wikitables │ ├── experiment-elmo-no-features.json │ ├── experiment-erm.json │ ├── experiment-mixture.json │ ├── experiment.json │ └── serialization │ ├── best.th │ ├── model.tar.gz │ └── vocabulary │ ├── non_padded_namespaces.txt │ ├── rule_labels.txt │ └── tokens.txt ├── tests ├── __init__.py ├── common │ ├── __init__.py │ ├── action_space_walker_test.py │ ├── date_test.py │ ├── sql │ │ ├── __init__.py │ │ └── text2sql_utils_test.py │ ├── util_test.py │ └── wikitables │ │ ├── __init__.py │ │ └── table_question_context_test.py ├── dataset_readers │ ├── __init__.py │ ├── atis_test.py │ ├── grammar_based_text2sql_test.py │ ├── nlvr_test.py │ ├── template_text2sql_test.py │ └── wikitables_test.py ├── domain_languages │ ├── __init__.py │ ├── domain_language_test.py │ ├── nlvr_language_test.py │ └── wikitables_language_test.py ├── fields │ ├── __init__.py │ ├── knowledge_graph_field_test.py │ └── production_rule_field_test.py ├── models │ ├── __init__.py │ ├── atis │ │ ├── __init__.py │ │ ├── atis_grammar_statelet_test.py │ │ └── atis_semantic_parser_test.py │ ├── nlvr │ │ ├── __init__.py │ │ ├── nlvr_coverage_semantic_parser_test.py │ │ └── nlvr_direct_semantic_parser_test.py │ ├── quarel │ │ └── __init__.py │ ├── text2sql_parser_test.py │ └── wikitables │ │ ├── __init__.py │ │ ├── wikitables_erm_semantic_parser_test.py │ │ └── wikitables_mml_semantic_parser_test.py ├── nltk_languages │ ├── __init__.py │ ├── contexts │ │ └── __init__.py │ ├── type_declarations │ │ ├── __init__.py │ │ └── type_declaration_test.py │ └── worlds │ │ ├── __init__.py │ │ └── world_test.py ├── parsimonious_languages │ ├── __init__.py │ ├── contexts │ │ └── __init__.py │ ├── executors │ │ ├── __init__.py │ │ └── sql_executor_test.py │ └── worlds │ │ ├── __init__.py │ │ ├── atis_world_test.py │ │ └── text2sql_world_test.py ├── predictors │ ├── __init__.py │ ├── atis_parser_test.py │ ├── nlvr_parser_test.py │ └── wikitables_parser_test.py ├── semparse_test_case.py └── state_machines │ ├── __init__.py │ ├── beam_search_test.py │ ├── constrained_beam_search_test.py │ ├── simple_transition_system.py │ ├── states │ ├── __init__.py │ ├── grammar_statelet_test.py │ └── lambda_grammar_statelet_test.py │ ├── trainers │ ├── __init__.py │ ├── expected_risk_minimization_test.py │ └── maximum_marginal_likelihood_test.py │ ├── transition_functions │ ├── __init__.py │ └── basic_transition_function_test.py │ └── util_test.py └── training_config ├── wikitables_erm_parser.jsonnet └── wikitables_mml_parser.jsonnet /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | # these rules don't play well with black 6 | E203 # whitespace before : 7 | W503 # line break before binary operator 8 | 9 | exclude = 10 | build/** 11 | doc/** 12 | 13 | per-file-ignores = 14 | # __init__.py files are allowed to have unused imports and lines-too-long 15 | allennlp_semparse/__init__.py:F401 16 | allennlp_semparse/**/__init__.py:F401,E501 17 | 18 | # tests don't have to respect 19 | # E731: do not assign a lambda expression, use a def 20 | # F401: unused imports 21 | tests/**:E731,F401 22 | 23 | # E402: module level import not at top of file 24 | scripts/**:E402 25 | 26 | # E266: too many leading '#' for block comment 27 | allennlp_semparse/common/wikitables/wikitables_evaluator.py:E266 28 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | /*.iml 107 | /.idea/ 108 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7.2 2 | 3 | # Setup a spot for the code 4 | WORKDIR /allennlp_semparse 5 | 6 | # Install Python dependencies 7 | COPY requirements.txt requirements.txt 8 | RUN pip install -r requirements.txt 9 | 10 | COPY .flake8 .flake8 11 | COPY pytest.ini pytest.ini 12 | COPY pyproject.toml pyproject.toml 13 | COPY training_config/ training_config/ 14 | COPY tests tests/ 15 | COPY test_fixtures test_fixtures/ 16 | COPY allennlp_semparse allennlp_semparse/ 17 | 18 | CMD ["/bin/bash"] 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # allennlp-semparse 2 | [![Build status](https://github.com/allenai/allennlp-semparse/workflows/CI/badge.svg)](https://github.com/allenai/allennlp-semparse/actions?workflow=CI) 3 | [![PyPI](https://img.shields.io/pypi/v/allennlp-semparse)](https://pypi.org/project/allennlp-semparse/) 4 | [![codecov](https://codecov.io/gh/allenai/allennlp-semparse/branch/master/graph/badge.svg)](https://codecov.io/gh/allenai/allennlp-semparse) 5 | 6 | A framework for building semantic parsers (including neural module networks) with AllenNLP, built by the authors of AllenNLP 7 | 8 | ## Installing 9 | 10 | `allennlp-semparse` is available on PyPI. You can install through `pip` with 11 | 12 | ``` 13 | pip install allennlp-semparse 14 | ``` 15 | 16 | ## Supported datasets 17 | 18 | - ATIS 19 | - Text2SQL 20 | - NLVR 21 | - WikiTableQuestions 22 | 23 | ## Supported models 24 | 25 | - Grammar-based decoding models, following the parser originally introduced in [Neural 26 | Semantic Parsing with Type Constraints for Semi-Structured 27 | Tables](https://www.semanticscholar.org/paper/Neural-Semantic-Parsing-with-Type-Constraints-for-Krishnamurthy-Dasigi/8c6f58ed0ebf379858c0bbe02c53ee51b3eb398a). 28 | The models that are currently checked in are all based on this parser, applied to various datasets. 29 | - Neural module networks. We don't have models checked in for this yet, but `DomainLanguage` 30 | supports defining them, and we will add some models to the repo once papers go through peer 31 | review. The code is slow (batching is hard), but it works. 32 | 33 | 34 | ## Tutorials 35 | 36 | Coming sometime in the future... You can look at [this old 37 | tutorial](https://github.com/allenai/allennlp/blob/master/tutorials/getting_started/semantic_parsing.md), 38 | but the part about using NLTK to define a grammar is outdated. Now you can use `DomainLanguage` to 39 | define a python executor, and we analyze the type annotations in the functions in that executor to 40 | automatically infer a grammar for you. It is much easier to use than it used to be. Until we get 41 | around to writing a better tutorial for this, the best way to get started using this is to look at 42 | some examples. The simplest is the [Arithmetic 43 | language](https://github.com/allenai/allennlp-semparse/blob/master/tests/domain_languages/domain_language_test.py) 44 | in the `DomainLanguage` test (there's also a bit of description in the [`DomainLanguage` 45 | docstring](https://github.com/allenai/allennlp-semparse/blob/bbc8fde3a354ee1708ae900f09be9aa2adc8177f/allennlp_semparse/domain_languages/domain_language.py#L204-L270)). 46 | After looking at those, you can look at more complex (real) examples in the [`domain_languages` 47 | module](https://github.com/allenai/allennlp-semparse/tree/master/allennlp_semparse/domain_languages). 48 | Note that the executor you define can have _learned parameters_, making it a neural module network. 49 | The best place to get an example of that is currently [this unfinished implementation of N2NMNs on 50 | the CLEVR 51 | dataset](https://github.com/matt-gardner/allennlp/blob/neural_module_networks/allennlp/semparse/domain_languages/visual_reasoning_language.py). 52 | We'll have more examples of doing this in the not-too-distant future. 53 | -------------------------------------------------------------------------------- /allennlp_semparse/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.common.action_space_walker import ActionSpaceWalker 2 | from allennlp_semparse.domain_languages.domain_language import ( 3 | DomainLanguage, 4 | predicate, 5 | predicate_with_side_args, 6 | ) 7 | -------------------------------------------------------------------------------- /allennlp_semparse/common/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.common.date import Date 2 | from allennlp_semparse.common.errors import ParsingError, ExecutionError 3 | from allennlp_semparse.common.util import ( 4 | NUMBER_CHARACTERS, 5 | MONTH_NUMBERS, 6 | ORDER_OF_MAGNITUDE_WORDS, 7 | NUMBER_WORDS, 8 | ) 9 | -------------------------------------------------------------------------------- /allennlp_semparse/common/date.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.common.errors import ExecutionError 2 | 3 | 4 | class Date: 5 | def __init__(self, year: int, month: int, day: int) -> None: 6 | self.year = year 7 | self.month = month 8 | self.day = day 9 | 10 | def __eq__(self, other) -> bool: 11 | # Note that the logic below renders equality to be non-transitive. That is, 12 | # Date(2018, -1, -1) == Date(2018, 2, 3) and Date(2018, -1, -1) == Date(2018, 4, 5) 13 | # but Date(2018, 2, 3) != Date(2018, 4, 5). 14 | if not isinstance(other, Date): 15 | raise ExecutionError("only compare Dates with Dates") 16 | year_is_same = self.year == -1 or other.year == -1 or self.year == other.year 17 | month_is_same = self.month == -1 or other.month == -1 or self.month == other.month 18 | day_is_same = self.day == -1 or other.day == -1 or self.day == other.day 19 | return year_is_same and month_is_same and day_is_same 20 | 21 | def __gt__(self, other) -> bool: 22 | # The logic below is tricky, and is based on some assumptions we make about date comparison. 23 | # Year, month or day being -1 means that we do not know its value. In those cases, the 24 | # we consider the comparison to be undefined, and return False if all the fields that are 25 | # more significant than the field being compared are equal. However, when year is -1 for both 26 | # dates being compared, it is safe to assume that the year is not specified because it is 27 | # the same. So we make an exception just in that case. That is, we deem the comparison 28 | # undefined only when one of the year values is -1, but not both. 29 | if not isinstance(other, Date): 30 | raise ExecutionError("only compare Dates with Dates") 31 | # We're doing an exclusive or below. 32 | if (self.year == -1) != (other.year == -1): 33 | return False # comparison undefined 34 | # If both years are -1, we proceed. 35 | if self.year != other.year: 36 | return self.year > other.year 37 | # The years are equal and not -1, or both are -1. 38 | if self.month == -1 or other.month == -1: 39 | return False 40 | if self.month != other.month: 41 | return self.month > other.month 42 | # The months and years are equal and not -1 43 | if self.day == -1 or other.day == -1: 44 | return False 45 | return self.day > other.day 46 | 47 | def __ge__(self, other) -> bool: 48 | if not isinstance(other, Date): 49 | raise ExecutionError("only compare Dates with Dates") 50 | return self > other or self == other 51 | 52 | def __str__(self): 53 | if (self.month, self.day) == (-1, -1): 54 | # If we have only the year, return just that so that the official evaluator does the 55 | # comparison against the target as if both are numbers. 56 | return str(self.year) 57 | return f"{self.year}-{self.month}-{self.day}" 58 | 59 | def __hash__(self): 60 | return hash(str(self)) 61 | 62 | def to_json(self): 63 | return str(self) 64 | 65 | @classmethod 66 | def make_date(cls, string: str) -> "Date": 67 | year_string, month_string, day_string = string.split("-") 68 | year = -1 69 | month = -1 70 | day = -1 71 | try: 72 | year = int(year_string) 73 | except ValueError: 74 | pass 75 | try: 76 | month = int(month_string) 77 | except ValueError: 78 | pass 79 | try: 80 | day = int(day_string) 81 | except ValueError: 82 | pass 83 | return Date(year, month, day) 84 | -------------------------------------------------------------------------------- /allennlp_semparse/common/errors.py: -------------------------------------------------------------------------------- 1 | class ParsingError(Exception): 2 | """ 3 | This exception gets raised when there is a parsing error during logical form processing. This 4 | might happen because you're not handling the full set of possible logical forms, for instance, 5 | and having this error provides a consistent way to catch those errors and log how frequently 6 | this occurs. 7 | """ 8 | 9 | def __init__(self, message): 10 | super().__init__() 11 | self.message = message 12 | 13 | def __str__(self): 14 | return repr(self.message) 15 | 16 | 17 | class ExecutionError(Exception): 18 | """ 19 | This exception gets raised when you're trying to execute a logical form that your executor does 20 | not understand. This may be because your logical form contains a function with an invalid name 21 | or a set of arguments whose types do not match those that the function expects. 22 | """ 23 | 24 | def __init__(self, message): 25 | super().__init__() 26 | self.message = message 27 | 28 | def __str__(self): 29 | return repr(self.message) 30 | -------------------------------------------------------------------------------- /allennlp_semparse/common/knowledge_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | A ``KnowledgeGraph`` is a graphical representation of some structured knowledge source: say a 3 | table, figure or an explicit knowledge base. 4 | """ 5 | 6 | from typing import Dict, List, Set 7 | 8 | 9 | class KnowledgeGraph: 10 | """ 11 | A ``KnowledgeGraph`` represents a collection of entities and their relationships. 12 | 13 | The ``KnowledgeGraph`` currently stores (untyped) neighborhood information and text 14 | representations of each entity (if there is any). 15 | 16 | The knowledge base itself can be a table (like in WikitableQuestions), a figure (like in NLVR) 17 | or some other structured knowledge source. This abstract class needs to be inherited for 18 | implementing the functionality appropriate for a given KB. 19 | 20 | All of the parameters listed below are stored as public attributes. 21 | 22 | Parameters 23 | ---------- 24 | entities : ``Set[str]`` 25 | The string identifiers of the entities in this knowledge graph. We sort this set and store 26 | it as a list. The sorting is so that we get a guaranteed consistent ordering across 27 | separate runs of the code. 28 | neighbors : ``Dict[str, List[str]]`` 29 | A mapping from string identifiers to other string identifiers, denoting which entities are 30 | neighbors in the graph. 31 | entity_text : ``Dict[str, str]`` 32 | If you have additional text associated with each entity (other than its string identifier), 33 | you can store that here. This might be, e.g., the text in a table cell, or the description 34 | of a wikipedia entity. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | entities: Set[str], 40 | neighbors: Dict[str, List[str]], 41 | entity_text: Dict[str, str] = None, 42 | ) -> None: 43 | self.entities = sorted(entities) 44 | self.neighbors = neighbors 45 | self.entity_text = entity_text 46 | 47 | def __eq__(self, other): 48 | if isinstance(self, other.__class__): 49 | return self.__dict__ == other.__dict__ 50 | return NotImplemented 51 | -------------------------------------------------------------------------------- /allennlp_semparse/common/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/common/sql/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/common/util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | NUMBER_CHARACTERS = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "-"} 5 | MONTH_NUMBERS = { 6 | "january": 1, 7 | "jan": 1, 8 | "february": 2, 9 | "feb": 2, 10 | "march": 3, 11 | "mar": 3, 12 | "april": 4, 13 | "apr": 4, 14 | "may": 5, 15 | "june": 6, 16 | "jun": 6, 17 | "july": 7, 18 | "jul": 7, 19 | "august": 8, 20 | "aug": 8, 21 | "september": 9, 22 | "sep": 9, 23 | "october": 10, 24 | "oct": 10, 25 | "november": 11, 26 | "nov": 11, 27 | "december": 12, 28 | "dec": 12, 29 | } 30 | ORDER_OF_MAGNITUDE_WORDS = {"hundred": 100, "thousand": 1000, "million": 1000000} 31 | NUMBER_WORDS = { 32 | "zero": 0, 33 | "one": 1, 34 | "two": 2, 35 | "three": 3, 36 | "four": 4, 37 | "five": 5, 38 | "six": 6, 39 | "seven": 7, 40 | "eight": 8, 41 | "nine": 9, 42 | "ten": 10, 43 | "first": 1, 44 | "second": 2, 45 | "third": 3, 46 | "fourth": 4, 47 | "fifth": 5, 48 | "sixth": 6, 49 | "seventh": 7, 50 | "eighth": 8, 51 | "ninth": 9, 52 | "tenth": 10, 53 | **MONTH_NUMBERS, 54 | } 55 | 56 | 57 | def lisp_to_nested_expression(lisp_string: str) -> List: 58 | """ 59 | Takes a logical form as a lisp string and returns a nested list representation of the lisp. 60 | For example, "(count (division first))" would get mapped to ['count', ['division', 'first']]. 61 | """ 62 | stack: List = [] 63 | current_expression: List = [] 64 | tokens = lisp_string.split() 65 | for token in tokens: 66 | while token[0] == "(": 67 | nested_expression: List = [] 68 | current_expression.append(nested_expression) 69 | stack.append(current_expression) 70 | current_expression = nested_expression 71 | token = token[1:] 72 | current_expression.append(token.replace(")", "")) 73 | while token[-1] == ")": 74 | current_expression = stack.pop() 75 | token = token[:-1] 76 | return current_expression[0] 77 | -------------------------------------------------------------------------------- /allennlp_semparse/common/wikitables/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.common.wikitables.table_question_context import ( 2 | TableQuestionContext, 3 | CellValueType, 4 | ) 5 | -------------------------------------------------------------------------------- /allennlp_semparse/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.dataset_readers.atis import AtisDatasetReader 2 | from allennlp_semparse.dataset_readers.grammar_based_text2sql import ( 3 | GrammarBasedText2SqlDatasetReader, 4 | ) 5 | from allennlp_semparse.dataset_readers.nlvr import NlvrDatasetReader 6 | from allennlp_semparse.dataset_readers.template_text2sql import TemplateText2SqlDatasetReader 7 | from allennlp_semparse.dataset_readers.wikitables import WikiTablesDatasetReader 8 | -------------------------------------------------------------------------------- /allennlp_semparse/dataset_readers/template_text2sql.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import logging 3 | import json 4 | import glob 5 | import os 6 | 7 | 8 | from allennlp.common.file_utils import cached_path 9 | from allennlp.data import DatasetReader 10 | from allennlp.data.fields import TextField, Field, SequenceLabelField, LabelField 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.tokenizers import Token 13 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 14 | 15 | from allennlp_semparse.common.sql import text2sql_utils 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @DatasetReader.register("template_text2sql") 22 | class TemplateText2SqlDatasetReader(DatasetReader): 23 | """ 24 | Reads text2sql data for the sequence tagging and template prediction baseline 25 | from `"Improving Text to SQL Evaluation Methodology" `_. 26 | 27 | Parameters 28 | ---------- 29 | use_all_sql : ``bool``, optional (default = False) 30 | Whether to use all of the sql queries which have identical semantics, 31 | or whether to just use the first one. 32 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) 33 | We use this to define the input representation for the text. See :class:`TokenIndexer`. 34 | Note that the `output` tags will always correspond to single token IDs based on how they 35 | are pre-tokenised in the data file. 36 | cross_validation_split_to_exclude : ``int``, optional (default = None) 37 | Some of the text2sql datasets are very small, so you may need to do cross validation. 38 | Here, you can specify a integer corresponding to a split_{int}.json file not to include 39 | in the training set. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | use_all_sql: bool = False, 45 | token_indexers: Dict[str, TokenIndexer] = None, 46 | cross_validation_split_to_exclude: int = None, 47 | **kwargs, 48 | ) -> None: 49 | super().__init__(**kwargs) 50 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 51 | self._use_all_sql = use_all_sql 52 | self._cross_validation_split_to_exclude = str(cross_validation_split_to_exclude) 53 | 54 | def _read(self, file_path: str): 55 | """ 56 | This dataset reader consumes the data from 57 | https://github.com/jkkummerfeld/text2sql-data/tree/master/data 58 | formatted using ``scripts/reformat_text2sql_data.py``. 59 | 60 | Parameters 61 | ---------- 62 | file_path : ``str``, required. 63 | For this dataset reader, file_path can either be a path to a file `or` a 64 | path to a directory containing json files. The reason for this is because 65 | some of the text2sql datasets require cross validation, which means they are split 66 | up into many small files, for which you only want to exclude one. 67 | """ 68 | files = [ 69 | p 70 | for p in glob.glob(file_path) 71 | if self._cross_validation_split_to_exclude not in os.path.basename(p) 72 | ] 73 | 74 | for path in files: 75 | with open(cached_path(path), "r") as data_file: 76 | data = json.load(data_file) 77 | 78 | for sql_data in text2sql_utils.process_sql_data(data, self._use_all_sql): 79 | template = " ".join(sql_data.sql) 80 | yield self.text_to_instance(sql_data.text, sql_data.variable_tags, template) 81 | 82 | def text_to_instance( 83 | self, # type: ignore 84 | query: List[str], 85 | slot_tags: List[str] = None, 86 | sql_template: str = None, 87 | ) -> Instance: 88 | fields: Dict[str, Field] = {} 89 | tokens = TextField([Token(t) for t in query], self._token_indexers) 90 | fields["tokens"] = tokens 91 | 92 | if slot_tags is not None and sql_template is not None: 93 | slot_field = SequenceLabelField(slot_tags, tokens, label_namespace="slot_tags") 94 | template = LabelField(sql_template, label_namespace="template_labels") 95 | fields["slot_tags"] = slot_field 96 | fields["template"] = template 97 | 98 | return Instance(fields) 99 | -------------------------------------------------------------------------------- /allennlp_semparse/domain_languages/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.domain_languages.domain_language import ( 2 | DomainLanguage, 3 | START_SYMBOL, 4 | predicate, 5 | predicate_with_side_args, 6 | ) 7 | from allennlp_semparse.domain_languages.nlvr_language import NlvrLanguage 8 | from allennlp_semparse.domain_languages.wikitables_language import WikiTablesLanguage 9 | -------------------------------------------------------------------------------- /allennlp_semparse/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.fields.knowledge_graph_field import KnowledgeGraphField 2 | from allennlp_semparse.fields.production_rule_field import ProductionRuleField 3 | -------------------------------------------------------------------------------- /allennlp_semparse/models/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.models.atis.atis_semantic_parser import AtisSemanticParser 2 | from allennlp_semparse.models.nlvr.nlvr_coverage_semantic_parser import NlvrCoverageSemanticParser 3 | from allennlp_semparse.models.nlvr.nlvr_direct_semantic_parser import NlvrDirectSemanticParser 4 | from allennlp_semparse.models.text2sql_parser import Text2SqlParser 5 | from allennlp_semparse.models.wikitables.wikitables_erm_semantic_parser import ( 6 | WikiTablesErmSemanticParser, 7 | ) 8 | from allennlp_semparse.models.wikitables.wikitables_mml_semantic_parser import ( 9 | WikiTablesMmlSemanticParser, 10 | ) 11 | -------------------------------------------------------------------------------- /allennlp_semparse/models/atis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/models/atis/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/models/nlvr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/models/nlvr/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/models/wikitables/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/models/wikitables/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/nltk_languages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/nltk_languages/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/nltk_languages/contexts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/nltk_languages/contexts/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/nltk_languages/type_declarations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/nltk_languages/type_declarations/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/nltk_languages/worlds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/nltk_languages/worlds/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/parsimonious_languages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/allennlp_semparse/parsimonious_languages/__init__.py -------------------------------------------------------------------------------- /allennlp_semparse/parsimonious_languages/contexts/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.parsimonious_languages.contexts.atis_sql_table_context import ( 2 | AtisSqlTableContext, 3 | ) 4 | -------------------------------------------------------------------------------- /allennlp_semparse/parsimonious_languages/executors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Executors are classes that deterministically transform programs in domain specific languages 3 | into denotations. We have one executor defined for each language-domain pair that we handle. 4 | """ 5 | from allennlp_semparse.parsimonious_languages.executors.sql_executor import SqlExecutor 6 | -------------------------------------------------------------------------------- /allennlp_semparse/parsimonious_languages/executors/sql_executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import sqlite3 5 | import multiprocessing 6 | from multiprocessing import Process 7 | from allennlp.common.file_utils import cached_path 8 | 9 | logger = logging.getLogger(__name__) 10 | MULTIPROCESSING_LOGGER = multiprocessing.get_logger() 11 | 12 | 13 | class SqlExecutor: 14 | """ 15 | This class evaluates SQL queries by connecting to a SQLite database. Because SQLite is disk-based 16 | we just need to provide one file with the location. We execute the predicted SQL query and the labeled 17 | queries against the database and check if they execute to the same table. 18 | """ 19 | 20 | def __init__(self, database_file: str) -> None: 21 | # Initialize a cursor to our sqlite database, so we can execute SQL queries for denotation accuracy. 22 | self._database_file = cached_path(database_file) 23 | 24 | def evaluate_sql_query(self, predicted_sql_query: str, sql_query_labels: List[str]) -> int: 25 | # We set the logging level for the subprocesses to warning, otherwise, it will 26 | # log every time a process starts and stops. 27 | MULTIPROCESSING_LOGGER.setLevel(logging.WARNING) 28 | 29 | # Since the query might hang, we run in another process and kill it if it 30 | # takes too long. 31 | process = Process( 32 | target=self._evaluate_sql_query_subprocess, 33 | args=(self._database_file, predicted_sql_query, sql_query_labels), 34 | ) 35 | process.start() 36 | 37 | # If the query has not finished in 3 seconds then we will proceed. 38 | process.join(10) 39 | denotation_correct = process.exitcode # type: ignore 40 | 41 | if process.is_alive(): 42 | logger.warning("Evaluating query took over 10 seconds, skipping query") 43 | process.terminate() 44 | process.join() 45 | 46 | if denotation_correct is None: 47 | denotation_correct = 0 48 | 49 | return denotation_correct 50 | 51 | @staticmethod 52 | def _evaluate_sql_query_subprocess( 53 | database_file: str, predicted_query: str, sql_query_labels: List[str] 54 | ) -> None: 55 | """ 56 | We evaluate here whether the predicted query and the query label evaluate to the 57 | exact same table. This method is only called by the subprocess, so we just exit with 58 | 1 if it is correct and 0 otherwise. 59 | """ 60 | 61 | connection = sqlite3.connect(database_file) 62 | cursor = connection.cursor() 63 | 64 | postprocessed_predicted_query = SqlExecutor.postprocess_query_sqlite(predicted_query) 65 | 66 | try: 67 | cursor.execute(postprocessed_predicted_query) 68 | predicted_rows = cursor.fetchall() 69 | except sqlite3.Error as error: 70 | logger.warning(f"Error executing predicted: {error}") 71 | exit(0) 72 | 73 | # If predicted table matches any of the reference tables then it is counted as correct. 74 | target_rows = None 75 | for sql_query_label in sql_query_labels: 76 | postprocessed_sql_query_label = SqlExecutor.postprocess_query_sqlite(sql_query_label) 77 | try: 78 | cursor.execute(postprocessed_sql_query_label) 79 | target_rows = cursor.fetchall() 80 | except sqlite3.Error as error: 81 | logger.warning(f"Error executing predicted: {error}") 82 | if predicted_rows == target_rows: 83 | exit(1) 84 | exit(0) 85 | 86 | @staticmethod 87 | def postprocess_query_sqlite(query: str): 88 | # The dialect of SQL that SQLite takes is not exactly the same as the labeled data. 89 | # We strip off the parentheses that surround the entire query here. 90 | query = query.strip() 91 | if query.startswith("("): 92 | return query[1 : query.rfind(")")] + ";" 93 | return query 94 | -------------------------------------------------------------------------------- /allennlp_semparse/parsimonious_languages/worlds/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.parsimonious_languages.worlds.atis_world import AtisWorld 2 | -------------------------------------------------------------------------------- /allennlp_semparse/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.predictors.atis_parser import AtisParserPredictor 2 | from allennlp_semparse.predictors.nlvr_parser import NlvrParserPredictor 3 | from allennlp_semparse.predictors.wikitables_parser import WikiTablesParserPredictor 4 | -------------------------------------------------------------------------------- /allennlp_semparse/predictors/atis_parser.py: -------------------------------------------------------------------------------- 1 | from allennlp.common.util import JsonDict 2 | from allennlp.data import Instance 3 | from allennlp.predictors.predictor import Predictor 4 | 5 | 6 | @Predictor.register("atis-parser") 7 | class AtisParserPredictor(Predictor): 8 | """ 9 | Predictor for the :class:`~allennlp_semparse.models.atis.AtisSemanticParser` model. 10 | """ 11 | 12 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 13 | """ 14 | Expects JSON that looks like ``{"utterance": "..."}``. 15 | """ 16 | utterance = json_dict["utterance"] 17 | return self._dataset_reader.text_to_instance([utterance]) 18 | -------------------------------------------------------------------------------- /allennlp_semparse/predictors/nlvr_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | from allennlp.common.util import JsonDict 5 | from allennlp.data import Instance 6 | from allennlp.predictors.predictor import Predictor 7 | 8 | 9 | @Predictor.register("nlvr-parser") 10 | class NlvrParserPredictor(Predictor): 11 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 12 | sentence = json_dict["sentence"] 13 | if "worlds" in json_dict: 14 | # This is grouped data 15 | worlds = json_dict["worlds"] 16 | if isinstance(worlds, str): 17 | worlds = json.loads(worlds) 18 | else: 19 | structured_rep = json_dict["structured_rep"] 20 | if isinstance(structured_rep, str): 21 | structured_rep = json.loads(structured_rep) 22 | worlds = [structured_rep] 23 | identifier = json_dict["identifier"] if "identifier" in json_dict else None 24 | instance = self._dataset_reader.text_to_instance( 25 | sentence=sentence, # type: ignore 26 | structured_representations=worlds, 27 | identifier=identifier, 28 | ) 29 | return instance 30 | 31 | def dump_line(self, outputs: JsonDict) -> str: 32 | if "identifier" in outputs: 33 | # Returning CSV lines for official evaluation 34 | identifier = outputs["identifier"] 35 | denotation = outputs["denotations"][0][0] 36 | return f"{identifier},{denotation}\n" 37 | else: 38 | return json.dumps(outputs) + "\n" 39 | -------------------------------------------------------------------------------- /allennlp_semparse/predictors/wikitables_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from allennlp.common.util import JsonDict 4 | from allennlp.data import Instance 5 | from allennlp.predictors.predictor import Predictor 6 | 7 | 8 | @Predictor.register("wikitables-parser") 9 | class WikiTablesParserPredictor(Predictor): 10 | """ 11 | Wrapper for the 12 | :class:`~allennlp.models.encoder_decoders.wikitables_semantic_parser.WikiTablesSemanticParser` 13 | model. 14 | """ 15 | 16 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 17 | """ 18 | Expects JSON that looks like ``{"question": "...", "table": "..."}``. 19 | """ 20 | question_text = json_dict["question"] 21 | table_rows = json_dict["table"].split("\n") 22 | 23 | # We are directly passing the raw table rows here. The code in ``TableQuestionContext`` will do some 24 | # minimal processing to extract dates and numbers from the cells. 25 | instance = self._dataset_reader.text_to_instance( 26 | question_text, # type: ignore 27 | table_rows, 28 | ) 29 | return instance 30 | 31 | def predict_json(self, inputs: JsonDict) -> JsonDict: 32 | """ 33 | We need to override this because of the interactive beam search aspects. 34 | """ 35 | instance = self._json_to_instance(inputs) 36 | 37 | # Get the rules out of the instance 38 | index_to_rule = [ 39 | production_rule_field.rule 40 | for production_rule_field in instance.fields["actions"].field_list 41 | ] 42 | rule_to_index = {rule: i for i, rule in enumerate(index_to_rule)} 43 | 44 | # A sequence of strings to force, then convert them to ints 45 | initial_tokens = inputs.get("initial_sequence", []) 46 | 47 | # Want to get initial_sequence on the same device as the model. 48 | initial_sequence = torch.tensor( 49 | [rule_to_index[token] for token in initial_tokens], 50 | device=next(self._model.parameters()).device, 51 | ) 52 | 53 | # Replace beam search with one that forces the initial sequence 54 | original_beam_search = self._model._beam_search 55 | interactive_beam_search = original_beam_search.constrained_to(initial_sequence) 56 | self._model._beam_search = interactive_beam_search 57 | 58 | # Now get results 59 | results = self.predict_instance(instance) 60 | 61 | # And add in the choices. Need to convert from idxs to rules. 62 | results["choices"] = [ 63 | [ 64 | (probability, action) 65 | for probability, action in zip(pa["action_probabilities"], pa["considered_actions"]) 66 | ] 67 | for pa in results["predicted_actions"] 68 | ] 69 | 70 | results["beam_snapshots"] = { 71 | # For each batch_index, we get a list of beam snapshots 72 | batch_index: [ 73 | # Each beam_snapshots consists of a list of timesteps, 74 | # each of which is a list of pairs (score, sequence). 75 | # The sequence is the *indices* of the rules, which we 76 | # want to convert to the string representations. 77 | [ 78 | (score, [index_to_rule[idx] for idx in sequence]) 79 | for score, sequence in timestep_snapshot 80 | ] 81 | for timestep_snapshot in beam_snapshots 82 | ] 83 | for batch_index, beam_snapshots in interactive_beam_search.beam_snapshots.items() 84 | } 85 | 86 | # Restore original beam search 87 | self._model._beam_search = original_beam_search 88 | 89 | return results 90 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains code for using state machines in a model to do transition-based decoding. 3 | "Transition-based decoding" is where you start in some state, iteratively transition between 4 | states, and have some kind of supervision signal that tells you which end states, or which 5 | transition sequences, are "good". 6 | 7 | Typical seq2seq decoding, where you have a fixed vocabulary and no constraints on your output, can 8 | be done much more efficiently than we do in this code. This is intended for structured models that 9 | have constraints on their outputs. 10 | 11 | The key abstractions in this code are the following: 12 | 13 | - ``State`` represents the current state of decoding, containing a list of all of the actions 14 | taken so far, and a current score for the state. It also has methods around determining 15 | whether the state is "finished" and for combining states for batched computation. 16 | - ``TransitionFunction`` is a ``torch.nn.Module`` that models the transition function between 17 | states. Its main method is ``take_step``, which generates a ranked list of next states given 18 | a current state. 19 | - ``DecoderTrainer`` is an algorithm for training the transition function with some kind of 20 | supervision signal. There are many options for training algorithms and supervision signals; 21 | this is an abstract class that is generic over the type of the supervision signal. 22 | 23 | There is also a generic ``BeamSearch`` class for finding the ``k`` highest-scoring transition 24 | sequences given a trained ``TransitionFunction`` and an initial ``State``. 25 | """ 26 | from allennlp_semparse.state_machines.beam_search import BeamSearch 27 | from allennlp_semparse.state_machines.constrained_beam_search import ConstrainedBeamSearch 28 | from allennlp_semparse.state_machines.states import State 29 | from allennlp_semparse.state_machines.trainers import DecoderTrainer 30 | from allennlp_semparse.state_machines.transition_functions import TransitionFunction 31 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/states/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the ``State`` abstraction for defining state-machine-based decoders, and some 3 | pre-built concrete ``State`` classes for various kinds of decoding (e.g., a ``GrammarBasedState`` 4 | for doing grammar-based decoding, where the output is a sequence of production rules from a 5 | grammar). 6 | 7 | The module also has some ``Statelet`` classes to help represent the ``State`` by grouping together 8 | related pieces, including ``RnnStatelet``, which you can use to keep track of a decoder RNN's 9 | internal state, ``GrammarStatelet``, which keeps track of what actions are allowed at each timestep 10 | of decoding (if your outputs are production rules from a grammar), and ``ChecklistStatelet`` that 11 | keeps track of coverage information if you are training a coverage-based parser. 12 | """ 13 | from allennlp_semparse.state_machines.states.checklist_statelet import ChecklistStatelet 14 | from allennlp_semparse.state_machines.states.coverage_state import CoverageState 15 | from allennlp_semparse.state_machines.states.grammar_based_state import GrammarBasedState 16 | from allennlp_semparse.state_machines.states.grammar_statelet import GrammarStatelet 17 | from allennlp_semparse.state_machines.states.lambda_grammar_statelet import LambdaGrammarStatelet 18 | from allennlp_semparse.state_machines.states.rnn_statelet import RnnStatelet 19 | from allennlp_semparse.state_machines.states.state import State 20 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/states/checklist_statelet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from allennlp.nn import util 6 | 7 | 8 | class ChecklistStatelet: 9 | """ 10 | This class keeps track of checklist related variables that are used while training a coverage 11 | based semantic parser (or any other kind of transition based constrained decoder). This is 12 | intended to be used within a ``State``. 13 | 14 | Parameters 15 | ---------- 16 | terminal_actions : ``torch.Tensor`` 17 | A vector containing the indices of terminal actions, required for computing checklists for 18 | next states based on current actions. The idea is that we will build checklists 19 | corresponding to the presence or absence of just the terminal actions. But in principle, 20 | they can be all actions that are relevant to checklist computation. 21 | checklist_target : ``torch.Tensor`` 22 | Targets corresponding to checklist that indicate the states in which we want the checklist to 23 | ideally be. It is the same size as ``terminal_actions``, and it contains 1 for each corresponding 24 | action in the list that we want to see in the final logical form, and 0 for each corresponding 25 | action that we do not. 26 | checklist_mask : ``torch.Tensor`` 27 | Mask corresponding to ``terminal_actions``, indicating which of those actions are relevant 28 | for checklist computation. For example, if the parser is penalizing non-agenda terminal 29 | actions, all the terminal actions are relevant. 30 | checklist : ``torch.Tensor`` 31 | A checklist indicating how many times each action in its agenda has been chosen previously. 32 | It contains the actual counts of the agenda actions. 33 | terminal_indices_dict: ``Dict[int, int]``, optional 34 | Mapping from batch action indices to indices in any of the four vectors above. If not 35 | provided, this mapping will be computed here. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | terminal_actions: torch.Tensor, 41 | checklist_target: torch.Tensor, 42 | checklist_mask: torch.Tensor, 43 | checklist: torch.Tensor, 44 | terminal_indices_dict: Dict[int, int] = None, 45 | ) -> None: 46 | self.terminal_actions = terminal_actions 47 | self.checklist_target = checklist_target 48 | self.checklist_mask = checklist_mask 49 | self.checklist = checklist 50 | if terminal_indices_dict is not None: 51 | self.terminal_indices_dict = terminal_indices_dict 52 | else: 53 | self.terminal_indices_dict = {} 54 | for checklist_index, batch_action_index in enumerate(terminal_actions.detach().cpu()): 55 | action_index = int(batch_action_index[0]) 56 | if action_index == -1: 57 | continue 58 | self.terminal_indices_dict[action_index] = checklist_index 59 | 60 | def update(self, action: torch.Tensor) -> "ChecklistStatelet": 61 | """ 62 | Takes an action index, updates checklist and returns an updated state. 63 | """ 64 | checklist_addition = (self.terminal_actions == action).float() 65 | new_checklist = self.checklist + checklist_addition 66 | new_checklist_state = ChecklistStatelet( 67 | terminal_actions=self.terminal_actions, 68 | checklist_target=self.checklist_target, 69 | checklist_mask=self.checklist_mask, 70 | checklist=new_checklist, 71 | terminal_indices_dict=self.terminal_indices_dict, 72 | ) 73 | return new_checklist_state 74 | 75 | def get_balance(self) -> torch.Tensor: 76 | return self.checklist_mask * (self.checklist_target - self.checklist) 77 | 78 | def __eq__(self, other): 79 | if isinstance(self, other.__class__): 80 | return all( 81 | [ 82 | util.tensors_equal(self.terminal_actions, other.terminal_actions), 83 | util.tensors_equal(self.checklist_target, other.checklist_target), 84 | util.tensors_equal(self.checklist_mask, other.checklist_mask), 85 | util.tensors_equal(self.checklist, other.checklist), 86 | self.terminal_indices_dict == other.terminal_indices_dict, 87 | ] 88 | ) 89 | return NotImplemented 90 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/states/rnn_statelet.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from allennlp.nn import util 6 | 7 | 8 | class RnnStatelet: 9 | """ 10 | This class keeps track of all of decoder-RNN-related variables that you need during decoding. 11 | This includes things like the current decoder hidden state, the memory cell (for LSTM 12 | decoders), the encoder output that you need for computing attentions, and so on. 13 | 14 | This is intended to be used `inside` a ``State``, which likely has other things it has to keep 15 | track of for doing constrained decoding. 16 | 17 | Parameters 18 | ---------- 19 | hidden_state : ``torch.Tensor`` 20 | This holds the LSTM hidden state, with shape ``(decoder_output_dim,)`` if the decoder 21 | has 1 layer and ``(num_layers, decoder_output_dim)`` otherwise. 22 | memory_cell : ``torch.Tensor`` 23 | This holds the LSTM memory cell, with shape ``(decoder_output_dim,)`` if the decoder has 24 | 1 layer and ``(num_layers, decoder_output_dim)`` otherwise. 25 | previous_action_embedding : ``torch.Tensor`` 26 | This holds the embedding for the action we took at the last timestep (which gets input to 27 | the decoder). Has shape ``(action_embedding_dim,)``. 28 | attended_input : ``torch.Tensor`` 29 | This holds the attention-weighted sum over the input representations that we computed in 30 | the previous timestep. We keep this as part of the state because we use the previous 31 | attention as part of our decoder cell update. Has shape ``(encoder_output_dim,)``. 32 | encoder_outputs : ``List[torch.Tensor]`` 33 | A list of variables, each of shape ``(input_sequence_length, encoder_output_dim)``, 34 | containing the encoder outputs at each timestep. The list is over batch elements, and we 35 | do the input this way so we can easily do a ``torch.cat`` on a list of indices into this 36 | batched list. 37 | 38 | Note that all of the above parameters are single tensors, while the encoder outputs and 39 | mask are lists of length ``batch_size``. We always pass around the encoder outputs and 40 | mask unmodified, regardless of what's in the grouping for this state. We'll use the 41 | ``batch_indices`` for the group to pull pieces out of these lists when we're ready to 42 | actually do some computation. 43 | encoder_output_mask : ``List[torch.Tensor]`` 44 | A list of variables, each of shape ``(input_sequence_length,)``, containing a mask over 45 | question tokens for each batch instance. This is a list over batch elements, for the same 46 | reasons as above. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | hidden_state: torch.Tensor, 52 | memory_cell: torch.Tensor, 53 | previous_action_embedding: torch.Tensor, 54 | attended_input: torch.Tensor, 55 | encoder_outputs: List[torch.Tensor], 56 | encoder_output_mask: List[torch.Tensor], 57 | ) -> None: 58 | self.hidden_state = hidden_state 59 | self.memory_cell = memory_cell 60 | self.previous_action_embedding = previous_action_embedding 61 | self.attended_input = attended_input 62 | self.encoder_outputs = encoder_outputs 63 | self.encoder_output_mask = encoder_output_mask 64 | 65 | def __eq__(self, other): 66 | if isinstance(self, other.__class__): 67 | return all( 68 | [ 69 | util.tensors_equal(self.hidden_state, other.hidden_state, tolerance=1e-5), 70 | util.tensors_equal(self.memory_cell, other.memory_cell, tolerance=1e-5), 71 | util.tensors_equal( 72 | self.previous_action_embedding, 73 | other.previous_action_embedding, 74 | tolerance=1e-5, 75 | ), 76 | util.tensors_equal(self.attended_input, other.attended_input, tolerance=1e-5), 77 | util.tensors_equal(self.encoder_outputs, other.encoder_outputs, tolerance=1e-5), 78 | util.tensors_equal( 79 | self.encoder_output_mask, other.encoder_output_mask, tolerance=1e-5 80 | ), 81 | ] 82 | ) 83 | return NotImplemented 84 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/states/state.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, List, TypeVar 2 | 3 | import torch 4 | 5 | # Note that the bound here is `State` itself. This is what lets us have methods that take 6 | # lists of a `State` subclass and output structures with the subclass. Really ugly that we 7 | # have to do this generic typing _for our own class_, but it makes mypy happy and gives us good 8 | # type checking in a few important methods. 9 | T = TypeVar("T", bound="State") 10 | 11 | 12 | class State(Generic[T]): 13 | """ 14 | Represents the (batched) state of a transition-based decoder. 15 | 16 | There are two different kinds of batching we need to distinguish here. First, there's the 17 | batch of training instances passed to ``model.forward()``. We'll use "batch" and 18 | ``batch_size`` to refer to this through the docs and code. We additionally batch together 19 | computation for several states at the same time, where each state could be from the same 20 | training instance in the original batch, or different instances. We use "group" and 21 | ``group_size`` in the docs and code to refer to this kind of batching, to distinguish it from 22 | the batch of training instances. 23 | 24 | So, using this terminology, a single ``State`` object represents a `grouped` collection of 25 | states. Because different states in this group might finish at different timesteps, we have 26 | methods and member variables to handle some bookkeeping around this, to split and regroup 27 | things. 28 | 29 | Parameters 30 | ---------- 31 | batch_indices : ``List[int]`` 32 | A ``group_size``-length list, where each element specifies which ``batch_index`` that group 33 | element came from. 34 | 35 | Our internal variables (like scores, action histories, hidden states, whatever) are 36 | `grouped`, and our ``group_size`` is likely different from the original ``batch_size``. 37 | This variable keeps track of which batch instance each group element came from (e.g., to 38 | know what the correct action sequences are, or which encoder outputs to use). 39 | action_history : ``List[List[int]]`` 40 | The list of actions taken so far in this state. This is also grouped, so each state in the 41 | group has a list of actions. 42 | score : ``List[torch.Tensor]`` 43 | This state's score. It's a variable, because typically we'll be computing a loss based on 44 | this score, and using it for backprop during training. Like the other variables here, this 45 | is a ``group_size``-length list. 46 | """ 47 | 48 | def __init__( 49 | self, batch_indices: List[int], action_history: List[List[int]], score: List[torch.Tensor] 50 | ) -> None: 51 | self.batch_indices = batch_indices 52 | self.action_history = action_history 53 | self.score = score 54 | 55 | def is_finished(self) -> bool: 56 | """ 57 | If this state has a ``group_size`` of 1, this returns whether the single action sequence in 58 | this state is finished or not. If this state has a ``group_size`` other than 1, this 59 | method raises an error. 60 | """ 61 | raise NotImplementedError 62 | 63 | @classmethod 64 | def combine_states(cls, states: List[T]) -> T: 65 | """ 66 | Combines a list of states, each with their own group size, into a single state. 67 | """ 68 | raise NotImplementedError 69 | 70 | def __eq__(self, other): 71 | if isinstance(self, other.__class__): 72 | return self.__dict__ == other.__dict__ 73 | return NotImplemented 74 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_semparse.state_machines.trainers.decoder_trainer import DecoderTrainer 2 | from allennlp_semparse.state_machines.trainers.expected_risk_minimization import ( 3 | ExpectedRiskMinimization, 4 | ) 5 | from allennlp_semparse.state_machines.trainers.maximum_marginal_likelihood import ( 6 | MaximumMarginalLikelihood, 7 | ) 8 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/trainers/decoder_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generic, TypeVar 2 | 3 | import torch 4 | 5 | from allennlp_semparse.state_machines.states import State 6 | from allennlp_semparse.state_machines.transition_functions import TransitionFunction 7 | 8 | SupervisionType = TypeVar("SupervisionType") 9 | 10 | 11 | class DecoderTrainer(Generic[SupervisionType]): 12 | """ 13 | ``DecoderTrainers`` define a training regime for transition-based decoders. A 14 | ``DecoderTrainer`` assumes an initial ``State``, a ``TransitionFunction`` function that can 15 | traverse the state space, and some supervision signal. Given these things, the 16 | ``DecoderTrainer`` trains the ``TransitionFunction`` function to traverse the state space to 17 | end up at good end states. 18 | 19 | Concrete implementations of this abstract base class could do things like maximum marginal 20 | likelihood, SEARN, LaSO, or other structured learning algorithms. If you're just trying to 21 | maximize the probability of a single target sequence where the possible outputs are the same 22 | for each timestep (as in, e.g., typical machine translation training regimes), there are way 23 | more efficient ways to do that than using this API. 24 | """ 25 | 26 | def decode( 27 | self, 28 | initial_state: State, 29 | transition_function: TransitionFunction, 30 | supervision: SupervisionType, 31 | ) -> Dict[str, torch.Tensor]: 32 | """ 33 | Takes an initial state object, a means of transitioning from state to state, and a 34 | supervision signal, and uses the supervision to train the transition function to pick 35 | "good" states. 36 | 37 | This function should typically return a ``loss`` key during training, which the ``Model`` 38 | will use as its loss. 39 | 40 | Parameters 41 | ---------- 42 | initial_state : ``State`` 43 | This is the initial state for decoding, typically initialized after running some kind 44 | of encoder on some inputs. 45 | transition_function : ``TransitionFunction`` 46 | This is the transition function that scores all possible actions that can be taken in a 47 | given state, and returns a ranked list of next states at each step of decoding. 48 | supervision : ``SupervisionType`` 49 | This is the supervision that is used to train the ``transition_function`` function to 50 | pick "good" states. You can use whatever kind of supervision you want (e.g., a single 51 | "gold" action sequence, a set of possible "gold" action sequences, a reward function, 52 | etc.). We use ``typing.Generics`` to make sure that our static type checker is happy 53 | with how you've matched the supervision that you provide in the model to the 54 | ``DecoderTrainer`` that you want to use. 55 | """ 56 | raise NotImplementedError 57 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/trainers/maximum_marginal_likelihood.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | 6 | from allennlp.nn import util 7 | 8 | from allennlp_semparse.state_machines.constrained_beam_search import ConstrainedBeamSearch 9 | from allennlp_semparse.state_machines.states import State 10 | from allennlp_semparse.state_machines.trainers.decoder_trainer import DecoderTrainer 11 | from allennlp_semparse.state_machines.transition_functions import TransitionFunction 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class MaximumMarginalLikelihood(DecoderTrainer[Tuple[torch.Tensor, torch.Tensor]]): 17 | """ 18 | This class trains a decoder by maximizing the marginal likelihood of the targets. That is, 19 | during training, we are given a `set` of acceptable or possible target sequences, and we 20 | optimize the `sum` of the probability the model assigns to each item in the set. This allows 21 | the model to distribute its probability mass over the set however it chooses, without forcing 22 | `all` of the given target sequences to have high probability. This is helpful, for example, if 23 | you have good reason to expect that the correct target sequence is in the set, but aren't sure 24 | `which` of the sequences is actually correct. 25 | 26 | This implementation of maximum marginal likelihood requires the model you use to be `locally 27 | normalized`; that is, at each decoding timestep, we assume that the model creates a normalized 28 | probability distribution over actions. This assumption is necessary, because we do no explicit 29 | normalization in our loss function, we just sum the probabilities assigned to all correct 30 | target sequences, relying on the local normalization at each time step to push probability mass 31 | from bad actions to good ones. 32 | 33 | Parameters 34 | ---------- 35 | beam_size : ``int``, optional (default=None) 36 | We can optionally run a constrained beam search over the provided targets during decoding. 37 | This narrows the set of transition sequences that are marginalized over in the loss 38 | function, keeping only the top ``beam_size`` sequences according to the model. If this is 39 | ``None``, we will keep all of the provided sequences in the loss computation. 40 | """ 41 | 42 | def __init__(self, beam_size: int = None) -> None: 43 | self._beam_size = beam_size 44 | 45 | def decode( 46 | self, 47 | initial_state: State, 48 | transition_function: TransitionFunction, 49 | supervision: Tuple[torch.Tensor, torch.Tensor], 50 | ) -> Dict[str, torch.Tensor]: 51 | targets, target_mask = supervision 52 | beam_search = ConstrainedBeamSearch(self._beam_size, targets, target_mask) 53 | finished_states: Dict[int, List[State]] = beam_search.search( 54 | initial_state, transition_function 55 | ) 56 | 57 | loss = 0 58 | for instance_states in finished_states.values(): 59 | scores = [state.score[0].view(-1) for state in instance_states] 60 | loss += -util.logsumexp(torch.cat(scores)) 61 | return {"loss": loss / len(finished_states)} 62 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/transition_functions/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains ``TransitionFunctions`` for state-machine-based decoders. The 3 | ``TransitionFunction`` parameterizes transitions between ``States``. These ``TransitionFunctions`` 4 | are all pytorch `Modules`` that have trainable parameters. The :class:`BasicTransitionFunction` is 5 | simply an LSTM decoder with attention over an input utterance, and the other classes typically 6 | subclass this and add functionality to it. 7 | """ 8 | from allennlp_semparse.state_machines.transition_functions.basic_transition_function import ( 9 | BasicTransitionFunction, 10 | ) 11 | from allennlp_semparse.state_machines.transition_functions.coverage_transition_function import ( 12 | CoverageTransitionFunction, 13 | ) 14 | from allennlp_semparse.state_machines.transition_functions.linking_coverage_transition_function import ( 15 | LinkingCoverageTransitionFunction, 16 | ) 17 | from allennlp_semparse.state_machines.transition_functions.linking_transition_function import ( 18 | LinkingTransitionFunction, 19 | ) 20 | from allennlp_semparse.state_machines.transition_functions.transition_function import ( 21 | TransitionFunction, 22 | ) 23 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/transition_functions/transition_function.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, List, Set, TypeVar 2 | 3 | import torch 4 | 5 | from allennlp_semparse.state_machines.states import State 6 | 7 | StateType = TypeVar("StateType", bound=State) 8 | 9 | 10 | class TransitionFunction(torch.nn.Module, Generic[StateType]): 11 | """ 12 | A ``TransitionFunction`` is a module that assigns scores to state transitions in a 13 | transition-based decoder. 14 | 15 | The ``TransitionFunction`` takes a ``State`` and outputs a ranked list of next states, ordered 16 | by the state's score. 17 | 18 | The intention with this class is that a model will implement a subclass of 19 | ``TransitionFunction`` that defines how exactly you want to handle the input and what 20 | computations get done at each step of decoding, and how states are scored. This subclass then 21 | gets passed to a ``DecoderTrainer`` to have its parameters trained. 22 | """ 23 | 24 | def forward(self, *inputs): 25 | raise RuntimeError("call .take_step() instead of .forward()") 26 | 27 | def take_step( 28 | self, state: StateType, max_actions: int = None, allowed_actions: List[Set] = None 29 | ) -> List[StateType]: 30 | """ 31 | The main method in the ``TransitionFunction`` API. This function defines the computation 32 | done at each step of decoding and returns a ranked list of next states. 33 | 34 | The input state is `grouped`, to allow for efficient computation, but the output states 35 | should all have a ``group_size`` of 1, to make things easier on the decoding algorithm. 36 | They will get regrouped later as needed. 37 | 38 | Because of the way we handle grouping in the decoder states, constructing a new state is 39 | actually a relatively expensive operation. If you know a priori that only some of the 40 | states will be needed (either because you have a set of gold action sequences, or you have 41 | a fixed beam size), passing that information into this function will keep us from 42 | constructing more states than we need, which will greatly speed up your computation. 43 | 44 | IMPORTANT: This method `must` returns states already sorted by their score, otherwise 45 | ``BeamSearch`` and other methods will break. For efficiency, we do not perform an 46 | additional sort in those methods. 47 | 48 | ALSO IMPORTANT: When ``allowed_actions`` is given and ``max_actions`` is not, we assume you 49 | want to evaluate all possible states and do not need any sorting (e.g., this is true for 50 | maximum marginal likelihood training that does not use a beam search). In this case, we 51 | may skip the sorting step for efficiency reasons. 52 | 53 | Parameters 54 | ---------- 55 | state : ``State`` 56 | The current state of the decoder, which we will take a step `from`. We may be grouping 57 | together computation for several states here. Because we can have several states for 58 | each instance in the original batch being evaluated at the same time, we use 59 | ``group_size`` for this kind of batching, and ``batch_size`` for the `original` batch 60 | in ``model.forward.`` 61 | max_actions : ``int``, optional 62 | If you know that you will only need a certain number of states out of this (e.g., in a 63 | beam search), you can pass in the max number of actions that you need, and we will only 64 | construct that many states (for each `batch` instance - `not` for each `group` 65 | instance!). This can save a whole lot of computation if you have an action space 66 | that's much larger than your beam size. 67 | allowed_actions : ``List[Set]``, optional 68 | If the ``DecoderTrainer`` has constraints on which actions need to be evaluated (e.g., 69 | maximum marginal likelihood only needs to evaluate action sequences in a given set), 70 | you can pass those constraints here, to avoid constructing state objects unnecessarily. 71 | If there are no constraints from the trainer, passing a value of ``None`` here will 72 | allow all actions to be considered. 73 | 74 | This is a list because it is `batched` - every instance in the batch has a set of 75 | allowed actions. Note that the size of this list is the ``group_size`` in the 76 | ``State``, `not` the ``batch_size`` of ``model.forward``. The training algorithm needs 77 | to convert from the `batched` allowed action sequences that it has to a `grouped` 78 | allowed action sequence list. 79 | 80 | Returns 81 | ------- 82 | next_states : ``List[State]`` 83 | A list of next states, ordered by score. 84 | """ 85 | raise NotImplementedError 86 | -------------------------------------------------------------------------------- /allennlp_semparse/state_machines/util.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, List, Optional, Set, Tuple, Union 3 | 4 | import torch 5 | 6 | 7 | def construct_prefix_tree( 8 | targets: Union[torch.Tensor, List[List[List[int]]]], 9 | target_mask: Optional[Union[torch.Tensor, List[List[List[int]]]]] = None, 10 | ) -> List[Dict[Tuple[int, ...], Set[int]]]: 11 | """ 12 | Takes a list of valid target action sequences and creates a mapping from all possible 13 | (valid) action prefixes to allowed actions given that prefix. While the method is called 14 | ``construct_prefix_tree``, we're actually returning a map that has as keys the paths to 15 | `all internal nodes of the trie`, and as values all of the outgoing edges from that node. 16 | 17 | ``targets`` is assumed to be a tensor of shape ``(batch_size, num_valid_sequences, 18 | sequence_length)``. If the mask is not ``None``, it is assumed to have the same shape, and 19 | we will ignore any value in ``targets`` that has a value of ``0`` in the corresponding 20 | position in the mask. We assume that the mask has the format 1*0* for each item in 21 | ``targets`` - that is, once we see our first zero, we stop processing that target. 22 | 23 | For example, if ``targets`` is the following tensor: ``[[1, 2, 3], [1, 4, 5]]``, the return 24 | value will be: ``{(): set([1]), (1,): set([2, 4]), (1, 2): set([3]), (1, 4): set([5])}``. 25 | 26 | This could be used, e.g., to do an efficient constrained beam search, or to efficiently 27 | evaluate the probability of all of the target sequences. 28 | """ 29 | batched_allowed_transitions: List[Dict[Tuple[int, ...], Set[int]]] = [] 30 | 31 | if not isinstance(targets, list): 32 | assert targets.dim() == 3, "targets tensor needs to be batched!" 33 | targets = targets.detach().cpu().numpy().tolist() 34 | if target_mask is not None: 35 | if not isinstance(target_mask, list): 36 | target_mask = target_mask.detach().cpu().numpy().tolist() 37 | else: 38 | target_mask = [None for _ in targets] 39 | 40 | for instance_targets, instance_mask in zip(targets, target_mask): 41 | allowed_transitions: Dict[Tuple[int, ...], Set[int]] = defaultdict(set) 42 | for i, target_sequence in enumerate(instance_targets): 43 | history: Tuple[int, ...] = () 44 | for j, action in enumerate(target_sequence): 45 | if instance_mask and instance_mask[i][j] == 0: 46 | break 47 | allowed_transitions[history].add(action) 48 | history = history + (action,) 49 | batched_allowed_transitions.append(allowed_transitions) 50 | return batched_allowed_transitions 51 | -------------------------------------------------------------------------------- /allennlp_semparse/version.py: -------------------------------------------------------------------------------- 1 | _MAJOR = "0" 2 | _MINOR = "0" 3 | _REVISION = "4" 4 | 5 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 6 | VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION) 7 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 0 3 | round: down 4 | status: 5 | patch: 6 | default: 7 | target: 90 8 | project: 9 | default: 10 | threshold: 1% 11 | changes: false 12 | comment: false 13 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | #### TESTING-RELATED PACKAGES #### 2 | 3 | # Checks style, syntax, and other useful errors. 4 | flake8 5 | 6 | # Static type checking 7 | mypy==0.942 8 | 9 | # Automatic code formatting 10 | black==22.1.0 11 | 12 | # Running unit tests. 13 | pytest 14 | 15 | # Allows generation of coverage reports with pytest. 16 | pytest-cov 17 | 18 | # Lets you run tests in forked subprocesses. 19 | pytest-forked 20 | 21 | # Lets you run tests in parallel. 22 | pytest-xdist 23 | 24 | # Allows codecov to generate coverage reports 25 | coverage 26 | codecov 27 | 28 | # For running tests that aren't 100% reliable 29 | flaky 30 | 31 | #### PACKAGE-UPLOAD PACKAGES #### 32 | 33 | # Pypi uploads 34 | twine>=1.11.0 35 | setuptools 36 | wheel 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | include = '\.pyi?$' 5 | 6 | exclude = ''' 7 | ( 8 | __pycache__ 9 | | \bbuild\b 10 | | \.git 11 | | \.mypy_cache 12 | | \.pytest_cache 13 | | \.vscode 14 | | \.venv 15 | | \bdist\b 16 | | \bdoc\b 17 | ) 18 | ''' 19 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests/ 3 | log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s 4 | log_level = DEBUG 5 | filterwarnings = 6 | # Note: When a warning matches more than one option in the list, 7 | # the action for the _last_ matching option is performed. 8 | # 9 | # individual warnings filters are specified as a sequence of fields separated by colons: 10 | # action:message:category:module:line 11 | # 12 | # 13 | # how to explicitly test warns 14 | # using `unittest`: https://docs.python.org/3/library/warnings.html#testing-warnings 15 | # using `pytest`: https://docs.pytest.org/en/4.1.0/warnings.html#assertwarnings 16 | # 17 | # Our policy here is to ignore (silence) any deprecation warnings from _outside_ allennlp, but to 18 | # treat any _internal_ deprecation warnings as errors. If we get a deprecation warning from things 19 | # we call in another library, we will just rely on seeing those outside of tests. The purpose of 20 | # having these errors here is to make sure that we do not deprecate things lightly in allennlp. 21 | ignore::DeprecationWarning 22 | ignore::PendingDeprecationWarning 23 | error::DeprecationWarning:allennlp.*: 24 | error::PendingDeprecationWarning:allennlp.*: 25 | error::DeprecationWarning:allennlp_semparse.*: 26 | error::PendingDeprecationWarning:allennlp_semparse.*: 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp>=2.0,<3.0 2 | 3 | spacy<3.0 4 | 5 | # Used to create grammars for parsing SQL 6 | parsimonious>=0.8.0 7 | 8 | # Used to format and postprocess SQL 9 | sqlparse>=0.2.4 10 | 11 | # Used to process tables in WikiTableQuestions 12 | unidecode 13 | 14 | # Used in KnowledgeGraphField 15 | editdistance 16 | 17 | # Used for the type system for some languages 18 | nltk 19 | 20 | # Used for some tests 21 | flaky 22 | -------------------------------------------------------------------------------- /scripts/get_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from typing import Dict 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("version_type", choices=["stable", "latest", "current"]) 10 | parser.add_argument("--minimal", action="store_true", default=False) 11 | return parser.parse_args() 12 | 13 | 14 | def post_process(version: str, minimal: bool = False): 15 | if version.startswith("v"): 16 | return version if not minimal else version[1:] 17 | return version if minimal else f"v{version}" 18 | 19 | 20 | def get_current_version() -> str: 21 | VERSION: Dict[str, str] = {} 22 | with open("allennlp_semparse/version.py", "r") as version_file: 23 | exec(version_file.read(), VERSION) 24 | return VERSION["VERSION"] 25 | 26 | 27 | def get_latest_version() -> str: 28 | # Import this here so this requirements isn't mandatory when we just want to 29 | # call `get_current_version`. 30 | import requests 31 | 32 | resp = requests.get("https://api.github.com/repos/allenai/allennlp-semparse/tags") 33 | return resp.json()[0]["name"] 34 | 35 | 36 | def get_stable_version() -> str: 37 | import requests 38 | 39 | resp = requests.get("https://api.github.com/repos/allenai/allennlp-semparse/releases/latest") 40 | return resp.json()["tag_name"] 41 | 42 | 43 | def main() -> None: 44 | opts = parse_args() 45 | if opts.version_type == "stable": 46 | print(post_process(get_stable_version(), opts.minimal)) 47 | elif opts.version_type == "latest": 48 | print(post_process(get_latest_version(), opts.minimal)) 49 | elif opts.version_type == "current": 50 | print(post_process(get_current_version(), opts.minimal)) 51 | else: 52 | raise NotImplementedError 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /scripts/nlvr/generate_data_from_erm_model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | 4 | import sys 5 | import os 6 | import json 7 | import argparse 8 | 9 | sys.path.insert( 10 | 0, os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 11 | ) 12 | 13 | from allennlp.data.dataset_readers import NlvrDatasetReader 14 | from allennlp.models import NlvrCoverageSemanticParser 15 | from allennlp.models.archival import load_archive 16 | from allennlp.semparse.worlds import NlvrWorld 17 | 18 | 19 | def make_data( 20 | input_file: str, output_file: str, archived_model_file: str, max_num_decoded_sequences: int 21 | ) -> None: 22 | reader = NlvrDatasetReader(output_agendas=True) 23 | model = load_archive(archived_model_file).model 24 | if not isinstance(model, NlvrCoverageSemanticParser): 25 | model_type = type(model) 26 | raise RuntimeError( 27 | f"Expected an archived NlvrCoverageSemanticParser, but found {model_type} instead" 28 | ) 29 | # Tweaking the decoder trainer to coerce the it to generate a k-best list. Setting k to 100 30 | # here, so that we can filter out the inconsistent ones later. 31 | model._decoder_trainer._max_num_decoded_sequences = 100 32 | num_outputs = 0 33 | num_sentences = 0 34 | with open(output_file, "w") as outfile: 35 | for line in open(input_file): 36 | num_sentences += 1 37 | input_data = json.loads(line) 38 | sentence = input_data["sentence"] 39 | structured_representations = input_data["worlds"] 40 | labels = input_data["labels"] 41 | instance = reader.text_to_instance(sentence, structured_representations) 42 | outputs = model.forward_on_instance(instance) 43 | action_strings = outputs["best_action_strings"] 44 | logical_forms = outputs["logical_form"] 45 | correct_sequences = [] 46 | # Checking for consistency 47 | worlds = [NlvrWorld(structure) for structure in structured_representations] 48 | for sequence, logical_form in zip(action_strings, logical_forms): 49 | denotations = [world.execute(logical_form) for world in worlds] 50 | denotations_are_correct = [ 51 | label.lower() == str(denotation).lower() 52 | for label, denotation in zip(labels, denotations) 53 | ] 54 | if all(denotations_are_correct): 55 | correct_sequences.append(sequence) 56 | correct_sequences = correct_sequences[:max_num_decoded_sequences] 57 | if not correct_sequences: 58 | continue 59 | output_data = { 60 | "id": input_data["identifier"], 61 | "sentence": sentence, 62 | "correct_sequences": correct_sequences, 63 | "worlds": structured_representations, 64 | "labels": input_data["labels"], 65 | } 66 | json.dump(output_data, outfile) 67 | outfile.write("\n") 68 | num_outputs += 1 69 | outfile.close() 70 | sys.stderr.write(f"{num_outputs} out of {num_sentences} sentences have outputs.") 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("input", type=str, help="Input data file") 76 | parser.add_argument("output", type=str, help="Output data file") 77 | parser.add_argument( 78 | "archived_model", type=str, help="Path to archived model.tar.gz to use for decoding" 79 | ) 80 | parser.add_argument( 81 | "--max-num-sequences", 82 | type=int, 83 | dest="max_num_sequences", 84 | help="Maximum number of sequences per instance to output", 85 | default=20, 86 | ) 87 | args = parser.parse_args() 88 | make_data(args.input, args.output, args.archived_model, args.max_num_sequences) 89 | -------------------------------------------------------------------------------- /scripts/nlvr/group_nlvr_worlds.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | NLVR dataset has at most four worlds corresponding to each sentence (with 93% of the sentences 4 | appearing with four worlds), identified by the prefixes in identifiers. This script groups the 5 | worlds and corresponding labels together to enable training a parser with this information. 6 | """ 7 | 8 | 9 | import json 10 | import argparse 11 | from collections import defaultdict 12 | 13 | 14 | def group_dataset(input_file: str, output_file: str) -> None: 15 | instance_groups = defaultdict(lambda: {"worlds": [], "labels": []}) 16 | for line in open(input_file): 17 | data = json.loads(line) 18 | # "identifier" in the original dataset looks something like 4055-3, where 4055 is common 19 | # across all four instances with the same sentence, but different worlds, and the suffix 20 | # differentiates among the four instances. 21 | identifier = data["identifier"].split("-")[0] 22 | instance_groups[identifier]["identifier"] = identifier 23 | instance_groups[identifier]["sentence"] = data["sentence"] 24 | instance_groups[identifier]["worlds"].append(data["structured_rep"]) 25 | instance_groups[identifier]["labels"].append(data["label"]) 26 | 27 | with open(output_file, "w") as output: 28 | for instance_group in instance_groups.values(): 29 | json.dump(instance_group, output) 30 | output.write("\n") 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("input_file", type=str, help="NLVR data file in json format") 36 | parser.add_argument("output_file", type=str, help="Grouped output file in json format") 37 | args = parser.parse_args() 38 | group_dataset(args.input_file, args.output_file) 39 | -------------------------------------------------------------------------------- /scripts/nlvr/sed_commands.txt: -------------------------------------------------------------------------------- 1 | s/<:>/<,>/g 2 | s//>/g 3 | s//>/g 4 | s//>/g 5 | s//>/g 6 | s//>/g 7 | s//>/g 8 | s//>/g 9 | s//>/g 10 | s///g 11 | s///g 12 | s///g 13 | s///g 14 | s///g 15 | s/[[:<:]]int[[:>:]]/e/g 16 | s/[[:<:]]bool[[:>:]]/t/g 17 | s/Color/c/g 18 | s/Shape/s/g 19 | s/Set\[Box\]/b/g 20 | s/Set\[Object\]/o/g 21 | -------------------------------------------------------------------------------- /scripts/reformat_text2sql_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from collections import defaultdict 5 | from typing import Dict, Any, Iterable, Tuple 6 | import glob 7 | import argparse 8 | 9 | sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 10 | 11 | JsonDict = Dict[str, Any] 12 | 13 | 14 | def process_dataset(data: JsonDict, split_type: str) -> Iterable[Tuple[str, JsonDict]]: 15 | 16 | splits = defaultdict(list) 17 | 18 | for example in data: 19 | if split_type == "query_split": 20 | example_split = example["query-split"] 21 | splits[example_split].append(example) 22 | 23 | else: 24 | sentences = example.pop("sentences") 25 | 26 | for sentence in sentences: 27 | new_example = example.copy() 28 | new_example["sentences"] = [sentence] 29 | split = sentence["question-split"] 30 | splits[split].append(new_example) 31 | 32 | for split, examples in splits.items(): 33 | if split.isdigit(): 34 | yield ("split_" + split + ".json", examples) 35 | else: 36 | yield (split + ".json", examples) 37 | 38 | 39 | def main(output_directory: int, data: str) -> None: 40 | """ 41 | Processes the text2sql data into the following directory structure: 42 | 43 | ``dataset/{query_split, question_split}/{train,dev,test}.json`` 44 | 45 | for datasets which have train, dev and test splits, or: 46 | 47 | ``dataset/{query_split, question_split}/{split_{split_id}}.json`` 48 | 49 | for datasets which use cross validation. 50 | 51 | The JSON format is identical to the original datasets, apart from they 52 | are split into separate files with respect to the split_type. This means that 53 | for the question split, all of the sql data is duplicated for each sentence 54 | which is bucketed together as having the same semantics. 55 | 56 | As an example, the following blob would be put "as-is" into the query split 57 | dataset, and split into two datasets with identical blobs for the question split, 58 | differing only in the "sentence" key, where blob1 would end up in the train split 59 | and blob2 would be in the dev split, with the rest of the json duplicated in each. 60 | { 61 | "comments": [], 62 | "old-name": "", 63 | "query-split": "train", 64 | "sentences": [{blob1, "question-split": "train"}, {blob2, "question-split": "dev"}], 65 | "sql": [], 66 | "variables": [] 67 | }, 68 | 69 | Parameters 70 | ---------- 71 | output_directory : str, required. 72 | The output directory. 73 | data: str, default = None 74 | The path to the data director of https://github.com/jkkummerfeld/text2sql-data. 75 | """ 76 | json_files = glob.glob(os.path.join(data, "*.json")) 77 | 78 | for dataset in json_files: 79 | dataset_name = os.path.basename(dataset)[:-5] 80 | print( 81 | f"Processing dataset: {dataset} into query and question " 82 | f"splits at output path: {output_directory + '/' + dataset_name}" 83 | ) 84 | full_dataset = json.load(open(dataset)) 85 | if not isinstance(full_dataset, list): 86 | full_dataset = [full_dataset] 87 | 88 | for split_type in ["query_split", "question_split"]: 89 | dataset_out = os.path.join(output_directory, dataset_name, split_type) 90 | 91 | for split, split_dataset in process_dataset(full_dataset, split_type): 92 | dataset_out = os.path.join(output_directory, dataset_name, split_type) 93 | os.makedirs(dataset_out, exist_ok=True) 94 | json.dump(split_dataset, open(os.path.join(dataset_out, split), "w"), indent=4) 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | parser = argparse.ArgumentParser( 100 | description="process text2sql data into a more readable format." 101 | ) 102 | parser.add_argument("--out", type=str, help="The serialization directory.") 103 | parser.add_argument("--data", type=str, help="The path to the text2sql data directory.") 104 | args = parser.parse_args() 105 | main(args.out, args.data) 106 | -------------------------------------------------------------------------------- /scripts/wikitables/generate_data_from_erm_model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | 4 | import sys 5 | import os 6 | import gzip 7 | import argparse 8 | 9 | sys.path.insert( 10 | 0, os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 11 | ) 12 | 13 | from allennlp.data.dataset_readers import WikiTablesDatasetReader 14 | from allennlp.data.dataset_readers.semantic_parsing.wikitables import util 15 | from allennlp.models.archival import load_archive 16 | 17 | 18 | def make_data( 19 | input_examples_file: str, 20 | tables_directory: str, 21 | archived_model_file: str, 22 | output_dir: str, 23 | num_logical_forms: int, 24 | ) -> None: 25 | reader = WikiTablesDatasetReader( 26 | tables_directory=tables_directory, keep_if_no_logical_forms=True, output_agendas=True 27 | ) 28 | dataset = reader.read(input_examples_file) 29 | input_lines = [] 30 | with open(input_examples_file) as input_file: 31 | input_lines = input_file.readlines() 32 | archive = load_archive(archived_model_file) 33 | model = archive.model 34 | model.training = False 35 | model._decoder_trainer._max_num_decoded_sequences = 100 36 | for instance, example_line in zip(dataset, input_lines): 37 | outputs = model.forward_on_instance(instance) 38 | world = instance.fields["world"].metadata 39 | parsed_info = util.parse_example_line(example_line) 40 | example_id = parsed_info["id"] 41 | target_list = parsed_info["target_values"] 42 | logical_forms = outputs["logical_form"] 43 | correct_logical_forms = [] 44 | for logical_form in logical_forms: 45 | if world.evaluate_logical_form(logical_form, target_list): 46 | correct_logical_forms.append(logical_form) 47 | if len(correct_logical_forms) >= num_logical_forms: 48 | break 49 | num_found = len(correct_logical_forms) 50 | print(f"{num_found} found for {example_id}") 51 | if num_found == 0: 52 | continue 53 | if not os.path.exists(output_dir): 54 | os.makedirs(output_dir) 55 | output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb") 56 | for logical_form in correct_logical_forms: 57 | logical_form_line = (logical_form + "\n").encode("utf-8") 58 | output_file.write(logical_form_line) 59 | output_file.close() 60 | 61 | 62 | if __name__ == "__main__": 63 | argparser = argparse.ArgumentParser() 64 | argparser.add_argument("input", type=str, help="Input file") 65 | argparser.add_argument("tables_directory", type=str, help="Tables directory") 66 | argparser.add_argument("archived_model", type=str, help="Archived model.tar.gz") 67 | argparser.add_argument( 68 | "--output-dir", type=str, dest="output_dir", help="Output directory", default="erm_output" 69 | ) 70 | argparser.add_argument( 71 | "--num-logical-forms", 72 | type=int, 73 | dest="num_logical_forms", 74 | help="Number of logical forms to output", 75 | default=10, 76 | ) 77 | args = argparser.parse_args() 78 | make_data( 79 | args.input, 80 | args.tables_directory, 81 | args.archived_model, 82 | args.output_dir, 83 | args.num_logical_forms, 84 | ) 85 | -------------------------------------------------------------------------------- /scripts/wikitables/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | 6 | sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 7 | logging.basicConfig( 8 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG 9 | ) 10 | from allennlp.commands.train import datasets_from_params 11 | from allennlp.common import Params 12 | from allennlp.data import Instance 13 | 14 | 15 | def main(params: Params, outdir: str): 16 | os.makedirs(outdir, exist_ok=True) 17 | params["dataset_reader"]["include_table_metadata"] = True 18 | if "validation_dataset_reader" in params: 19 | params["validation_dataset_reader"]["include_table_metadata"] = True 20 | all_datasets = datasets_from_params(params) 21 | for name, dataset in all_datasets.items(): 22 | with open(outdir + name + ".jsonl", "w") as outfile: 23 | for instance in iter(dataset): 24 | outfile.write(to_json_line(instance) + "\n") 25 | 26 | 27 | def to_json_line(instance: Instance): 28 | json_obj = {} 29 | question_tokens = instance.fields["question"].tokens 30 | json_obj["question_tokens"] = [ 31 | {"text": token.text, "lemma": token.lemma_} for token in question_tokens 32 | ] 33 | json_obj["table_lines"] = instance.fields["table_metadata"].metadata 34 | 35 | action_map = {i: action.rule for i, action in enumerate(instance.fields["actions"].field_list)} 36 | 37 | if "target_action_sequences" in instance.fields: 38 | targets = [] 39 | for target_sequence in instance.fields["target_action_sequences"].field_list: 40 | targets.append([]) 41 | for target_index_field in target_sequence.field_list: 42 | targets[-1].append(action_map[target_index_field.sequence_index]) 43 | 44 | json_obj["target_action_sequences"] = targets 45 | 46 | json_obj["example_lisp_string"] = instance.fields["example_lisp_string"].metadata 47 | 48 | entity_texts = [] 49 | for entity_text in instance.fields["table"].entity_texts: 50 | tokens = [{"text": token.text, "lemma": token.lemma_} for token in entity_text] 51 | entity_texts.append(tokens) 52 | json_obj["entity_texts"] = entity_texts 53 | json_obj["linking_features"] = instance.fields["table"].linking_features 54 | return json.dumps(json_obj) 55 | 56 | 57 | if __name__ == "__main__": 58 | param_file = sys.argv[1] 59 | outdir = "wikitables_preprocessed_data/" 60 | params = Params.from_file(param_file) 61 | main(params, outdir) 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | # PEP0440 compatible formatted version, see: 5 | # https://www.python.org/dev/peps/pep-0440/ 6 | # 7 | # release markers: 8 | # X.Y 9 | # X.Y.Z # For bugfix releases 10 | # 11 | # pre-release markers: 12 | # X.YaN # Alpha release 13 | # X.YbN # Beta release 14 | # X.YrcN # Release Candidate 15 | # X.Y # Final release 16 | 17 | # version.py defines the VERSION and VERSION_SHORT variables. 18 | # We use exec here so we don't import allennlp_semparse whilst setting up. 19 | VERSION = {} 20 | with open("allennlp_semparse/version.py") as version_file: 21 | exec(version_file.read(), VERSION) 22 | 23 | # Load requirements.txt with a special case for allennlp so we can handle 24 | # cross-library integration testing. 25 | with open("requirements.txt") as requirements_file: 26 | import re 27 | 28 | def requirement_is_allennlp(req: str) -> bool: 29 | if req == "allennlp": 30 | return True 31 | if re.match(r"^allennlp[>=<]", req): 32 | return True 33 | if re.match(r"^(git\+)?(https|ssh)://(git@)?github\.com/.*/allennlp\.git", req): 34 | return True 35 | return False 36 | 37 | def fix_url_dependencies(req: str) -> str: 38 | """Pip and setuptools disagree about how URL dependencies should be handled.""" 39 | m = re.match( 40 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git", req 41 | ) 42 | if m is None: 43 | return req 44 | else: 45 | return f"{m.group('name')} @ {req}" 46 | 47 | install_requirements = [] 48 | allennlp_requirements = [] 49 | for line in requirements_file: 50 | line = line.strip() 51 | if line.startswith("#") or len(line) <= 0: 52 | continue 53 | if requirement_is_allennlp(line): 54 | allennlp_requirements.append(line) 55 | else: 56 | install_requirements.append(line) 57 | 58 | assert len(allennlp_requirements) == 1 59 | allennlp_override = os.environ.get("ALLENNLP_VERSION_OVERRIDE") 60 | if allennlp_override is not None: 61 | if len(allennlp_override) > 0: 62 | allennlp_requirements = [allennlp_override] 63 | else: 64 | allennlp_requirements = [] 65 | 66 | install_requirements.extend(allennlp_requirements) 67 | install_requirements = [fix_url_dependencies(req) for req in install_requirements] 68 | 69 | setup( 70 | name="allennlp_semparse", 71 | version=VERSION["VERSION"], 72 | description=( 73 | "A framework for building semantic parsers (including neural " 74 | "module networks) with AllenNLP, built by the authors of AllenNLP" 75 | ), 76 | long_description=open("README.md").read(), 77 | long_description_content_type="text/markdown", 78 | classifiers=[ 79 | "Intended Audience :: Science/Research", 80 | "Development Status :: 3 - Alpha", 81 | "License :: OSI Approved :: Apache Software License", 82 | "Programming Language :: Python :: 3.6", 83 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 84 | ], 85 | keywords="allennlp NLP deep learning machine reading semantic parsing parsers", 86 | url="https://github.com/allenai/allennlp-semparse", 87 | author="Allen Institute for Artificial Intelligence", 88 | author_email="allennlp@allenai.org", 89 | license="Apache", 90 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 91 | install_requires=install_requirements, 92 | include_package_data=True, 93 | python_requires=">=3.6.1", 94 | zip_safe=False, 95 | ) 96 | -------------------------------------------------------------------------------- /test_fixtures/atis/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "atis", 4 | "database_file": "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" 5 | }, 6 | "validation_dataset_reader": { 7 | "type": "atis", 8 | "database_file": "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db", 9 | "keep_if_unparseable": true 10 | }, 11 | "train_data_path": "test_fixtures/data/atis/sample.json", 12 | "validation_data_path": "test_fixtures/data/atis/sample.json", 13 | "model": { 14 | "type": "atis_parser", 15 | "utterance_embedder": { 16 | "token_embedders": { 17 | "tokens": { 18 | "type": "embedding", 19 | "embedding_dim": 20, 20 | "trainable": true 21 | } 22 | } 23 | }, 24 | "action_embedding_dim": 10, 25 | "encoder": { 26 | "type": "lstm", 27 | "input_size": 20, 28 | "hidden_size": 40, 29 | "bidirectional": true, 30 | "num_layers": 1 31 | }, 32 | "decoder_beam_search": { 33 | "beam_size": 5 34 | }, 35 | "decoder_num_layers": 2, 36 | "max_decoding_steps": 10, 37 | "input_attention": {"type": "dot_product"}, 38 | "dropout": 0.5, 39 | "database_file": "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" 40 | }, 41 | "data_loader": { 42 | "batch_size" : 4 43 | }, 44 | "trainer": { 45 | "num_epochs": 2, 46 | "patience": 5, 47 | "cuda_device": -1, 48 | "optimizer": { 49 | "type": "sgd", 50 | "lr": 0.1 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /test_fixtures/atis/serialization/best.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/atis/serialization/best.th -------------------------------------------------------------------------------- /test_fixtures/atis/serialization/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/atis/serialization/model.tar.gz -------------------------------------------------------------------------------- /test_fixtures/atis/serialization/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /test_fixtures/atis/serialization/vocabulary/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | the 3 | from 4 | to 5 | show 6 | me 7 | one 8 | way 9 | detroit 10 | westchester 11 | county 12 | are 13 | fare 14 | flight 15 | fares 16 | tacoma 17 | montreal 18 | flights 19 | what 20 | most 21 | expensive 22 | there 23 | with 24 | highest 25 | how 26 | many 27 | is 28 | prices 29 | of 30 | these 31 | all 32 | any 33 | other 34 | that 35 | salt 36 | lake 37 | city 38 | milwaukee 39 | -------------------------------------------------------------------------------- /test_fixtures/data/corenlp_processed_tables/TEST-11.table: -------------------------------------------------------------------------------- 1 | row col id content tokens lemmaTokens posTags nerTags nerValues number date num2 list listId 2 | -1 0 fb:row.row.position Position position position NN O 3 | -1 1 fb:row.row.swara Swara swara swara NN LOCATION 4 | -1 2 fb:row.row.short_name Short name short|name short|name JJ|NN O|O | 5 | -1 3 fb:row.row.notation Notation notation notation NN O 6 | -1 4 fb:row.row.mnemonic Mnemonic mnemonic mnemonic JJ O 7 | 0 0 fb:cell.1 1 1 1 CD NUMBER 1.0 1.0 8 | 0 1 fb:cell.shadja Shadja shadja shadja FW PERSON 9 | 0 2 fb:cell.sa Sa sa sa NN O 10 | 0 3 fb:cell.s S s s NN O 11 | 0 4 fb:cell.sa Sa sa sa NN O 12 | 1 0 fb:cell.2 2 2 2 CD NUMBER 2.0 2.0 13 | 1 1 fb:cell.shuddha_rishabha Shuddha Rishabha shuddha|rishabha Shuddha|Rishabha NNP|NNP PERSON|PERSON | 14 | 1 2 fb:cell.ri Ri ri Ri NNP PERSON 15 | 1 3 fb:cell.r1 R1 r1 r1 NN O 1.0 16 | 1 4 fb:cell.ra ra ra ra NN O 17 | 2 0 fb:cell.3 3 3 3 CD NUMBER 3.0 3.0 18 | 2 1 fb:cell.chatushruti_rishabha Chatushruti Rishabha chatushruti|rishabha Chatushruti|Rishabha NNP|NNP PERSON|PERSON | 19 | 2 2 fb:cell.ri Ri ri Ri NNP PERSON 20 | 2 3 fb:cell.r2 R2 r2 r2 NN O 2.0 21 | 2 4 fb:cell.ri Ri ri Ri NNP PERSON 22 | 3 0 fb:cell.3 3 3 3 CD NUMBER 3.0 3.0 23 | 3 1 fb:cell.shuddha_gandhara Shuddha Gandhara shuddha|gandhara Shuddha|Gandhara NNP|NNP O|LOCATION | 24 | 3 2 fb:cell.ga Ga ga Ga NNP O 25 | 3 3 fb:cell.g1 G1 g1 g1 NN O 1.0 26 | 3 4 fb:cell.ga Ga ga Ga NNP O 27 | 4 0 fb:cell.4 4 4 4 CD NUMBER 4.0 4.0 28 | 4 1 fb:cell.shatshruti_rishabha Shatshruti Rishabha shatshruti|rishabha shatshruti|rishabha FW|FW PERSON|PERSON | 29 | 4 2 fb:cell.ri Ri ri Ri NNP PERSON 30 | 4 3 fb:cell.r3 R3 r3 r3 NN O 3.0 31 | 4 4 fb:cell.ru ru ru ru NN O 32 | 5 0 fb:cell.4 4 4 4 CD NUMBER 4.0 4.0 33 | 5 1 fb:cell.sadharana_gandhara Sadharana Gandhara sadharana|gandhara Sadharana|Gandhara NNP|NNP O|LOCATION | 34 | 5 2 fb:cell.ga Ga ga Ga NNP O 35 | 5 3 fb:cell.g2 G2 g2 g2 NN O 2.0 36 | 5 4 fb:cell.gi gi gi gi NN O 37 | 6 0 fb:cell.5 5 5 5 CD NUMBER 5.0 5.0 38 | 6 1 fb:cell.antara_gandhara Antara Gandhara antara|gandhara Antara|Gandhara NNP|NNP LOCATION|LOCATION | 39 | 6 2 fb:cell.ga Ga ga Ga NNP O 40 | 6 3 fb:cell.g3 G3 g3 g3 NN O 3.0 41 | 6 4 fb:cell.gu gu gu gu NN O 42 | 7 0 fb:cell.6 6 6 6 CD NUMBER 6.0 6.0 43 | 7 1 fb:cell.shuddha_madhyama Shuddha Madhyama shuddha|madhyama Shuddha|Madhyama NNP|NNP PERSON|PERSON | 44 | 7 2 fb:cell.ma Ma ma Ma NNP O 45 | 7 3 fb:cell.m1 M1 m1 m1 NN O 1.0 46 | 7 4 fb:cell.ma Ma ma Ma NNP O 47 | 8 0 fb:cell.7 7 7 7 CD NUMBER 7.0 7.0 48 | 8 1 fb:cell.prati_madhyama Prati Madhyama prati|madhyama Prati|Madhyama NNP|NNP PERSON|PERSON | 49 | 8 2 fb:cell.ma Ma ma Ma NNP O 50 | 8 3 fb:cell.m2 M2 m2 m2 NN O 2.0 51 | 8 4 fb:cell.mi mi mi mi FW O 52 | 9 0 fb:cell.8 8 8 8 CD NUMBER 8.0 8.0 53 | 9 1 fb:cell.panchama Panchama panchama panchama NN PERSON 54 | 9 2 fb:cell.pa Pa pa Pa NNP O 55 | 9 3 fb:cell.p P p p NN O 56 | 9 4 fb:cell.pa Pa pa Pa NNP O 57 | 10 0 fb:cell.9 9 9 9 CD NUMBER 9.0 9.0 58 | 10 1 fb:cell.shuddha_dhaivata Shuddha Dhaivata shuddha|dhaivata Shuddha|Dhaivata NNP|NNP PERSON|PERSON | 59 | 10 2 fb:cell.dha Dha dha dha NN O 60 | 10 3 fb:cell.d1 D1 d1 d1 NN O 1.0 61 | 10 4 fb:cell.dha Dha dha dha NN O 62 | 11 0 fb:cell.10 10 10 10 CD NUMBER 10.0 10.0 63 | 11 1 fb:cell.chatushruti_dhaivata Chatushruti Dhaivata chatushruti|dhaivata Chatushruti|Dhaivata NNP|NNP O|O | 64 | 11 2 fb:cell.dha Dha dha dha NN O 65 | 11 3 fb:cell.d2 D2 d2 d2 NN O 2.0 66 | 11 4 fb:cell.dhi dhi dhi dhus NN O 67 | 12 0 fb:cell.10 10 10 10 CD NUMBER 10.0 10.0 68 | 12 1 fb:cell.shuddha_nishada Shuddha Nishada shuddha|nishada Shuddha|Nishada NNP|NNP PERSON|PERSON | 69 | 12 2 fb:cell.ni Ni ni nus NN O 70 | 12 3 fb:cell.n1 N1 n1 n1 NN O 1.0 71 | 12 4 fb:cell.na na na na TO O 72 | 13 0 fb:cell.11 11 11 11 CD NUMBER 11.0 11.0 73 | 13 1 fb:cell.shatshruti_dhaivata Shatshruti Dhaivata shatshruti|dhaivata Shatshruti|Dhaivata NNP|NNP PERSON|PERSON | 74 | 13 2 fb:cell.dha Dha dha dha NN O 75 | 13 3 fb:cell.d3 D3 d3 d3 NN O 3.0 76 | 13 4 fb:cell.dhu dhu dhu dhu NN O 77 | 14 0 fb:cell.11 11 11 11 CD NUMBER 11.0 11.0 78 | 14 1 fb:cell.kaisiki_nishada Kaisiki Nishada kaisiki|nishada Kaisiki|Nishada NNP|NNP PERSON|PERSON | 79 | 14 2 fb:cell.ni Ni ni nus NN O 80 | 14 3 fb:cell.n2 N2 n2 n2 NN O 2.0 81 | 14 4 fb:cell.ni Ni ni nus NN O 82 | 15 0 fb:cell.12 12 12 12 CD NUMBER 12.0 12.0 83 | 15 1 fb:cell.kakali_nishada Kakali Nishada kakali|nishada Kakali|Nishada NNP|NNP PERSON|PERSON | 84 | 15 2 fb:cell.ni Ni ni nus NN O 85 | 15 3 fb:cell.n3 N3 n3 n3 NN O 3.0 86 | 15 4 fb:cell.nu nu nu nu NN O 87 | -------------------------------------------------------------------------------- /test_fixtures/data/corenlp_processed_tables/TEST-7.table: -------------------------------------------------------------------------------- 1 | row col id content tokens lemmaTokens posTags nerTags nerValues number date num2 list listId 2 | -1 0 fb:row.row.player Player player Player NNP O 3 | -1 1 fb:row.row.games_played Games Played games|played Games|play NNPS|VBD O|O | 4 | -1 2 fb:row.row.field_goals Field Goals field|goals field|goal NN|NNS O|O | 5 | -1 3 fb:row.row.free_throws Free Throws free|throws Free|throw NNP|VBZ O|O | 6 | -1 4 fb:row.row.points Points points point NNS O 7 | 0 0 fb:cell.ralf_woods Ralf Woods ralf|woods Ralf|Woods NNP|NNP PERSON|PERSON | 8 | 0 1 fb:cell.16 16 16 16 CD NUMBER 16.0 16.0 9 | 0 2 fb:cell.54 54 54 54 CD NUMBER 54.0 54.0 10 | 0 3 fb:cell.70 70 70 70 CD NUMBER 70.0 70.0 11 | 0 4 fb:cell.178 178 178 178 CD NUMBER 178.0 178.0 12 | 1 0 fb:cell.clyde_alwood Clyde Alwood clyde|alwood Clyde|Alwood NNP|NNP PERSON|PERSON | 13 | 1 1 fb:cell.15 15 15 15 CD NUMBER 15.0 15.0 14 | 1 2 fb:cell.57 57 57 57 CD NUMBER 57.0 57.0 15 | 1 3 fb:cell.0 0 0 0 CD NUMBER 0.0 0.0 16 | 1 4 fb:cell.114 114 114 114 CD NUMBER 114.0 114.0 17 | 2 0 fb:cell.ernest_mckay Ernest McKay ernest|mckay Ernest|McKay NNP|NNP PERSON|PERSON | 18 | 2 1 fb:cell.15 15 15 15 CD NUMBER 15.0 15.0 19 | 2 2 fb:cell.39 39 39 39 CD NUMBER 39.0 39.0 20 | 2 3 fb:cell.3 3 3 3 CD NUMBER 3.0 3.0 21 | 2 4 fb:cell.81 81 81 81 CD NUMBER 81.0 81.0 22 | 3 0 fb:cell.ray_woods Ray Woods ray|woods Ray|Woods NNP|NNP PERSON|PERSON | 23 | 3 1 fb:cell.16 16 16 16 CD NUMBER 16.0 16.0 24 | 3 2 fb:cell.19 19 19 19 CD NUMBER 19.0 19.0 25 | 3 3 fb:cell.0 0 0 0 CD NUMBER 0.0 0.0 26 | 3 4 fb:cell.38 38 38 38 CD NUMBER 38.0 38.0 27 | 4 0 fb:cell.john_felmley John Felmley john|felmley John|Felmley NNP|NNP PERSON|PERSON | 28 | 4 1 fb:cell.6 6 6 6 CD NUMBER 6.0 6.0 29 | 4 2 fb:cell.7 7 7 7 CD NUMBER 7.0 7.0 30 | 4 3 fb:cell.4 4 4 4 CD NUMBER 4.0 4.0 31 | 4 4 fb:cell.18 18 18 18 CD NUMBER 18.0 18.0 32 | 5 0 fb:cell.george_halas George Halas george|halas George|Halas NNP|NNP PERSON|PERSON | 33 | 5 1 fb:cell.11 11 11 11 CD NUMBER 11.0 11.0 34 | 5 2 fb:cell.5 5 5 5 CD NUMBER 5.0 5.0 35 | 5 3 fb:cell.0 0 0 0 CD NUMBER 0.0 0.0 36 | 5 4 fb:cell.10 10 10 10 CD NUMBER 10.0 10.0 37 | 6 0 fb:cell.r_c_haas R.C. Haas r.c.|haas R.C.|Haas NNP|NNP PERSON|PERSON | 38 | 6 1 fb:cell.3 3 3 3 CD NUMBER 3.0 3.0 39 | 6 2 fb:cell.1 1 1 1 CD NUMBER 1.0 1.0 40 | 6 3 fb:cell.0 0 0 0 CD NUMBER 0.0 0.0 41 | 6 4 fb:cell.2 2 2 2 CD NUMBER 2.0 2.0 42 | 7 0 fb:cell.gordon_otto Gordon Otto gordon|otto Gordon|Otto NNP|NNP PERSON|PERSON | 43 | 7 1 fb:cell.4 4 4 4 CD NUMBER 4.0 4.0 44 | 7 2 fb:cell.1 1 1 1 CD NUMBER 1.0 1.0 45 | 7 3 fb:cell.0 0 0 0 CD NUMBER 0.0 0.0 46 | 7 4 fb:cell.2 2 2 2 CD NUMBER 2.0 2.0 47 | -------------------------------------------------------------------------------- /test_fixtures/data/nlvr/sample_ungrouped_data.jsonl: -------------------------------------------------------------------------------- 1 | {"sentence": "There is a circle closely touching a corner of a box.", "label": "true", "identifier": "1304-0", "evals": {"r0": "true"}, "structured_rep": [[{"y_loc": 21, "size": 20, "type": "triangle", "x_loc": 27, "color": "Yellow"}, {"y_loc": 60, "size": 10, "type": "circle", "x_loc": 59, "color": "Yellow"}], [{"y_loc": 81, "size": 10, "type": "triangle", "x_loc": 48, "color": "Yellow"}, {"y_loc": 64, "size": 20, "type": "circle", "x_loc": 77, "color": "#0099ff"}], [{"y_loc": 2, "size": 20, "type": "triangle", "x_loc": 62, "color": "Yellow"}, {"y_loc": 70, "size": 30, "type": "circle", "x_loc": 70, "color": "Black"}, {"y_loc": 51, "size": 20, "type": "circle", "x_loc": 30, "color": "#0099ff"}, {"y_loc": 42, "size": 20, "type": "circle", "x_loc": 67, "color": "Yellow"}, {"y_loc": 73, "size": 20, "type": "circle", "x_loc": 37, "color": "Black"}, {"y_loc": 14, "size": 30, "type": "triangle", "x_loc": 7, "color": "Yellow"}, {"y_loc": 27, "size": 10, "type": "circle", "x_loc": 48, "color": "Black"}]]} 2 | {"sentence": "There are 2 yellow blocks", "label": "true", "identifier": "2170-1", "evals": {"r1": "true"}, "structured_rep": [[{"y_loc": 80, "size": 20, "type": "square", "x_loc": 40, "color": "Black"}], [{"y_loc": 80, "size": 20, "type": "square", "x_loc": 40, "color": "#0099ff"}], [{"y_loc": 80, "size": 20, "type": "square", "x_loc": 40, "color": "Black"}, {"y_loc": 59, "size": 20, "type": "square", "x_loc": 40, "color": "Yellow"}, {"y_loc": 38, "size": 20, "type": "square", "x_loc": 40, "color": "Yellow"}]]} 3 | {"sentence": "There is a box without a blue item.", "label": "false", "identifier": "700-2", "evals": {"r0": "false"}, "structured_rep": [[{"y_loc": 38, "size": 10, "type": "circle", "x_loc": 70, "color": "#0099ff"}, {"y_loc": 53, "size": 20, "type": "triangle", "x_loc": 12, "color": "#0099ff"}], [{"y_loc": 39, "size": 30, "type": "triangle", "x_loc": 53, "color": "Black"}, {"y_loc": 70, "size": 30, "type": "circle", "x_loc": 1, "color": "Yellow"}, {"y_loc": 72, "size": 20, "type": "circle", "x_loc": 61, "color": "Black"}, {"y_loc": 37, "size": 30, "type": "triangle", "x_loc": 4, "color": "#0099ff"}, {"y_loc": 80, "size": 20, "type": "triangle", "x_loc": 36, "color": "#0099ff"}, {"y_loc": 85, "size": 10, "type": "circle", "x_loc": 86, "color": "Yellow"}, {"y_loc": 9, "size": 20, "type": "triangle", "x_loc": 6, "color": "Yellow"}], [{"y_loc": 48, "size": 30, "type": "square", "x_loc": 4, "color": "#0099ff"}, {"y_loc": 9, "size": 20, "type": "circle", "x_loc": 80, "color": "Yellow"}, {"y_loc": 45, "size": 10, "type": "square", "x_loc": 53, "color": "Black"}, {"y_loc": 6, "size": 10, "type": "circle", "x_loc": 68, "color": "#0099ff"}]]} 4 | -------------------------------------------------------------------------------- /test_fixtures/data/quarel.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"GeneralQR_V1_Fr_0223","question":"Mike was snowboarding on the snow and hit a piece of ice. He went much faster on the ice because _____ is smoother. (A) snow (B) ice","answer_index":1,"logical_forms":["(infer (speed higher world1) (smoothness higher world2) (smoothness higher world1))","(infer (speed higher world2) (smoothness higher world1) (smoothness higher world2))"],"world_literals":{"world1":"ice","world2":"snow"}} 2 | {"id":"GeneralQR_V1_Fr_0334","question":"A car gets very hot as it drives up a muddy hill, but stays cool as it drives up a grass hill. The car warms on on the muddy hill because the muddy hill has (A) more friction (B) less friction.","answer_index":0,"logical_forms":["(infer (and (heat low world1) (heat high world2)) (friction higher world2) (friction lower world2))","(infer (and (heat low world2) (heat high world1)) (friction higher world1) (friction lower world1))"],"world_literals":{"world1":"grass hill","world2":"muddy hill"}} 3 | {"id":"GeneralQR_V1_B5_1282","question":"Juan is injured in a car accident, which necessitates a hospital stay where he is unable to maintain the strength in his arm. Juan notices that his throwing arm feels extremely frail compared to the level of strength it had when he was healthy. If Juan decides to throw a ball with his friend, when will his throw travel less distance? (A) When Juan's arm is healthy (B) When Juan's arm is weak after the hospital stay.","answer_index":1,"logical_forms":["(infer (strength lower world1) (distance lower world2) (distance lower world1))","(infer (strength lower world2) (distance lower world1) (distance lower world2))"],"world_literals":{"world1":"Juan after a hospital stay","world2":"Juan when healthy"}} 4 | -------------------------------------------------------------------------------- /test_fixtures/data/text2sql/restaurants-schema.csv: -------------------------------------------------------------------------------- 1 | Table, Field, Primary Key, Foreign Key, Type 2 | RESTAURANT, RESTAURANT_ID, y, n, int(11) 3 | RESTAURANT, NAME, n, n, varchar(255) 4 | RESTAURANT, FOOD_TYPE, n, n, varchar(255) 5 | RESTAURANT, CITY_NAME, n, y, varchar(255) 6 | RESTAURANT, RATING, n, n, "decimal(1,1)" 7 | -, -, -, -, - 8 | LOCATION, RESTAURANT_ID, y, y, int(11) 9 | LOCATION, HOUSE_NUMBER, n, n, int(11) 10 | LOCATION, STREET_NAME, n, n, varchar(255) 11 | LOCATION, CITY_NAME, n, y, varchar(255) 12 | -, -, -, -, - 13 | GEOGRAPHIC, CITY_NAME, y, n, varchar(255) 14 | GEOGRAPHIC, COUNTY, n, n, varchar(255) 15 | GEOGRAPHIC, REGION, n, n, varchar(255) 16 | -------------------------------------------------------------------------------- /test_fixtures/data/text2sql/restaurants.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/text2sql/restaurants.db -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/action_space_walker_output/nt-0.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/action_space_walker_output/nt-0.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/action_space_walker_output/nt-1.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/action_space_walker_output/nt-1.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/action_space_walker_output_with_single_tarball/all_lfs_tarball.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/action_space_walker_output_with_single_tarball/all_lfs_tarball.tar.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/dpd_output/nt-0.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/dpd_output/nt-0.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/dpd_output/nt-1.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/dpd_output/nt-1.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/dpd_output/nt-64.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/data/wikitables/dpd_output/nt-64.gz -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/lots_of_ors_example.examples: -------------------------------------------------------------------------------- 1 | (example (id nt-64) (utterance "how many districts are there in virginia?") (context (graph tables.TableKnowledgeGraph csv/204-csv/109.csv)) (targetValue (list (description "22")))) 2 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/sample_data.examples: -------------------------------------------------------------------------------- 1 | (example (id nt-0) (utterance "what was the last year where this team was a part of the usl a-league?") (context (graph tables.TableKnowledgeGraph tables/590.csv)) (targetValue (list (description "2004")))) 2 | (example (id nt-1) (utterance "in what city did piotr's last 1st place finish occur?") (context (graph tables.TableKnowledgeGraph tables/622.csv)) (targetValue (list (description "Bangkok, Thailand")))) 3 | (example (id nt-2) (utterance "When did piotr's last 1st place finish occur?") (context (graph tables.TableKnowledgeGraph tables/622.csv)) (targetValue (list (description "2007")))) 4 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/sample_table.tagged: -------------------------------------------------------------------------------- 1 | row col id content tokens lemmaTokens posTags nerTags nerValues number date num2 list listId 2 | -1 0 fb:row.row.year Year year year NN DURATION 3 | -1 1 fb:row.row.division Division division Division NNP O 4 | -1 2 fb:row.row.league League league League NNP O 5 | -1 3 fb:row.row.regular_season Regular Season regular|season regular|season JJ|NN O|O | 6 | -1 4 fb:row.row.playoffs Playoffs playoffs playoff NNS O 7 | -1 5 fb:row.row.open_cup Open Cup open|cup Open|Cup NNP|NNP MISC|MISC | 8 | -1 6 fb:row.row.avg_attendance Avg. Attendance avg|.|attendance avg|.|attendance NN|.|NN O|O|O || 9 | 0 0 fb:cell.2001 2001 2001 2001 CD DATE 2001 2001.0 2001-xx-xx 10 | 0 1 fb:cell.2 2 2 2 CD NUMBER 2.0 2.0 11 | 0 2 fb:cell.usl_a_league USL A-League usl|a|league USL|A|League NNP|NNP|NNP MISC|MISC|MISC || 12 | 0 3 fb:cell.4th_western 4th, Western 4th|,|western 4th|,|western JJ|,|JJ ORDINAL|O|MISC 4.0|| 4.0 4th|Western fb:part.4th|fb:part.western 13 | 0 4 fb:cell.quarterfinals Quarterfinals quarterfinals quarterfinal NNS O 14 | 0 5 fb:cell.did_not_qualify Did not qualify did|not|qualify do|not|qualify VBD-AUX|RB|VB O|O|O || 15 | 0 6 fb:cell.7_169 7,169 7,169 7,169 CD NUMBER 7169.0 7169.0 16 | 1 0 fb:cell.2005 2005 2005 2005 CD DATE 2005 2005.0 2005-xx-xx 17 | 1 1 fb:cell.2 2 2 2 CD NUMBER 2.0 2.0 18 | 1 2 fb:cell.usl_first_division USL First Division usl|first|division USL|First|Division NNP|NNP|NNP MISC|ORDINAL|O |1.0| 19 | 1 3 fb:cell.5th 5th 5th 5th NN ORDINAL 5.0 5.0 5th fb:part.5th 20 | 1 4 fb:cell.quarterfinals Quarterfinals quarterfinals quarterfinal NNS O 21 | 1 5 fb:cell.4th_round 4th Round 4th|round 4th|Round JJ|NNP ORDINAL|O 4.0| 4.0 22 | 1 6 fb:cell.6_028 6,028 6,028 6,028 CD NUMBER 6028.0 6028.0 23 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/sample_table.tsv: -------------------------------------------------------------------------------- 1 | Year Division League Regular Season Playoffs Open Cup Score 2 | 2001 2 USL A-League 4th, Western Quarterfinals Did not qualify 20-30 3 | 2005 2 USL First Division 5th Quarterfinals 4th Round 50-40 4 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/sample_table_with_date.tsv: -------------------------------------------------------------------------------- 1 | Date Division League Regular Season Playoffs Open Cup Avg. Attendance 2 | January 2001 2 USL-A-League 4th, Western Quarterfinals Did not qualify 7,169 3 | March 2005 2 USL First Division 5th Quarterfinals 4th Round 6,028 4 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/109.tsv: -------------------------------------------------------------------------------- 1 | District Incumbent Party First\nelected Result Candidates 2 | Virginia 1 Thomas Newton, Jr. Adams-Clay Republican 1801 Re-elected Thomas Newton, Jr. 3 | Virginia 2 Arthur Smith Crawford Republican 1821 Retired\nJacksonian gain James Trezvant (J) 75.7%\nRichard Eppes (DR) 24.3% 4 | Virginia 3 William S. Archer Crawford Republican 1820 (special) Re-elected William S. Archer (J) 100% 5 | Virginia 4 Mark Alexander Crawford Republican 1819 Re-elected Mark Alexander (J) 6 | Virginia 5 John Randolph Crawford Republican 1799\n1819 Re-elected John Randolph (J) 100% 7 | Virginia 6 George Tucker Crawford Republican 1819 Retired\nJacksonian gain Thomas Davenport (J) 53.9%\nJames Lanier 22.6%\nBarzillai Graves 16.3%\nJohn D. Urquhart 7.2% 8 | Virginia 7 Jabez Leftwich Crawford Republican 1821 Lost re-election\nJacksonian gain Nathaniel H. Claiborne (J) 51.4%\nJabez Leftwich (C-DR) 48.6% 9 | Virginia 8 Burwell Bassett Crawford Republican 1805\n1821 Re-elected Burwell Bassett (J) 95.3%\nServant Jones (DR) 4.5%\nReuben Washer 0.2% 10 | Virginia 9 Andrew Stevenson Crawford Republican 1821 Re-elected Andrew Stevenson (J) 100% 11 | Virginia 10 William C. Rives Crawford Republican 1823 Re-elected William C. Rives (J) 100% 12 | Virginia 11 Philip P. Barbour Crawford Republican 1814 (special) Retired\nAdams gain Robert Taylor (A) 100% 13 | Virginia 12 Robert S. Garnett Crawford Republican 1817 Re-elected Robert S. Garnett (J) 68.5%\nJohn H. Upshaw 31.5% 14 | Virginia 13 John Taliaferro Crawford Republican 1824 (special) Re-elected John Taliaferro (A) 63.3%\nJohn Hooe (F) 26.7% 15 | Virginia 14 Charles F. Mercer Crawford Republican 1817 Re-elected Charles F. Mercer (A) 16 | Virginia 15 John S. Barbour Crawford Republican 1823 Re-elected John S. Barbour (J) 53.7%\nThomas Marshall (F) 46.3% 17 | Virginia 16 James Stephenson Federalist 1821 Retired\nAdams gain William Armstrong (A) 57.1%\nEdward Colston (F) 42.9% 18 | Virginia 17 Jared Williams Crawford Republican 1819 Retired\nAdams gain Alfred H. Powell (A) 42.0%\nWilliam Steenergen (DR) 21.5%\nAugustine C. Smith (DR) 20.3%\nSamuel Kercheval (DR) 13.6%\nRobert Allen (DR) 2.6% 19 | Virginia 18 Joseph Johnson Jackson Republican 1823 Re-elected Joseph Johnson (J) 62.0%\nPhillip Doddridge (F) 38.0% 20 | Virginia 19 William McCoy Crawford Republican 1811 Re-elected William McCoy (J) 70.2%\nDaniel Sheffey (F) 29.8% 21 | Virginia 20 John Floyd Crawford Republican 1817 Re-elected John Floyd (J) 84.7%\nAllen Taylor (F) 15.3% 22 | Virginia 21 William Smith Crawford Republican 1821 Re-elected William Smith (J) 55.2%\nJames Lovell (DR) 44.8% 23 | Virginia 22 Alexander Smyth Crawford Republican 1817 Retired\nAdams gain Benjamin Estil (A) 58.9%\nJoseph Crockett (DR) 32.0%\nWilliam Graham (DR) 9.1% 24 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/341.tagged: -------------------------------------------------------------------------------- 1 | row col id content tokens lemmaTokens posTags nerTags nerValues number date num2 list listId 2 | -1 0 fb:row.row.name_in_english Name in English name|in|english name|in|English VB|IN|NNP O|O|MISC || 3 | -1 1 fb:row.row.location_in_english Location in English location|in|english location|in|English NN|IN|NNP O|O|MISC || 4 | 0 0 fb:cell.lake_gala Lake Gala lake|gala Lake|Gala NNP|NNP LOCATION|LOCATION | 5 | 0 1 fb:cell.edirne Edirne edirne Edirne NNP LOCATION Edirne fb:part.edirne 6 | 1 0 fb:cell.paradeniz Paradeniz paradeniz paradeniz NN LOCATION 7 | 1 1 fb:cell.mersin Mersin mersin Mersin NNP LOCATION Mersin fb:part.mersin 8 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/590.csv: -------------------------------------------------------------------------------- 1 | "Year","Division","League","Regular Season","Playoffs","Open Cup","Avg. Attendance" 2 | "2001","2","USL A-League","4th, Western","Quarterfinals","Did not qualify","7,169" 3 | "2002","2","USL A-League","2nd, Pacific","1st Round","Did not qualify","6,260" 4 | "2003","2","USL A-League","3rd, Pacific","Did not qualify","Did not qualify","5,871" 5 | "2004","2","USL A-League","1st, Western","Quarterfinals","4th Round","5,628" 6 | "2005","2","USL First Division","5th","Quarterfinals","4th Round","6,028" 7 | "2006","2","USL First Division","11th","Did not qualify","3rd Round","5,575" 8 | "2007","2","USL First Division","2nd","Semifinals","2nd Round","6,851" 9 | "2008","2","USL First Division","11th","Did not qualify","1st Round","8,567" 10 | "2009","2","USL First Division","1st","Semifinals","3rd Round","9,734" 11 | "2010","2","USSF D-2 Pro League","3rd, USL (3rd)","Quarterfinals","3rd Round","10,727" 12 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/590.tsv: -------------------------------------------------------------------------------- 1 | Year Division League Regular Season Playoffs Open Cup Avg. Attendance 2 | 2001 2 USL A-League 4th, Western Quarterfinals Did not qualify 7,169 3 | 2002 2 USL A-League 2nd, Pacific 1st Round Did not qualify 6,260 4 | 2003 2 USL A-League 3rd, Pacific Did not qualify Did not qualify 5,871 5 | 2004 2 USL A-League 1st, Western Quarterfinals 4th Round 5,628 6 | 2005 2 USL First Division 5th Quarterfinals 4th Round 6,028 7 | 2006 2 USL First Division 11th Did not qualify 3rd Round 5,575 8 | 2007 2 USL First Division 2nd Semifinals 2nd Round 6,851 9 | 2008 2 USL First Division 11th Did not qualify 1st Round 8,567 10 | 2009 2 USL First Division 1st Semifinals 3rd Round 9,734 11 | 2010 2 USSF D-2 Pro League 3rd, USL (3rd) Quarterfinals 3rd Round 10,727 12 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/622.csv: -------------------------------------------------------------------------------- 1 | "Year","Competition","Venue","Position","Event","Notes" 2 | "2001","World Youth Championships","Debrecen, Hungary","2nd","400 m","47.12" 3 | "2001","World Youth Championships","Debrecen, Hungary","1st","Medley relay","1:50.46" 4 | "2001","European Junior Championships","Grosseto, Italy","1st","4x400 m relay","3:06.12" 5 | "2003","European Junior Championships","Tampere, Finland","3rd","400 m","46.69" 6 | "2003","European Junior Championships","Tampere, Finland","2nd","4x400 m relay","3:08.62" 7 | "2005","European U23 Championships","Erfurt, Germany","11th (sf)","400 m","46.62" 8 | "2005","European U23 Championships","Erfurt, Germany","1st","4x400 m relay","3:04.41" 9 | "2005","Universiade","Izmir, Turkey","7th","400 m","46.89" 10 | "2005","Universiade","Izmir, Turkey","1st","4x400 m relay","3:02.57" 11 | "2006","World Indoor Championships","Moscow, Russia","2nd (h)","4x400 m relay","3:06.10" 12 | "2006","European Championships","Gothenburg, Sweden","3rd","4x400 m relay","3:01.73" 13 | "2007","European Indoor Championships","Birmingham, United Kingdom","3rd","4x400 m relay","3:08.14" 14 | "2007","Universiade","Bangkok, Thailand","7th","400 m","46.85" 15 | "2007","Universiade","Bangkok, Thailand","1st","4x400 m relay","3:02.05" 16 | "2008","World Indoor Championships","Valencia, Spain","4th","4x400 m relay","3:08.76" 17 | "2008","Olympic Games","Beijing, China","7th","4x400 m relay","3:00.32" 18 | "2009","Universiade","Belgrade, Serbia","2nd","4x400 m relay","3:05.69" 19 | -------------------------------------------------------------------------------- /test_fixtures/data/wikitables/tables/622.tsv: -------------------------------------------------------------------------------- 1 | Year Competition Venue Position Event Notes 2 | 2001 World Youth Championships Debrecen, Hungary 2nd 400 m 47.12 3 | 2001 World Youth Championships Debrecen, Hungary 1st Medley relay 1:50.46 4 | 2001 European Junior Championships Grosseto, Italy 1st 4x400 m relay 3:06.12 5 | 2003 European Junior Championships Tampere, Finland 3rd 400 m 46.69 6 | 2003 European Junior Championships Tampere, Finland 2nd 4x400 m relay 3:08.62 7 | 2005 European U23 Championships Erfurt, Germany 11th (sf) 400 m 46.62 8 | 2005 European U23 Championships Erfurt, Germany 1st 4x400 m relay 3:04.41 9 | 2005 Universiade Izmir, Turkey 7th 400 m 46.89 10 | 2005 Universiade Izmir, Turkey 1st 4x400 m relay 3:02.57 11 | 2006 World Indoor Championships Moscow, Russia 2nd (h) 4x400 m relay 3:06.10 12 | 2006 European Championships Gothenburg, Sweden 3rd 4x400 m relay 3:01.73 13 | 2007 European Indoor Championships Birmingham, United Kingdom 3rd 4x400 m relay 3:08.14 14 | 2007 Universiade Bangkok, Thailand 7th 400 m 46.85 15 | 2007 Universiade Bangkok, Thailand 1st 4x400 m relay 3:02.05 16 | 2008 World Indoor Championships Valencia, Spain 4th 4x400 m relay 3:08.76 17 | 2008 Olympic Games Beijing, China 7th 4x400 m relay 3:00.32 18 | 2009 Universiade Belgrade, Serbia 2nd 4x400 m relay 3:05.69 19 | -------------------------------------------------------------------------------- /test_fixtures/elmo/config/characters_token_embedder.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "conll2003", 4 | "tag_label": "ner", 5 | "token_indexers": { 6 | "tokens": { 7 | "type": "single_id", 8 | "lowercase_tokens": true 9 | }, 10 | "elmo": { 11 | "type": "elmo_characters" 12 | } 13 | } 14 | }, 15 | "train_data_path": "allennlp/tests/fixtures/data/conll2003.txt", 16 | "validation_data_path": "allennlp/tests/fixtures/data/conll2003.txt", 17 | "model": { 18 | "type": "crf_tagger", 19 | "text_field_embedder": { 20 | "token_embedders": { 21 | "tokens": { 22 | "type": "embedding", 23 | "embedding_dim": 50 24 | }, 25 | "elmo": { 26 | "type": "elmo_token_embedder", 27 | "options_file": "allennlp/tests/fixtures/elmo/options.json", 28 | "weight_file": "allennlp/tests/fixtures/elmo/lm_weights.hdf5", 29 | } 30 | } 31 | }, 32 | "encoder": { 33 | "type": "gru", 34 | "input_size": 82, 35 | "hidden_size": 25, 36 | "num_layers": 2, 37 | "dropout": 0.5, 38 | "bidirectional": true 39 | }, 40 | "regularizer": [ 41 | ["transitions$", {"type": "l2", "alpha": 0.01}] 42 | ] 43 | }, 44 | "data_loader": {"batch_size": 32}, 45 | "trainer": { 46 | "optimizer": "adam", 47 | "num_epochs": 5, 48 | "cuda_device": -1 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /test_fixtures/elmo/elmo_token_embeddings.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/elmo/elmo_token_embeddings.hdf5 -------------------------------------------------------------------------------- /test_fixtures/elmo/lm_embeddings_0.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/elmo/lm_embeddings_0.hdf5 -------------------------------------------------------------------------------- /test_fixtures/elmo/lm_embeddings_1.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/elmo/lm_embeddings_1.hdf5 -------------------------------------------------------------------------------- /test_fixtures/elmo/lm_embeddings_2.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/elmo/lm_embeddings_2.hdf5 -------------------------------------------------------------------------------- /test_fixtures/elmo/lm_weights.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/elmo/lm_weights.hdf5 -------------------------------------------------------------------------------- /test_fixtures/elmo/options.json: -------------------------------------------------------------------------------- 1 | { 2 | "lstm": { 3 | "cell_clip": 3, 4 | "use_skip_connections": true, 5 | "n_layers": 2, 6 | "proj_clip": 3, 7 | "projection_dim": 16, 8 | "dim": 64 9 | }, 10 | "char_cnn": { 11 | "embedding": { 12 | "dim": 4 13 | }, 14 | "filters": [ 15 | [1, 4], 16 | [2, 8], 17 | [3, 16], 18 | [4, 32], 19 | [5, 64] 20 | ], 21 | "n_highway": 2, 22 | "n_characters": 262, 23 | "max_characters_per_token": 50, 24 | "activation": "relu" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /test_fixtures/elmo/sentences.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | "The U.S. Centers for Disease Control and Prevention initially advised school systems to close if outbreaks occurred , then reversed itself , saying the apparent mildness of the virus meant most schools and day care centers should stay open , even if they had confirmed cases of swine flu .", 4 | "When Ms. Winfrey invited Suzanne Somers to share her controversial views about bio-identical hormone treatment on her syndicated show in 2009 , it won Ms. Winfrey a rare dollop of unflattering press , including a Newsweek cover story titled \" Crazy Talk : Oprah , Wacky Cures & You . \"", 5 | "Elk calling -- a skill that hunters perfected long ago to lure game with the promise of a little romance -- is now its own sport .", 6 | "Don 't !", 7 | "Fish , ranked 98th in the world , fired 22 aces en route to a 6-3 , 6-7 ( 5 \/ 7 ) , 7-6 ( 7 \/ 4 ) win over seventh-seeded Argentinian David Nalbandian .", 8 | "Why does everything have to become such a big issue ?", 9 | "AMMAN ( Reuters ) - King Abdullah of Jordan will meet U.S. President Barack Obama in Washington on April 21 to lobby on behalf of Arab states for a stronger U.S. role in Middle East peacemaking , palace officials said on Sunday .", 10 | "To help keep traffic flowing the Congestion Charge will remain in operation through-out the strike and TfL will be suspending road works on major London roads wherever possible .", 11 | "If no candidate wins an absolute majority , there will be a runoff between the top two contenders , most likely in mid-October .", 12 | "Authorities previously served search warrants at Murray 's Las Vegas home and his businesses in Las Vegas and Houston ." 13 | ], 14 | [ 15 | "Brent North Sea crude for November delivery rose 84 cents to 68.88 dollars a barrel .", 16 | "That seems to have been their model up til now .", 17 | "Gordon will join Luol Deng on the GB team ; their respective NBA teams , the Detroit Pistons and the Chicago Bulls , play tonight .", 18 | "Nikam maintains the attacks were masterminded by the Muslim militant group Lashkar-e-Taiba .", 19 | "Last year , Williams was unseeded , ranked 81st and coming off one of her worst losses on tour -- in a Tier 4 event at Hobart -- yet she beat six seeded players en route to the title at Melbourne Park .", 20 | "It said that two officers involved in the case had been disciplined .", 21 | "\" There is more intelligence now being gathered , \" the official said , adding that such efforts would continue for some time .", 22 | "The majority will be of the standard 6X6 configuration for carrying personnel .", 23 | "\" Consequently , necessary actions may not be taken to reduce the risks to children of sexual exploitation and drug or alcohol misuse , \" the report said . \u2022 Almost two-thirds of inspected schools were good or outstanding , but the number of underperforming secondaries remained \" stubborn and persistent . \"", 24 | "What a World Cup ." 25 | ], 26 | [ 27 | "But , there have also been many cases of individuals and small groups of people protesting , as in the case of Rongye Adak , a nomad who called for the return of the Dalai Lama and for the freedom of Tibet during the Lithang Horse Racing Festival , in eastern Tibet .", 28 | "James Duncan , head of transportation at Bournemouth Borough Council , said : \" Our legal team is reviewing the entitlement of taxis to drop and pick up passengers at bus stops , only for as long as is absolutely necessary to fulfil that function and for no other reason .", 29 | "To Mo concerning the food log you kept -- Dr. Buchholz recommends the same thing .", 30 | "The CBO estimates that only 23 percent of that would be spent in 2009 and 2010 .", 31 | "Even so , Democrats slammed Bush as out of touch .", 32 | "An information campaign will be launched later to raise awareness of employment rights and how to enforce them .", 33 | "At the gallery the concept is less vague , as Ms. Piper cites specific instances of racial violence , political assassinations and the devastation of Hurricane Katrina .", 34 | "There have been some exceptions -- such as Medicare in 1965 .", 35 | "The government guidance will be reviewed early next year after a period of public comment .", 36 | "It wasn 't the most seaworthy of prizes ." 37 | ] 38 | ] 39 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "nlvr", 4 | "output_agendas": true 5 | }, 6 | "vocabulary": { 7 | "non_padded_namespaces": ["denotations", "rule_labels"] 8 | }, 9 | "train_data_path": "test_fixtures/data/nlvr/sample_grouped_data.jsonl", 10 | "validation_data_path": "test_fixtures/data/nlvr/sample_grouped_data.jsonl", 11 | "model": { 12 | "type": "nlvr_coverage_parser", 13 | "sentence_embedder": { 14 | "token_embedders": { 15 | "tokens": { 16 | "type": "embedding", 17 | "embedding_dim": 25, 18 | "trainable": true 19 | } 20 | } 21 | }, 22 | "action_embedding_dim": 50, 23 | "encoder": { 24 | "type": "lstm", 25 | "input_size": 25, 26 | "hidden_size": 10, 27 | "num_layers": 1 28 | }, 29 | "beam_size": 40, 30 | "max_num_finished_states": 40, 31 | "max_decoding_steps": 40, 32 | "attention": {"type": "dot_product"}, 33 | "checklist_cost_weight": 0.8, 34 | "dynamic_cost_weight": { 35 | "wait_num_epochs": 0, 36 | "rate": 0.1 37 | }, 38 | "dropout": 0.3, 39 | "penalize_non_agenda_actions": true 40 | }, 41 | "data_loader": { 42 | "batch_sampler": { 43 | "type": "bucket", 44 | "padding_noise": 0.0, 45 | "batch_size": 4 46 | } 47 | }, 48 | "trainer": { 49 | "num_epochs": 1, 50 | "patience": 2, 51 | "cuda_device": -1, 52 | "optimizer": { 53 | "type": "sgd", 54 | "lr": 0.01 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/mml_init_experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "nlvr", 4 | "output_agendas": true 5 | }, 6 | "vocabulary": { 7 | "non_padded_namespaces": ["denotations", "rule_labels"] 8 | }, 9 | "train_data_path": "test_fixtures/data/nlvr/sample_grouped_data.jsonl", 10 | "validation_data_path": "test_fixtures/data/nlvr/sample_grouped_data.jsonl", 11 | "model": { 12 | "type": "nlvr_coverage_parser", 13 | "sentence_embedder": { 14 | "token_embedders": { 15 | "tokens": { 16 | "type": "embedding", 17 | "embedding_dim": 25, 18 | "trainable": true 19 | } 20 | } 21 | }, 22 | "action_embedding_dim": 50, 23 | "encoder": { 24 | "type": "lstm", 25 | "input_size": 25, 26 | "hidden_size": 10, 27 | "num_layers": 1 28 | }, 29 | "beam_size": 20, 30 | "max_decoding_steps": 20, 31 | "attention": {"type": "dot_product"}, 32 | "checklist_cost_weight": 0.8, 33 | "dynamic_cost_weight": { 34 | "wait_num_epochs": 0, 35 | "rate": 0.1 36 | }, 37 | "penalize_non_agenda_actions": true, 38 | "initial_mml_model_file": "test_fixtures/semantic_parsing/nlvr_direct_semantic_parser/serialization/model.tar.gz" 39 | }, 40 | "data_loader": { 41 | "batch_sampler": { 42 | "type": "bucket", 43 | "padding_noise": 0.0, 44 | "batch_size": 2 45 | } 46 | }, 47 | "trainer": { 48 | "num_epochs": 1, 49 | "patience": 2, 50 | "cuda_device": -1, 51 | "optimizer": { 52 | "type": "sgd", 53 | "lr": 0.01 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/serialization/best.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/nlvr_coverage_semantic_parser/serialization/best.th -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/serialization/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/nlvr_coverage_semantic_parser/serialization/model.tar.gz -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/serialization/vocabulary/denotations.txt: -------------------------------------------------------------------------------- 1 | true 2 | false 3 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/serialization/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | denotations 2 | rule_labels 3 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/serialization/vocabulary/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | a 3 | There 4 | is 5 | circle 6 | closely 7 | touching 8 | corner 9 | of 10 | box 11 | . 12 | are 13 | 2 14 | yellow 15 | blocks 16 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_coverage_semantic_parser/ungrouped_experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "nlvr", 4 | "output_agendas": true 5 | }, 6 | "vocabulary": { 7 | "non_padded_namespaces": ["rule_labels", "denotations"] 8 | }, 9 | "train_data_path": "test_fixtures/data/nlvr/sample_ungrouped_data.jsonl", 10 | "validation_data_path": "test_fixtures/data/nlvr/sample_ungrouped_data.jsonl", 11 | "model": { 12 | "type": "nlvr_coverage_parser", 13 | "sentence_embedder": { 14 | "token_embedders": { 15 | "tokens": { 16 | "type": "embedding", 17 | "embedding_dim": 25, 18 | "trainable": true 19 | } 20 | } 21 | }, 22 | "action_embedding_dim": 50, 23 | "encoder": { 24 | "type": "lstm", 25 | "input_size": 25, 26 | "hidden_size": 10, 27 | "num_layers": 1 28 | }, 29 | "beam_size": 20, 30 | "max_decoding_steps": 20, 31 | "attention": {"type": "dot_product"}, 32 | "checklist_cost_weight": 0.8, 33 | "dynamic_cost_weight": { 34 | "wait_num_epochs": 0, 35 | "rate": 0.1 36 | }, 37 | "penalize_non_agenda_actions": true 38 | }, 39 | "data_loader": { 40 | "batch_sampler": { 41 | "type": "bucket", 42 | "padding_noise": 0.0, 43 | "batch_size": 2 44 | } 45 | }, 46 | "trainer": { 47 | "num_epochs": 1, 48 | "patience": 2, 49 | "cuda_device": -1, 50 | "optimizer": { 51 | "type": "sgd", 52 | "lr": 0.01 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "nlvr", 4 | "output_agendas": false 5 | }, 6 | "vocabulary": { 7 | "non_padded_namespaces": ["denotations", "rule_labels"] 8 | }, 9 | "train_data_path": "test_fixtures/data/nlvr/sample_processed_data.jsonl", 10 | "validation_data_path": "test_fixtures/data/nlvr/sample_processed_data.jsonl", 11 | "model": { 12 | "type": "nlvr_direct_parser", 13 | "sentence_embedder": { 14 | "token_embedders": { 15 | "tokens": { 16 | "type": "embedding", 17 | "embedding_dim": 25, 18 | "trainable": true 19 | } 20 | } 21 | }, 22 | "action_embedding_dim": 50, 23 | "encoder": { 24 | "type": "lstm", 25 | "input_size": 25, 26 | "hidden_size": 10, 27 | "num_layers": 1 28 | }, 29 | "decoder_beam_search": { 30 | "beam_size": 5 31 | }, 32 | "max_decoding_steps": 20, 33 | "attention": {"type": "dot_product"}, 34 | "dropout": 0.2 35 | }, 36 | "data_loader": { 37 | "batch_sampler": { 38 | "type": "bucket", 39 | "padding_noise": 0.0, 40 | "batch_size": 2 41 | } 42 | }, 43 | "trainer": { 44 | "num_epochs": 1, 45 | "patience": 2, 46 | "cuda_device": -1, 47 | "optimizer": { 48 | "type": "sgd", 49 | "lr": 0.01 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/serialization/best.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/nlvr_direct_semantic_parser/serialization/best.th -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/serialization/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/nlvr_direct_semantic_parser/serialization/model.tar.gz -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/serialization/vocabulary/denotations.txt: -------------------------------------------------------------------------------- 1 | true 2 | false 3 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/serialization/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | denotations 2 | rule_labels 3 | -------------------------------------------------------------------------------- /test_fixtures/nlvr_direct_semantic_parser/serialization/vocabulary/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | a 3 | There 4 | is 5 | circle 6 | closely 7 | touching 8 | corner 9 | of 10 | box 11 | . 12 | are 13 | 2 14 | yellow 15 | blocks 16 | -------------------------------------------------------------------------------- /test_fixtures/text2sql/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "grammar_based_text2sql", 4 | "database_file": "test_fixtures/data/text2sql/restaurants.db", 5 | "schema_path": "test_fixtures/data/text2sql/restaurants-schema.csv" 6 | }, 7 | "train_data_path": "test_fixtures/data/text2sql/restaurants_tiny.json", 8 | "validation_data_path": "test_fixtures/data/text2sql/restaurants_tiny.json", 9 | "model": { 10 | "type": "text2sql_parser", 11 | "utterance_embedder": { 12 | "token_embedders": { 13 | "tokens": { 14 | "type": "embedding", 15 | "embedding_dim": 5, 16 | "trainable": true 17 | } 18 | } 19 | }, 20 | "action_embedding_dim": 10, 21 | "encoder": { 22 | "type": "lstm", 23 | "input_size": 5, 24 | "hidden_size": 7, 25 | "bidirectional": true, 26 | "num_layers": 1 27 | }, 28 | "decoder_beam_search": { 29 | "beam_size": 5 30 | }, 31 | "max_decoding_steps": 10, 32 | "input_attention": {"type": "dot_product"}, 33 | "dropout": 0.0 34 | }, 35 | "data_loader": { 36 | "batch_size" : 4 37 | }, 38 | "trainer": { 39 | "num_epochs": 2, 40 | "patience": 5, 41 | "cuda_device": -1, 42 | "optimizer": { 43 | "type": "sgd", 44 | "lr": 0.1 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/experiment-elmo-no-features.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikitables", 4 | "tables_directory": "test_fixtures/data/wikitables/", 5 | "dpd_output_directory": "test_fixtures/data/wikitables/dpd_output/", 6 | "question_token_indexers": { 7 | "elmo": { 8 | "type": "elmo_characters" 9 | } 10 | }, 11 | "max_table_tokens": 200 12 | }, 13 | "vocabulary": { 14 | "min_count": {"tokens": 1} 15 | }, 16 | "train_data_path": "test_fixtures/data/wikitables/sample_data.examples", 17 | "validation_data_path": "test_fixtures/data/wikitables/sample_data.examples", 18 | "model": { 19 | "type": "wikitables_mml_parser", 20 | "question_embedder": { 21 | "token_embedders": { 22 | "elmo":{ 23 | "type": "elmo_token_embedder", 24 | "options_file": "test_fixtures/elmo/options.json", 25 | "weight_file": "test_fixtures/elmo/lm_weights.hdf5", 26 | "do_layer_norm": false, 27 | "dropout": 0.0 28 | } 29 | } 30 | }, 31 | "action_embedding_dim": 50, 32 | "encoder": { 33 | "type": "lstm", 34 | "input_size": 64, 35 | "hidden_size": 10, 36 | "num_layers": 1 37 | }, 38 | "entity_encoder": { 39 | "type": "boe", 40 | "embedding_dim": 32, 41 | "averaged": true 42 | }, 43 | "decoder_beam_search": { 44 | "beam_size": 3 45 | }, 46 | "max_decoding_steps": 200, 47 | "attention": {"type": "dot_product"}, 48 | "num_linking_features": 0, 49 | "use_neighbor_similarity_for_linking": true, 50 | "tables_directory": "test_fixtures/data/wikitables/" 51 | }, 52 | "data_loader": { 53 | "batch_sampler": { 54 | "type": "bucket", 55 | "padding_noise": 0.0, 56 | "batch_size": 2 57 | } 58 | }, 59 | "trainer": { 60 | "num_epochs": 2, 61 | "patience": 10, 62 | "cuda_device": -1, 63 | "optimizer": { 64 | "type": "sgd", 65 | "lr": 0.01 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/experiment-erm.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikitables", 4 | "tables_directory": "test_fixtures/data/wikitables/", 5 | "output_agendas": true 6 | }, 7 | "train_data_path": "test_fixtures/data/wikitables/sample_data.examples", 8 | "validation_data_path": "test_fixtures/data/wikitables/sample_data.examples", 9 | "model": { 10 | "type": "wikitables_erm_parser", 11 | "question_embedder": { 12 | "token_embedders": { 13 | "tokens": { 14 | "type": "embedding", 15 | "embedding_dim": 25, 16 | "trainable": true 17 | } 18 | } 19 | }, 20 | "action_embedding_dim": 50, 21 | "encoder": { 22 | "type": "lstm", 23 | "input_size": 50, 24 | "hidden_size": 10, 25 | "num_layers": 1 26 | }, 27 | "entity_encoder": { 28 | "type": "boe", 29 | "embedding_dim": 25, 30 | "averaged": true 31 | }, 32 | "checklist_cost_weight": 0.6, 33 | "max_decoding_steps": 10, 34 | "decoder_beam_size": 50, 35 | "decoder_num_finished_states": 100, 36 | "attention": {"type": "dot_product"} 37 | }, 38 | "data_loader": { 39 | "batch_size": 2 40 | }, 41 | "trainer": { 42 | "num_epochs": 2, 43 | "patience": 10, 44 | "cuda_device": -1, 45 | "optimizer": { 46 | "type": "sgd", 47 | "lr": 0.01 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/experiment-mixture.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikitables", 4 | "tables_directory": "test_fixtures/data/wikitables/", 5 | "dpd_output_directory": "test_fixtures/data/wikitables/dpd_output/" 6 | }, 7 | "train_data_path": "test_fixtures/data/wikitables/sample_data.examples", 8 | "validation_data_path": "test_fixtures/data/wikitables/sample_data.examples", 9 | "model": { 10 | "type": "wikitables_mml_parser", 11 | "tables_directory": "test_fixtures/data/wikitables/", 12 | "question_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 25, 17 | "trainable": true 18 | } 19 | } 20 | }, 21 | "action_embedding_dim": 50, 22 | "mixture_feedforward": { 23 | "input_dim": 10, 24 | "num_layers": 3, 25 | "hidden_dims": [5, 2, 1], 26 | "activations": ["relu", "sigmoid", "sigmoid"], 27 | "dropout": [0.0, 0.0, 0.0] 28 | }, 29 | "encoder": { 30 | "type": "lstm", 31 | "input_size": 50, 32 | "hidden_size": 10, 33 | "num_layers": 1 34 | }, 35 | "entity_encoder": { 36 | "type": "boe", 37 | "embedding_dim": 25, 38 | "averaged": true 39 | }, 40 | "decoder_beam_search": { 41 | "beam_size": 3 42 | }, 43 | "max_decoding_steps": 200, 44 | "attention": {"type": "dot_product"} 45 | }, 46 | "data_loader": { 47 | "batch_sampler": { 48 | "type": "bucket", 49 | "padding_noise": 0.0, 50 | "batch_size": 2 51 | } 52 | }, 53 | "trainer": { 54 | "num_epochs": 2, 55 | "patience": 10, 56 | "cuda_device": -1, 57 | "optimizer": { 58 | "type": "sgd", 59 | "lr": 0.01 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wikitables", 4 | "tables_directory": "test_fixtures/data/wikitables/", 5 | "offline_logical_forms_directory": "test_fixtures/data/wikitables/action_space_walker_output/" 6 | }, 7 | "train_data_path": "test_fixtures/data/wikitables/sample_data.examples", 8 | "validation_data_path": "test_fixtures/data/wikitables/sample_data.examples", 9 | "model": { 10 | "type": "wikitables_mml_parser", 11 | "question_embedder": { 12 | "token_embedders": { 13 | "tokens": { 14 | "type": "embedding", 15 | "embedding_dim": 25, 16 | "trainable": true 17 | } 18 | } 19 | }, 20 | "action_embedding_dim": 50, 21 | "encoder": { 22 | "type": "lstm", 23 | "input_size": 50, 24 | "hidden_size": 10, 25 | "num_layers": 1 26 | }, 27 | "entity_encoder": { 28 | "type": "boe", 29 | "embedding_dim": 25, 30 | "averaged": true 31 | }, 32 | "decoder_beam_search": { 33 | "beam_size": 3 34 | }, 35 | "max_decoding_steps": 200, 36 | "attention": {"type": "dot_product"} 37 | }, 38 | "data_loader": { 39 | "batch_sampler": { 40 | "type": "bucket", 41 | "padding_noise": 0.0, 42 | "batch_size": 2 43 | } 44 | }, 45 | "trainer": { 46 | "num_epochs": 2, 47 | "patience": 10, 48 | "cuda_device": -1, 49 | "optimizer": { 50 | "type": "sgd", 51 | "lr": 0.01 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/serialization/best.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/wikitables/serialization/best.th -------------------------------------------------------------------------------- /test_fixtures/wikitables/serialization/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/test_fixtures/wikitables/serialization/model.tar.gz -------------------------------------------------------------------------------- /test_fixtures/wikitables/serialization/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/serialization/vocabulary/rule_labels.txt: -------------------------------------------------------------------------------- 1 | -> same_as 2 | -> argmax 3 | -> argmin 4 | -> filter_date_equals 5 | -> filter_date_greater 6 | -> filter_date_greater_equals 7 | -> filter_date_lesser 8 | -> filter_date_lesser_equals 9 | -> filter_date_not_equals 10 | -> max_date 11 | -> min_date 12 | -> mode_date 13 | -> select_date 14 | -> diff 15 | -> filter_number_equals 16 | -> filter_number_greater 17 | -> filter_number_greater_equals 18 | -> filter_number_lesser 19 | -> filter_number_lesser_equals 20 | -> filter_number_not_equals 21 | -> average 22 | -> max_number 23 | -> min_number 24 | -> mode_number 25 | -> select_number 26 | -> sum 27 | -> filter_in 28 | -> filter_not_in 29 | -> mode_string 30 | -> select_string 31 | -> first 32 | -> last 33 | -> next 34 | -> previous 35 | -> count 36 | -> date 37 | @start@ -> Date 38 | @start@ -> List[str] 39 | @start@ -> Number 40 | Date -> [, List[Row], DateColumn] 41 | Date -> [, Number, Number, Number] 42 | List[Row] -> [, List[Row], Column] 43 | List[Row] -> [, List[Row], ComparableColumn] 44 | List[Row] -> [, List[Row], DateColumn, Date] 45 | List[Row] -> [, List[Row], NumberColumn, Number] 46 | List[Row] -> [, List[Row], StringColumn, List[str]] 47 | List[Row] -> [, List[Row]] 48 | List[Row] -> all_rows 49 | List[str] -> [, List[Row], StringColumn] 50 | Number -> [, List[Row], List[Row], NumberColumn] 51 | Number -> [, List[Row], NumberColumn] 52 | Number -> [, List[Row]] 53 | -------------------------------------------------------------------------------- /test_fixtures/wikitables/serialization/vocabulary/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | what 3 | was 4 | the 5 | last 6 | a 7 | ? 8 | year 9 | where 10 | this 11 | team 12 | part 13 | of 14 | usl 15 | - 16 | league 17 | in 18 | city 19 | did 20 | piotr 21 | 's 22 | 1st 23 | place 24 | finish 25 | occur 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .semparse_test_case import SemparseTestCase, ModelTestCase 2 | -------------------------------------------------------------------------------- /tests/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/common/__init__.py -------------------------------------------------------------------------------- /tests/common/date_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import SemparseTestCase 4 | 5 | from allennlp_semparse.common import Date, ExecutionError 6 | 7 | 8 | class TestDate(SemparseTestCase): 9 | def test_date_comparison_works(self): 10 | assert Date(2013, 12, 31) > Date(2013, 12, 30) 11 | assert Date(2013, 12, 31) == Date(2013, 12, -1) 12 | assert Date(2013, -1, -1) >= Date(2013, 12, 31) 13 | assert (Date(2013, 12, -1) > Date(2013, 12, 31)) is False 14 | with pytest.raises(ExecutionError, match="only compare Dates with Dates"): 15 | assert (Date(2013, 12, 31) > 2013) is False 16 | with pytest.raises(ExecutionError, match="only compare Dates with Dates"): 17 | assert (Date(2013, 12, 31) >= 2013) is False 18 | with pytest.raises(ExecutionError, match="only compare Dates with Dates"): 19 | assert Date(2013, 12, 31) != 2013 20 | assert (Date(2018, 1, 1) >= Date(-1, 2, 1)) is False 21 | assert (Date(2018, 1, 1) < Date(-1, 2, 1)) is False 22 | # When year is unknown in both cases, we can compare months and days. 23 | assert Date(-1, 2, 1) < Date(-1, 2, 3) 24 | # If both year and month are not know in both cases, the comparison is undefined, and both 25 | # < and >= return False. 26 | assert (Date(-1, -1, 1) < Date(-1, -1, 3)) is False 27 | assert (Date(-1, -1, 1) >= Date(-1, -1, 3)) is False 28 | # Same when year is known, but months are not. 29 | assert (Date(2018, -1, 1) < Date(2018, -1, 3)) is False 30 | assert (Date(2018, -1, 1) >= Date(2018, -1, 3)) is False 31 | -------------------------------------------------------------------------------- /tests/common/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/common/sql/__init__.py -------------------------------------------------------------------------------- /tests/common/util_test.py: -------------------------------------------------------------------------------- 1 | from .. import SemparseTestCase 2 | 3 | from allennlp_semparse.common import util 4 | 5 | 6 | class TestSemparseUtil(SemparseTestCase): 7 | def test_lisp_to_nested_expression(self): 8 | logical_form = "((reverse fb:row.row.year) (fb:row.row.league fb:cell.usl_a_league))" 9 | expression = util.lisp_to_nested_expression(logical_form) 10 | assert expression == [ 11 | ["reverse", "fb:row.row.year"], 12 | ["fb:row.row.league", "fb:cell.usl_a_league"], 13 | ] 14 | logical_form = "(count (and (division 1) (tier (!= null))))" 15 | expression = util.lisp_to_nested_expression(logical_form) 16 | assert expression == ["count", ["and", ["division", "1"], ["tier", ["!=", "null"]]]] 17 | -------------------------------------------------------------------------------- /tests/common/wikitables/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/common/wikitables/__init__.py -------------------------------------------------------------------------------- /tests/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/dataset_readers/__init__.py -------------------------------------------------------------------------------- /tests/dataset_readers/atis_test.py: -------------------------------------------------------------------------------- 1 | from allennlp.common.file_utils import cached_path 2 | from allennlp_semparse.dataset_readers import AtisDatasetReader 3 | from .. import SemparseTestCase 4 | 5 | from allennlp_semparse.parsimonious_languages.worlds import AtisWorld 6 | 7 | 8 | class TestAtisReader(SemparseTestCase): 9 | def test_atis_keep_unparseable(self): 10 | database_file = cached_path("https://allennlp.s3.amazonaws.com/datasets/atis/atis.db") 11 | reader = AtisDatasetReader(database_file=database_file, keep_if_unparseable=True) 12 | instance = reader.text_to_instance( 13 | utterances=["show me the one way flights from detroit me to westchester county"], 14 | sql_query_labels=["this is not a query that can be parsed"], 15 | ) 16 | 17 | # If we have a query that can't be parsed, we check that it only has one element in the list 18 | # of index fields and that index is the padding index, -1. 19 | assert len(instance.fields["target_action_sequence"].field_list) == 1 20 | assert instance.fields["target_action_sequence"].field_list[0].sequence_index == -1 21 | 22 | def test_atis_read_from_file(self): 23 | data_path = SemparseTestCase.FIXTURES_ROOT / "data" / "atis" / "sample.json" 24 | database_file = "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" 25 | reader = AtisDatasetReader(database_file=database_file) 26 | 27 | instances = list(reader.read(str(data_path))) 28 | 29 | assert len(instances) == 13 30 | instance = instances[0] 31 | 32 | assert set(instance.fields.keys()) == { 33 | "utterance", 34 | "actions", 35 | "world", 36 | "sql_queries", 37 | "target_action_sequence", 38 | "linking_scores", 39 | } 40 | 41 | assert [t.text for t in instance.fields["utterance"].tokens] == [ 42 | "show", 43 | "me", 44 | "the", 45 | "one", 46 | "way", 47 | "flights", 48 | "from", 49 | "detroit", 50 | "to", 51 | "westchester", 52 | "county", 53 | ] 54 | 55 | assert isinstance(instance.fields["world"].as_tensor({}), AtisWorld) 56 | 57 | world = instance.fields["world"].metadata 58 | assert set(world.valid_actions["number"]) == { 59 | 'number -> ["1"]', 60 | 'number -> ["0"]', 61 | 'number -> ["41"]', 62 | 'number -> ["60"]', 63 | } 64 | 65 | assert world.linked_entities["string"]["airport_airport_code_string -> [\"'DTW'\"]"][2] == [ 66 | 0, 67 | 0, 68 | 0, 69 | 0, 70 | 0, 71 | 0, 72 | 0, 73 | 1, 74 | 0, 75 | 0, 76 | 0, 77 | ] # ``detroit`` -> ``DTW`` 78 | assert world.linked_entities["string"]["flight_stop_stop_airport_string -> [\"'DTW'\"]"][ 79 | 2 80 | ] == [ 81 | 0, 82 | 0, 83 | 0, 84 | 0, 85 | 0, 86 | 0, 87 | 0, 88 | 1, 89 | 0, 90 | 0, 91 | 0, 92 | ] # ``detroit`` -> ``DTW`` 93 | assert world.linked_entities["string"]["city_city_code_string -> [\"'DDTT'\"]"][2] == [ 94 | 0, 95 | 0, 96 | 0, 97 | 0, 98 | 0, 99 | 0, 100 | 0, 101 | 1, 102 | 0, 103 | 0, 104 | 0, 105 | ] # ``detroit`` -> ``DDTT`` 106 | assert world.linked_entities["string"]["fare_basis_economy_string -> [\"'NO'\"]"][2] == [ 107 | 0, 108 | 0, 109 | 0, 110 | 1, 111 | 1, 112 | 0, 113 | 0, 114 | 0, 115 | 0, 116 | 0, 117 | 0, 118 | ] # ``one way`` -> ``NO`` 119 | assert world.linked_entities["string"][ 120 | "city_city_name_string -> [\"'WESTCHESTER COUNTY'\"]" 121 | ][2] == [ 122 | 0, 123 | 0, 124 | 0, 125 | 0, 126 | 0, 127 | 0, 128 | 0, 129 | 0, 130 | 0, 131 | 1, 132 | 1, 133 | ] # ``westchester county`` -> ``WESTCHESTER COUNTY`` 134 | assert world.linked_entities["string"]["city_city_code_string -> [\"'HHPN'\"]"][2] == [ 135 | 0, 136 | 0, 137 | 0, 138 | 0, 139 | 0, 140 | 0, 141 | 0, 142 | 0, 143 | 0, 144 | 1, 145 | 1, 146 | ] # ``westchester county`` -> ``HHPN`` 147 | -------------------------------------------------------------------------------- /tests/dataset_readers/wikitables_test.py: -------------------------------------------------------------------------------- 1 | from allennlp.common import Params 2 | from .. import SemparseTestCase 3 | 4 | from allennlp_semparse.dataset_readers import WikiTablesDatasetReader 5 | from allennlp_semparse.domain_languages import WikiTablesLanguage 6 | 7 | 8 | def assert_dataset_correct(dataset): 9 | instances = list(dataset) 10 | assert len(instances) == 2 11 | instance = instances[0] 12 | 13 | assert instance.fields.keys() == { 14 | "question", 15 | "metadata", 16 | "table", 17 | "world", 18 | "actions", 19 | "target_action_sequences", 20 | "target_values", 21 | } 22 | 23 | question_tokens = [ 24 | "what", 25 | "was", 26 | "the", 27 | "last", 28 | "year", 29 | "where", 30 | "this", 31 | "team", 32 | "was", 33 | "a", 34 | "part", 35 | "of", 36 | "the", 37 | "usl", 38 | "a", 39 | "-", 40 | "league", 41 | "?", 42 | ] 43 | assert [t.text for t in instance.fields["question"].tokens] == question_tokens 44 | 45 | assert instance.fields["metadata"].as_tensor({})["question_tokens"] == question_tokens 46 | 47 | # The content of this will be tested indirectly by checking the actions; we'll just make 48 | # sure we get a WikiTablesWorld object in here. 49 | assert isinstance(instance.fields["world"].as_tensor({}), WikiTablesLanguage) 50 | 51 | action_fields = instance.fields["actions"].field_list 52 | actions = [action_field.rule for action_field in action_fields] 53 | 54 | # We should have been able to read all of the logical forms in the file. If one of them can't 55 | # be parsed, or the action sequences can't be mapped correctly, the DatasetReader will skip the 56 | # logical form, log an error, and keep going (i.e., it won't crash). 57 | num_action_sequences = len(instance.fields["target_action_sequences"].field_list) 58 | assert num_action_sequences == 10 59 | 60 | # We should have sorted the logical forms by length. This is the action sequence 61 | # corresponding to the shortest logical form in the examples _by tree size_, which is _not_ the 62 | # first one in the file, or the shortest logical form by _string length_. It's also a totally 63 | # made up logical form, just to demonstrate that we're sorting things correctly. 64 | action_sequence = instance.fields["target_action_sequences"].field_list[0] 65 | action_indices = [action.sequence_index for action in action_sequence.field_list] 66 | actions = [actions[i] for i in action_indices] 67 | assert actions == [ 68 | "@start@ -> Number", 69 | "Number -> [, List[Row], NumberColumn]", 70 | " -> average", 71 | "List[Row] -> [, List[Row]]", 72 | " -> last", 73 | "List[Row] -> [, List[Row], StringColumn, List[str]]", 74 | " -> filter_in", 75 | "List[Row] -> all_rows", 76 | "StringColumn -> string_column:league", 77 | "List[str] -> string:usl_a_league", 78 | "NumberColumn -> number_column:year", 79 | ] 80 | 81 | 82 | class TestWikiTablesDatasetReader(SemparseTestCase): 83 | def test_reader_reads(self): 84 | offline_search_directory = ( 85 | self.FIXTURES_ROOT / "data" / "wikitables" / "action_space_walker_output" 86 | ) 87 | params = { 88 | "tables_directory": self.FIXTURES_ROOT / "data" / "wikitables", 89 | "offline_logical_forms_directory": offline_search_directory, 90 | } 91 | reader = WikiTablesDatasetReader.from_params(Params(params)) 92 | dataset = reader.read(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") 93 | assert_dataset_correct(dataset) 94 | 95 | def test_reader_reads_with_lfs_in_tarball(self): 96 | offline_search_directory = ( 97 | self.FIXTURES_ROOT 98 | / "data" 99 | / "wikitables" 100 | / "action_space_walker_output_with_single_tarball" 101 | ) 102 | params = { 103 | "tables_directory": self.FIXTURES_ROOT / "data" / "wikitables", 104 | "offline_logical_forms_directory": offline_search_directory, 105 | } 106 | reader = WikiTablesDatasetReader.from_params(Params(params)) 107 | dataset = reader.read(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") 108 | assert_dataset_correct(dataset) 109 | -------------------------------------------------------------------------------- /tests/domain_languages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/domain_languages/__init__.py -------------------------------------------------------------------------------- /tests/fields/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/fields/__init__.py -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/atis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/models/atis/__init__.py -------------------------------------------------------------------------------- /tests/models/atis/atis_grammar_statelet_test.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_almost_equal 2 | import torch 3 | 4 | from allennlp.common import Params 5 | 6 | from allennlp_semparse.models.atis.atis_semantic_parser import AtisSemanticParser 7 | from allennlp_semparse.parsimonious_languages.worlds import AtisWorld 8 | from allennlp_semparse.state_machines.states import GrammarStatelet 9 | from ... import SemparseTestCase 10 | 11 | 12 | class TestAtisGrammarStatelet(SemparseTestCase): 13 | def test_atis_grammar_statelet(self): 14 | world = AtisWorld( 15 | [("give me all flights from boston to " "philadelphia next week arriving after lunch")] 16 | ) 17 | action_sequence = [ 18 | 'statement -> [query, ";"]', 19 | 'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, ' 20 | 'where_clause, ")"]', 21 | 'distinct -> ["DISTINCT"]', 22 | "select_results -> [col_refs]", 23 | 'col_refs -> [col_ref, ",", col_refs]', 24 | 'col_ref -> ["city", ".", "city_code"]', 25 | "col_refs -> [col_ref]", 26 | 'col_ref -> ["city", ".", "city_name"]', 27 | "table_refs -> [table_name]", 28 | 'table_name -> ["city"]', 29 | 'where_clause -> ["WHERE", "(", conditions, ")"]', 30 | "conditions -> [condition]", 31 | "condition -> [biexpr]", 32 | 'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]', 33 | 'binaryop -> ["="]', 34 | "city_city_name_string -> [\"'BOSTON'\"]", 35 | ] 36 | 37 | grammar_state = GrammarStatelet( 38 | ["statement"], world.valid_actions, AtisSemanticParser.is_nonterminal 39 | ) 40 | for action in action_sequence: 41 | grammar_state = grammar_state.take_action(action) 42 | assert grammar_state._nonterminal_stack == [] 43 | -------------------------------------------------------------------------------- /tests/models/atis/atis_semantic_parser_test.py: -------------------------------------------------------------------------------- 1 | from flaky import flaky 2 | 3 | from ... import ModelTestCase 4 | from allennlp_semparse.parsimonious_languages.contexts.sql_context_utils import ( 5 | action_sequence_to_sql, 6 | ) 7 | 8 | 9 | class TestAtisSemanticParser(ModelTestCase): 10 | def setup_method(self): 11 | super().setup_method() 12 | self.set_up_model( 13 | str(self.FIXTURES_ROOT / "atis" / "experiment.json"), 14 | str(self.FIXTURES_ROOT / "data" / "atis" / "sample.json"), 15 | ) 16 | 17 | @flaky 18 | def test_atis_model_can_train_save_and_load(self): 19 | self.ensure_model_can_train_save_and_load(self.param_file) 20 | 21 | def test_action_sequence_to_sql(self): 22 | action_sequence = [ 23 | 'statement -> [query, ";"]', 24 | 'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, ' 25 | 'where_clause, ")"]', 26 | 'distinct -> ["DISTINCT"]', 27 | "select_results -> [col_refs]", 28 | 'col_refs -> [col_ref, ",", col_refs]', 29 | 'col_ref -> ["city", ".", "city_code"]', 30 | "col_refs -> [col_ref]", 31 | 'col_ref -> ["city", ".", "city_name"]', 32 | "table_refs -> [table_name]", 33 | 'table_name -> ["city"]', 34 | 'where_clause -> ["WHERE", "(", conditions, ")"]', 35 | "conditions -> [condition]", 36 | "condition -> [biexpr]", 37 | 'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]', 38 | 'binaryop -> ["="]', 39 | "city_city_name_string -> [\"'BOSTON'\"]", 40 | ] 41 | 42 | sql_query = action_sequence_to_sql(action_sequence) 43 | assert ( 44 | sql_query == "( SELECT DISTINCT city . city_code , city . city_name " 45 | "FROM city WHERE ( city . city_name = 'BOSTON' ) ) ;" 46 | ) 47 | -------------------------------------------------------------------------------- /tests/models/nlvr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/models/nlvr/__init__.py -------------------------------------------------------------------------------- /tests/models/nlvr/nlvr_coverage_semantic_parser_test.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_almost_equal 2 | import torch 3 | import pytest 4 | 5 | from allennlp.common import Params 6 | from ... import ModelTestCase 7 | from allennlp.data import Vocabulary 8 | from allennlp.models import Model 9 | from allennlp.models.archival import load_archive 10 | 11 | 12 | class TestNlvrCoverageSemanticParser(ModelTestCase): 13 | def setup_method(self): 14 | super().setup_method() 15 | self.set_up_model( 16 | self.FIXTURES_ROOT / "nlvr_coverage_semantic_parser" / "experiment.json", 17 | self.FIXTURES_ROOT / "data" / "nlvr" / "sample_grouped_data.jsonl", 18 | ) 19 | 20 | def test_model_can_train_save_and_load(self): 21 | self.ensure_model_can_train_save_and_load(self.param_file) 22 | 23 | def test_ungrouped_model_can_train_save_and_load(self): 24 | self.ensure_model_can_train_save_and_load( 25 | self.FIXTURES_ROOT / "nlvr_coverage_semantic_parser" / "ungrouped_experiment.json" 26 | ) 27 | 28 | def test_mml_initialized_model_can_train_save_and_load(self): 29 | self.ensure_model_can_train_save_and_load( 30 | self.FIXTURES_ROOT / "nlvr_coverage_semantic_parser" / "mml_init_experiment.json" 31 | ) 32 | 33 | def test_get_checklist_info(self): 34 | # Creating a fake all_actions field where actions 0, 2 and 4 are terminal productions. 35 | all_actions = [ 36 | (" -> top", True, None), 37 | ("fake_action", True, None), 38 | ("Color -> color_black", True, None), 39 | ("fake_action2", True, None), 40 | ("int -> 6", True, None), 41 | ] 42 | # Of the actions above, those at indices 0 and 4 are on the agenda, and there are padding 43 | # indices at the end. 44 | test_agenda = torch.Tensor([[0], [4], [-1], [-1]]) 45 | checklist_info = self.model._get_checklist_info(test_agenda, all_actions) 46 | target_checklist, terminal_actions, checklist_mask = checklist_info 47 | assert_almost_equal(target_checklist.data.numpy(), [[1], [0], [1]]) 48 | assert_almost_equal(terminal_actions.data.numpy(), [[0], [2], [4]]) 49 | assert_almost_equal(checklist_mask.data.numpy(), [[1], [1], [1]]) 50 | 51 | def test_initialize_weights_from_archive(self): 52 | original_model_parameters = self.model.named_parameters() 53 | original_model_weights = { 54 | name: parameter.data.clone().numpy() for name, parameter in original_model_parameters 55 | } 56 | mml_model_archive_file = ( 57 | self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz" 58 | ) 59 | archive = load_archive(mml_model_archive_file) 60 | archived_model_parameters = archive.model.named_parameters() 61 | self.model._initialize_weights_from_archive(archive) 62 | changed_model_parameters = dict(self.model.named_parameters()) 63 | for name, archived_parameter in archived_model_parameters: 64 | archived_weight = archived_parameter.data.numpy() 65 | original_weight = original_model_weights[name] 66 | changed_weight = changed_model_parameters[name].data.numpy() 67 | # We want to make sure that the weights in the original model have indeed been changed 68 | # after a call to ``_initialize_weights_from_archive``. 69 | with pytest.raises(AssertionError, match="Arrays are not almost equal"): 70 | assert_almost_equal(original_weight, changed_weight) 71 | # This also includes the sentence token embedder. Those weights will be the same 72 | # because the two models have the same vocabulary. 73 | assert_almost_equal(archived_weight, changed_weight) 74 | 75 | def test_get_vocab_index_mapping(self): 76 | mml_model_archive_file = ( 77 | self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz" 78 | ) 79 | archive = load_archive(mml_model_archive_file) 80 | mapping = self.model._get_vocab_index_mapping(archive.model.vocab) 81 | expected_mapping = [(i, i) for i in range(16)] 82 | assert mapping == expected_mapping 83 | 84 | new_vocab = Vocabulary() 85 | 86 | def copy_token_at_index(i): 87 | token = self.vocab.get_token_from_index(i, "tokens") 88 | new_vocab.add_token_to_namespace(token, "tokens") 89 | 90 | copy_token_at_index(5) 91 | copy_token_at_index(7) 92 | copy_token_at_index(10) 93 | mapping = self.model._get_vocab_index_mapping(new_vocab) 94 | # Mapping of indices from model vocabulary to new vocabulary. 0 and 1 are padding and unk 95 | # tokens. 96 | assert mapping == [(0, 0), (1, 1), (5, 2), (7, 3), (10, 4)] 97 | -------------------------------------------------------------------------------- /tests/models/nlvr/nlvr_direct_semantic_parser_test.py: -------------------------------------------------------------------------------- 1 | from ... import ModelTestCase 2 | 3 | 4 | class TestNlvrDirectSemanticParser(ModelTestCase): 5 | def setup_method(self): 6 | super().setup_method() 7 | self.set_up_model( 8 | self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "experiment.json", 9 | self.FIXTURES_ROOT / "data" / "nlvr" / "sample_processed_data.jsonl", 10 | ) 11 | 12 | def test_model_can_train_save_and_load(self): 13 | self.ensure_model_can_train_save_and_load(self.param_file) 14 | -------------------------------------------------------------------------------- /tests/models/quarel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/models/quarel/__init__.py -------------------------------------------------------------------------------- /tests/models/text2sql_parser_test.py: -------------------------------------------------------------------------------- 1 | from .. import ModelTestCase 2 | 3 | from allennlp_semparse.state_machines.states import GrammarStatelet 4 | from allennlp_semparse.models.text2sql_parser import Text2SqlParser 5 | from allennlp_semparse.parsimonious_languages.worlds.text2sql_world import Text2SqlWorld 6 | 7 | 8 | class TestText2SqlParser(ModelTestCase): 9 | def setup_method(self): 10 | super().setup_method() 11 | 12 | self.set_up_model( 13 | str(self.FIXTURES_ROOT / "text2sql" / "experiment.json"), 14 | str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants_tiny.json"), 15 | ) 16 | self.schema = str(self.FIXTURES_ROOT / "data" / "text2sql" / "restaurants-schema.csv") 17 | 18 | def test_model_can_train_save_and_load(self): 19 | self.ensure_model_can_train_save_and_load(self.param_file) 20 | 21 | def test_grammar_statelet(self): 22 | valid_actions = None 23 | world = Text2SqlWorld(self.schema) 24 | 25 | sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"] 26 | action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql) 27 | 28 | grammar_state = GrammarStatelet( 29 | ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True 30 | ) 31 | for action in action_sequence: 32 | grammar_state = grammar_state.take_action(action) 33 | assert grammar_state._nonterminal_stack == [] 34 | -------------------------------------------------------------------------------- /tests/models/wikitables/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/models/wikitables/__init__.py -------------------------------------------------------------------------------- /tests/models/wikitables/wikitables_erm_semantic_parser_test.py: -------------------------------------------------------------------------------- 1 | from flaky import flaky 2 | 3 | from ... import ModelTestCase 4 | 5 | 6 | class TestWikiTablesVariableFreeErm(ModelTestCase): 7 | def setup_method(self): 8 | super().setup_method() 9 | config_path = self.FIXTURES_ROOT / "wikitables" / "experiment-erm.json" 10 | data_path = self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples" 11 | self.set_up_model(config_path, data_path) 12 | 13 | @flaky 14 | def test_model_can_train_save_and_load(self): 15 | # We have very few embedded actions on our agenda, and so it's rare that this parameter 16 | # actually gets used. We know this parameter works from our NLVR ERM test, so it's easier 17 | # to just ignore it here than to try to finagle the test to make it so this has a non-zero 18 | # gradient. 19 | ignore = {"_decoder_step._checklist_multiplier"} 20 | self.ensure_model_can_train_save_and_load(self.param_file, gradients_to_ignore=ignore) 21 | -------------------------------------------------------------------------------- /tests/nltk_languages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/nltk_languages/__init__.py -------------------------------------------------------------------------------- /tests/nltk_languages/contexts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/nltk_languages/contexts/__init__.py -------------------------------------------------------------------------------- /tests/nltk_languages/type_declarations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/nltk_languages/type_declarations/__init__.py -------------------------------------------------------------------------------- /tests/nltk_languages/worlds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/nltk_languages/worlds/__init__.py -------------------------------------------------------------------------------- /tests/nltk_languages/worlds/world_test.py: -------------------------------------------------------------------------------- 1 | from ... import SemparseTestCase 2 | from allennlp_semparse.nltk_languages.worlds.world import World 3 | 4 | 5 | class FakeWorldWithoutRecursion(World): 6 | def all_possible_actions(self): 7 | # The logical forms this grammar allows are 8 | # (unary_function argument) 9 | # (binary_function argument argument) 10 | actions = [ 11 | "@start@ -> t", 12 | "t -> [, e]", 13 | " -> unary_function", 14 | " -> [>, e]", 15 | "> -> binary_function", 16 | "e -> argument", 17 | ] 18 | return actions 19 | 20 | 21 | class FakeWorldWithRecursion(FakeWorldWithoutRecursion): 22 | def all_possible_actions(self): 23 | # In addition to the forms allowed by ``FakeWorldWithoutRecursion``, this world allows 24 | # (unary_function (identity .... (argument))) 25 | # (binary_function (identity .... (argument)) (identity .... (argument))) 26 | actions = super(FakeWorldWithRecursion, self).all_possible_actions() 27 | actions.extend(["e -> [, e]", " -> identity"]) 28 | return actions 29 | 30 | 31 | class TestWorld(SemparseTestCase): 32 | def setup_method(self): 33 | super().setup_method() 34 | self.world_without_recursion = FakeWorldWithoutRecursion() 35 | self.world_with_recursion = FakeWorldWithRecursion() 36 | 37 | def test_get_paths_to_root_without_recursion(self): 38 | argument_paths = self.world_without_recursion.get_paths_to_root("e -> argument") 39 | assert argument_paths == [ 40 | ["e -> argument", "t -> [, e]", "@start@ -> t"], 41 | ["e -> argument", " -> [>, e]", "t -> [, e]", "@start@ -> t"], 42 | ] 43 | unary_function_paths = self.world_without_recursion.get_paths_to_root( 44 | " -> unary_function" 45 | ) 46 | assert unary_function_paths == [ 47 | [" -> unary_function", "t -> [, e]", "@start@ -> t"] 48 | ] 49 | binary_function_paths = self.world_without_recursion.get_paths_to_root( 50 | "> -> binary_function" 51 | ) 52 | assert binary_function_paths == [ 53 | [ 54 | "> -> binary_function", 55 | " -> [>, e]", 56 | "t -> [, e]", 57 | "@start@ -> t", 58 | ] 59 | ] 60 | 61 | def test_get_paths_to_root_with_recursion(self): 62 | argument_paths = self.world_with_recursion.get_paths_to_root("e -> argument") 63 | # Argument now has 4 paths, and the two new paths are with the identity function occurring 64 | # (only once) within unary and binary functions. 65 | assert argument_paths == [ 66 | ["e -> argument", "t -> [, e]", "@start@ -> t"], 67 | ["e -> argument", " -> [>, e]", "t -> [, e]", "@start@ -> t"], 68 | ["e -> argument", "e -> [, e]", "t -> [, e]", "@start@ -> t"], 69 | [ 70 | "e -> argument", 71 | "e -> [, e]", 72 | " -> [>, e]", 73 | "t -> [, e]", 74 | "@start@ -> t", 75 | ], 76 | ] 77 | identity_paths = self.world_with_recursion.get_paths_to_root(" -> identity") 78 | # Two identity paths, one through each of unary and binary function productions. 79 | assert identity_paths == [ 80 | [" -> identity", "e -> [, e]", "t -> [, e]", "@start@ -> t"], 81 | [ 82 | " -> identity", 83 | "e -> [, e]", 84 | " -> [>, e]", 85 | "t -> [, e]", 86 | "@start@ -> t", 87 | ], 88 | ] 89 | -------------------------------------------------------------------------------- /tests/parsimonious_languages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/parsimonious_languages/__init__.py -------------------------------------------------------------------------------- /tests/parsimonious_languages/contexts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/parsimonious_languages/contexts/__init__.py -------------------------------------------------------------------------------- /tests/parsimonious_languages/executors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/parsimonious_languages/executors/__init__.py -------------------------------------------------------------------------------- /tests/parsimonious_languages/executors/sql_executor_test.py: -------------------------------------------------------------------------------- 1 | from ... import SemparseTestCase 2 | 3 | from allennlp_semparse.parsimonious_languages.executors import SqlExecutor 4 | 5 | 6 | class TestSqlExecutor(SemparseTestCase): 7 | def setup_method(self): 8 | super().setup_method() 9 | self._database_file = "https://allennlp.s3.amazonaws.com/datasets/atis/atis.db" 10 | 11 | def test_sql_accuracy_is_scored_correctly(self): 12 | sql_query_label = ( 13 | "( SELECT airport_service . airport_code " 14 | "FROM airport_service " 15 | "WHERE airport_service . city_code IN ( " 16 | "SELECT city . city_code FROM city " 17 | "WHERE city.city_name = 'BOSTON' ) ) ;" 18 | ) 19 | 20 | executor = SqlExecutor(self._database_file) 21 | postprocessed_sql_query_label = executor.postprocess_query_sqlite(sql_query_label) 22 | # If the predicted query and the label are the same, then we should get 1. 23 | assert ( 24 | executor.evaluate_sql_query( 25 | postprocessed_sql_query_label, [postprocessed_sql_query_label] 26 | ) 27 | == 1 28 | ) 29 | 30 | predicted_sql_query = ( 31 | "( SELECT airport_service . airport_code " 32 | "FROM airport_service " 33 | "WHERE airport_service . city_code IN ( " 34 | "SELECT city . city_code FROM city " 35 | "WHERE city.city_name = 'SEATTLE' ) ) ;" 36 | ) 37 | 38 | postprocessed_predicted_sql_query = executor.postprocess_query_sqlite(predicted_sql_query) 39 | # If the predicted query and the label are different we should get 0. 40 | assert ( 41 | executor.evaluate_sql_query( 42 | postprocessed_predicted_sql_query, [postprocessed_sql_query_label] 43 | ) 44 | == 0 45 | ) 46 | -------------------------------------------------------------------------------- /tests/parsimonious_languages/worlds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/parsimonious_languages/worlds/__init__.py -------------------------------------------------------------------------------- /tests/predictors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/predictors/__init__.py -------------------------------------------------------------------------------- /tests/predictors/atis_parser_test.py: -------------------------------------------------------------------------------- 1 | from flaky import flaky 2 | 3 | from .. import SemparseTestCase 4 | from allennlp.models.archival import load_archive 5 | from allennlp.predictors import Predictor 6 | 7 | 8 | class TestAtisParserPredictor(SemparseTestCase): 9 | @flaky 10 | def test_atis_parser_uses_named_inputs(self): 11 | inputs = {"utterance": "show me the flights to seattle"} 12 | 13 | archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz" 14 | archive = load_archive(archive_path) 15 | predictor = Predictor.from_archive(archive, "atis-parser") 16 | 17 | result = predictor.predict_json(inputs) 18 | action_sequence = result.get("best_action_sequence") 19 | if action_sequence: 20 | # An untrained model will likely get into a loop, and not produce at finished states. 21 | # When the model gets into a loop it will not produce any valid SQL, so we don't get 22 | # any actions. This basically just tests if the model runs. 23 | assert len(action_sequence) > 1 24 | assert all([isinstance(action, str) for action in action_sequence]) 25 | predicted_sql_query = result.get("predicted_sql_query") 26 | assert predicted_sql_query is not None 27 | 28 | @flaky 29 | def test_atis_parser_predicted_sql_present(self): 30 | inputs = {"utterance": "show me flights to seattle"} 31 | 32 | archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz" 33 | archive = load_archive(archive_path) 34 | predictor = Predictor.from_archive(archive, "atis-parser") 35 | 36 | result = predictor.predict_json(inputs) 37 | predicted_sql_query = result.get("predicted_sql_query") 38 | assert predicted_sql_query is not None 39 | 40 | @flaky 41 | def test_atis_parser_batch_predicted_sql_present(self): 42 | inputs = [{"utterance": "show me flights to seattle"}] 43 | 44 | archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz" 45 | archive = load_archive(archive_path) 46 | predictor = Predictor.from_archive(archive, "atis-parser") 47 | 48 | result = predictor.predict_batch_json(inputs) 49 | predicted_sql_query = result[0].get("predicted_sql_query") 50 | assert predicted_sql_query is not None 51 | -------------------------------------------------------------------------------- /tests/predictors/wikitables_parser_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .. import SemparseTestCase 4 | from allennlp.models.archival import load_archive 5 | from allennlp.predictors import Predictor 6 | 7 | 8 | class TestWikiTablesParserPredictor(SemparseTestCase): 9 | def test_uses_named_inputs(self): 10 | inputs = {"question": "names", "table": "name\tdate\nmatt\t2017\npradeep\t2018"} 11 | 12 | archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz" 13 | archive = load_archive(archive_path) 14 | predictor = Predictor.from_archive(archive, "wikitables-parser") 15 | 16 | result = predictor.predict_json(inputs) 17 | 18 | action_sequence = result.get("best_action_sequence") 19 | if action_sequence: 20 | # We don't currently disallow endless loops in the decoder, and an untrained seq2seq 21 | # model will easily get itself into a loop. An endless loop isn't a finished logical 22 | # form, so decoding doesn't return any finished states, which means no actions. So, 23 | # sadly, we don't have a great test here. This is just testing that the predictor 24 | # runs, basically. 25 | assert len(action_sequence) > 1 26 | assert all([isinstance(action, str) for action in action_sequence]) 27 | 28 | logical_form = result.get("logical_form") 29 | assert logical_form is not None 30 | 31 | def test_answer_present(self): 32 | inputs = { 33 | "question": "Who is 18 years old?", 34 | "table": "Name\tAge\nShallan\t16\nKaladin\t18", 35 | } 36 | 37 | archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz" 38 | archive = load_archive(archive_path) 39 | predictor = Predictor.from_archive(archive, "wikitables-parser") 40 | 41 | result = predictor.predict_json(inputs) 42 | answer = result.get("answer") 43 | assert answer is not None 44 | 45 | def test_interactive_beam_search(self): 46 | inputs = { 47 | "question": "Who is 18 years old?", 48 | "table": "Name\tAge\nShallan\t16\nKaladin\t18", 49 | } 50 | 51 | archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz" 52 | archive = load_archive(archive_path) 53 | predictor = Predictor.from_archive(archive, "wikitables-parser") 54 | 55 | # This is not the start of the best sequence, but it will be once we force it. 56 | initial_tokens = [ 57 | "@start@ -> Number", 58 | "Number -> [, List[Row], NumberColumn]", 59 | ] 60 | 61 | # First let's try an unforced one. Its initial tokens should not be ours. 62 | result = predictor.predict_json(inputs) 63 | best_action_sequence = result["best_action_sequence"] 64 | assert best_action_sequence 65 | assert best_action_sequence[:2] != initial_tokens 66 | 67 | # Now let's try forcing it down the path of `initial_sequence` 68 | inputs["initial_sequence"] = initial_tokens 69 | result = predictor.predict_json(inputs) 70 | best_action_sequence = result["best_action_sequence"] 71 | assert best_action_sequence[:2] == initial_tokens 72 | 73 | # Should get choices back from beam search 74 | beam_search_choices = result["choices"] 75 | 76 | # Make sure that our forced choices appear as beam_search_choices. 77 | for choices, initial_token in zip(beam_search_choices, initial_tokens): 78 | assert any(token == initial_token for _, token in choices) 79 | 80 | # Should get back beams too 81 | beam_snapshots = result["beam_snapshots"] 82 | assert len(beam_snapshots) == 1 83 | assert 0 in beam_snapshots 84 | beams = beam_snapshots[0] 85 | 86 | for idx, (beam, action) in enumerate(zip(beams, best_action_sequence)): 87 | # First beam should have 1-element sequences, etc... 88 | assert all(len(sequence) == idx + 1 for _, sequence in beam) 89 | assert any(sequence[-1] == action for _, sequence in beam) 90 | 91 | def test_answer_present_with_batch_predict(self): 92 | inputs = [ 93 | {"question": "Who is 18 years old?", "table": "Name\tAge\nShallan\t16\nKaladin\t18"} 94 | ] 95 | 96 | archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz" 97 | archive = load_archive(archive_path) 98 | predictor = Predictor.from_archive(archive, "wikitables-parser") 99 | 100 | result = predictor.predict_batch_json(inputs) 101 | answer = result[0].get("answer") 102 | assert answer is not None 103 | -------------------------------------------------------------------------------- /tests/semparse_test_case.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from allennlp.common.testing import AllenNlpTestCase, ModelTestCase as AllenNlpModelTestCase 4 | 5 | # These imports are to get all of the items registered that we need. 6 | from allennlp_semparse import models, dataset_readers, predictors 7 | 8 | ROOT = (pathlib.Path(__file__).parent / "..").resolve() 9 | 10 | 11 | class SemparseTestCase(AllenNlpTestCase): 12 | PROJECT_ROOT = ROOT 13 | MODULE_ROOT = PROJECT_ROOT / "allennlp_semparse" 14 | TOOLS_ROOT = None # just removing the reference from super class 15 | TESTS_ROOT = PROJECT_ROOT / "tests" 16 | FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures" 17 | 18 | 19 | class ModelTestCase(AllenNlpModelTestCase): 20 | PROJECT_ROOT = ROOT 21 | MODULE_ROOT = PROJECT_ROOT / "allennlp_semparse" 22 | TOOLS_ROOT = None # just removing the reference from super class 23 | TESTS_ROOT = PROJECT_ROOT / "tests" 24 | FIXTURES_ROOT = PROJECT_ROOT / "test_fixtures" 25 | -------------------------------------------------------------------------------- /tests/state_machines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/state_machines/__init__.py -------------------------------------------------------------------------------- /tests/state_machines/constrained_beam_search_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import SemparseTestCase 4 | 5 | from allennlp_semparse.state_machines import ConstrainedBeamSearch 6 | from .simple_transition_system import SimpleState, SimpleTransitionFunction 7 | 8 | 9 | class TestConstrainedBeamSearch(SemparseTestCase): 10 | def test_search(self): 11 | # The simple transition system starts at some number, adds one or two at each state, and 12 | # tries to get to 4. The highest scoring path has the shortest length and the highest 13 | # numbers (so always add two, unless you're at 3). From -3, there are lots of possible 14 | # sequences: [-2, -1, 0, 1, 2, 3, 4], [-1, 1, 3, 4], ... We'll specify a few of those up 15 | # front as "allowed", and use that to test the constrained beam search implementation. 16 | initial_state = SimpleState([0], [[]], [torch.Tensor([0.0])], [-3]) 17 | beam_size = 3 18 | allowed_sequences = torch.Tensor( 19 | [ 20 | [ 21 | [-2, -1, 0, 1, 2, 3, 4], 22 | [-2, 0, 2, 4, -1, -1, -1], 23 | [-1, 1, 3, 4, -1, -1, -1], 24 | [-2, -1, 0, 1, 2, 4, -1], 25 | [-1, 0, 1, 2, 3, 4, -1], 26 | [-1, 1, 2, 3, 4, -1, -1], 27 | ] 28 | ] 29 | ) 30 | mask = torch.Tensor( 31 | [ 32 | [ 33 | [1, 1, 1, 1, 1, 1, 1], 34 | [1, 1, 1, 1, 0, 0, 0], 35 | [1, 1, 1, 1, 0, 0, 0], 36 | [1, 1, 1, 1, 1, 1, 0], 37 | [1, 1, 1, 1, 1, 1, 0], 38 | [1, 1, 1, 1, 1, 0, 0], 39 | ] 40 | ] 41 | ) 42 | 43 | beam_search = ConstrainedBeamSearch(beam_size, allowed_sequences, mask) 44 | 45 | # Including the value in the score will make us pick states that have higher numbers first. 46 | # So with a beam size of 3, we'll get all of the states that start with `-1` after the 47 | # first step, even though in the end one of the states that starts with `-2` is better than 48 | # two of the states that start with `-1`. 49 | decoder_step = SimpleTransitionFunction(include_value_in_score=True) 50 | best_states = beam_search.search(initial_state, decoder_step) 51 | 52 | assert len(best_states) == 1 53 | assert best_states[0][0].action_history[0] == [-1, 1, 3, 4] 54 | assert best_states[0][1].action_history[0] == [-1, 1, 2, 3, 4] 55 | assert best_states[0][2].action_history[0] == [-1, 0, 1, 2, 3, 4] 56 | 57 | # With a beam size of 6, we should get the other allowed path of length 4 as the second 58 | # best result. 59 | beam_size = 6 60 | beam_search = ConstrainedBeamSearch(beam_size, allowed_sequences, mask) 61 | decoder_step = SimpleTransitionFunction(include_value_in_score=True) 62 | best_states = beam_search.search(initial_state, decoder_step) 63 | 64 | assert len(best_states) == 1 65 | assert best_states[0][0].action_history[0] == [-1, 1, 3, 4] 66 | assert best_states[0][1].action_history[0] == [-2, 0, 2, 4] 67 | -------------------------------------------------------------------------------- /tests/state_machines/simple_transition_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | We define a simple deterministic decoder here, that takes steps to add integers to list. At 3 | each step, the decoder takes the last integer in the list, and adds either 1 or 2 to produce the 4 | next element that will be added to the list. We initialize the list with the value 0 (or whatever 5 | you pick), and we say that a sequence is finished when the last element is 4. We define the score 6 | of a state as the negative of the number of elements (excluding the initial value) in the action 7 | history. 8 | """ 9 | from collections import defaultdict 10 | from typing import List, Set, Dict 11 | 12 | 13 | import torch 14 | 15 | from allennlp_semparse.state_machines import State, TransitionFunction 16 | 17 | 18 | class SimpleState(State["SimpleState"]): 19 | def __init__( 20 | self, 21 | batch_indices: List[int], 22 | action_history: List[List[int]], 23 | score: List[torch.Tensor], 24 | start_values: List[int] = None, 25 | ) -> None: 26 | super().__init__(batch_indices, action_history, score) 27 | self.start_values = start_values or [0] * len(batch_indices) 28 | 29 | def is_finished(self) -> bool: 30 | return self.action_history[0][-1] == 4 31 | 32 | @classmethod 33 | def combine_states(cls, states) -> "SimpleState": 34 | batch_indices = [batch_index for state in states for batch_index in state.batch_indices] 35 | action_histories = [ 36 | action_history for state in states for action_history in state.action_history 37 | ] 38 | scores = [score for state in states for score in state.score] 39 | start_values = [start_value for state in states for start_value in state.start_values] 40 | return SimpleState(batch_indices, action_histories, scores, start_values) 41 | 42 | def __repr__(self): 43 | return f"{self.action_history}" 44 | 45 | 46 | class SimpleTransitionFunction(TransitionFunction[SimpleState]): 47 | def __init__( 48 | self, valid_actions: Set[int] = None, include_value_in_score: bool = False 49 | ) -> None: 50 | # The default allowed actions are adding 1 or 2 to the last element. 51 | self._valid_actions = valid_actions or {1, 2} 52 | # If True, we will add a small multiple of the action take to the score, to encourage 53 | # getting higher numbers first (and to differentiate action sequences). 54 | self._include_value_in_score = include_value_in_score 55 | 56 | def take_step( 57 | self, state: SimpleState, max_actions: int = None, allowed_actions: List[Set] = None 58 | ) -> List[SimpleState]: 59 | indexed_next_states: Dict[int, List[SimpleState]] = defaultdict(list) 60 | if not allowed_actions: 61 | allowed_actions = [None] * len(state.batch_indices) 62 | for batch_index, action_history, score, start_value, actions in zip( 63 | state.batch_indices, 64 | state.action_history, 65 | state.score, 66 | state.start_values, 67 | allowed_actions, 68 | ): 69 | 70 | prev_action = action_history[-1] if action_history else start_value 71 | for action in self._valid_actions: 72 | next_item = int(prev_action + action) 73 | if actions and next_item not in actions: 74 | continue 75 | new_history = action_history + [next_item] 76 | # For every action taken, we reduce the score by 1. 77 | new_score = score - 1 78 | if self._include_value_in_score: 79 | new_score += 0.01 * next_item 80 | new_state = SimpleState([batch_index], [new_history], [new_score]) 81 | indexed_next_states[batch_index].append(new_state) 82 | next_states: List[SimpleState] = [] 83 | for batch_next_states in indexed_next_states.values(): 84 | sorted_next_states = [(-state.score[0].data[0], state) for state in batch_next_states] 85 | sorted_next_states.sort(key=lambda x: x[0]) 86 | if max_actions is not None: 87 | sorted_next_states = sorted_next_states[:max_actions] 88 | next_states.extend(state[1] for state in sorted_next_states) 89 | return next_states 90 | -------------------------------------------------------------------------------- /tests/state_machines/states/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/state_machines/states/__init__.py -------------------------------------------------------------------------------- /tests/state_machines/states/grammar_statelet_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from ... import SemparseTestCase 4 | 5 | from allennlp_semparse.state_machines.states import GrammarStatelet 6 | 7 | 8 | def is_nonterminal(symbol: str) -> bool: 9 | if symbol == "identity": 10 | return False 11 | if "lambda " in symbol: 12 | return False 13 | if symbol in {"x", "y", "z"}: 14 | return False 15 | return True 16 | 17 | 18 | class TestGrammarStatelet(SemparseTestCase): 19 | def test_is_finished_just_uses_nonterminal_stack(self): 20 | state = GrammarStatelet(["s"], {}, is_nonterminal) 21 | assert not state.is_finished() 22 | state = GrammarStatelet([], {}, is_nonterminal) 23 | assert state.is_finished() 24 | 25 | def test_get_valid_actions_uses_top_of_stack(self): 26 | s_actions = object() 27 | t_actions = object() 28 | e_actions = object() 29 | state = GrammarStatelet(["s"], {"s": s_actions, "t": t_actions}, is_nonterminal) 30 | assert state.get_valid_actions() == s_actions 31 | state = GrammarStatelet(["t"], {"s": s_actions, "t": t_actions}, is_nonterminal) 32 | assert state.get_valid_actions() == t_actions 33 | state = GrammarStatelet( 34 | ["e"], {"s": s_actions, "t": t_actions, "e": e_actions}, is_nonterminal 35 | ) 36 | assert state.get_valid_actions() == e_actions 37 | 38 | def test_take_action_crashes_with_mismatched_types(self): 39 | with pytest.raises(AssertionError): 40 | state = GrammarStatelet(["s"], {}, is_nonterminal) 41 | state.take_action("t -> identity") 42 | -------------------------------------------------------------------------------- /tests/state_machines/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/state_machines/trainers/__init__.py -------------------------------------------------------------------------------- /tests/state_machines/trainers/expected_risk_minimization_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from numpy.testing import assert_almost_equal 4 | 5 | from ... import SemparseTestCase 6 | 7 | from allennlp_semparse.state_machines.trainers import ExpectedRiskMinimization 8 | from ..simple_transition_system import SimpleState, SimpleTransitionFunction 9 | 10 | 11 | class TestExpectedRiskMinimization(SemparseTestCase): 12 | def setup_method(self): 13 | super().setup_method() 14 | self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])]) 15 | self.decoder_step = SimpleTransitionFunction() 16 | # Cost is the number of odd elements in the action history. 17 | self.supervision = lambda state: torch.Tensor( 18 | [sum([x % 2 != 0 for x in state.action_history[0]])] 19 | ) 20 | # High beam size ensures exhaustive search. 21 | self.trainer = ExpectedRiskMinimization( 22 | beam_size=100, normalize_by_length=False, max_decoding_steps=10 23 | ) 24 | 25 | def test_get_finished_states(self): 26 | finished_states = self.trainer._get_finished_states(self.initial_state, self.decoder_step) 27 | state_info = [(state.action_history[0], state.score[0].item()) for state in finished_states] 28 | # There will be exactly five finished states with the following paths. Each score is the 29 | # negative of one less than the number of elements in the action history. 30 | assert len(finished_states) == 5 31 | assert ([0, 2, 4], -2) in state_info 32 | assert ([0, 1, 2, 4], -3) in state_info 33 | assert ([0, 1, 3, 4], -3) in state_info 34 | assert ([0, 2, 3, 4], -3) in state_info 35 | assert ([0, 1, 2, 3, 4], -4) in state_info 36 | 37 | def test_decode(self): 38 | decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision) 39 | # The best state corresponds to the shortest path. 40 | best_state = decoded_info["best_final_states"][0][0] 41 | assert best_state.action_history[0] == [0, 2, 4] 42 | # The scores and costs corresponding to the finished states will be 43 | # [0, 2, 4] : -2, 0 44 | # [0, 1, 2, 4] : -3, 1 45 | # [0, 1, 3, 4] : -3, 2 46 | # [0, 2, 3, 4] : -3, 1 47 | # [0, 1, 2, 3, 4] : -4, 2 48 | 49 | # This is the normalization factor while re-normalizing probabilities on the beam 50 | partition = np.exp(-2) + np.exp(-3) + np.exp(-3) + np.exp(-3) + np.exp(-4) 51 | expected_loss = ( 52 | (np.exp(-2) * 0) 53 | + (np.exp(-3) * 1) 54 | + (np.exp(-3) * 2) 55 | + (np.exp(-3) * 1) 56 | + (np.exp(-4) * 2) 57 | ) / partition 58 | assert_almost_equal(decoded_info["loss"].data.numpy(), expected_loss) 59 | -------------------------------------------------------------------------------- /tests/state_machines/trainers/maximum_marginal_likelihood_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from numpy.testing import assert_almost_equal 4 | import torch 5 | 6 | from ... import SemparseTestCase 7 | 8 | from allennlp_semparse.state_machines.trainers import MaximumMarginalLikelihood 9 | from ..simple_transition_system import SimpleState, SimpleTransitionFunction 10 | 11 | 12 | class TestMaximumMarginalLikelihood(SemparseTestCase): 13 | def setup_method(self): 14 | super().setup_method() 15 | self.initial_state = SimpleState( 16 | [0, 1], [[], []], [torch.Tensor([0.0]), torch.Tensor([0.0])], [0, 1] 17 | ) 18 | self.decoder_step = SimpleTransitionFunction() 19 | self.targets = torch.Tensor( 20 | [[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]] 21 | ) 22 | self.target_mask = torch.Tensor( 23 | [[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]] 24 | ) 25 | 26 | self.supervision = (self.targets, self.target_mask) 27 | # High beam size ensures exhaustive search. 28 | self.trainer = MaximumMarginalLikelihood() 29 | 30 | def test_decode(self): 31 | decoded_info = self.trainer.decode(self.initial_state, self.decoder_step, self.supervision) 32 | 33 | # Our loss is the negative log sum of the scores from each target sequence. The score for 34 | # each sequence in our simple transition system is just `-sequence_length`. 35 | instance0_loss = math.log(math.exp(-3) * 3) # all three sequences have length 3 36 | instance1_loss = math.log(math.exp(-2) + math.exp(-3)) # one has length 2, one has length 3 37 | expected_loss = -(instance0_loss + instance1_loss) / 2 38 | assert_almost_equal(decoded_info["loss"].data.numpy(), expected_loss) 39 | -------------------------------------------------------------------------------- /tests/state_machines/transition_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-semparse/751c25f5f59c4d7973f03dc05210f9f94752f1b5/tests/state_machines/transition_functions/__init__.py -------------------------------------------------------------------------------- /tests/state_machines/util_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import SemparseTestCase 4 | 5 | from allennlp_semparse.state_machines import util 6 | 7 | 8 | class TestStateMachinesUtil(SemparseTestCase): 9 | def test_create_allowed_transitions(self): 10 | targets = torch.Tensor( 11 | [[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]] 12 | ) 13 | target_mask = torch.Tensor( 14 | [[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]] 15 | ) 16 | prefix_tree = util.construct_prefix_tree(targets, target_mask) 17 | 18 | # There were two instances in this batch. 19 | assert len(prefix_tree) == 2 20 | 21 | # The first instance had six valid action sequence prefixes. 22 | assert len(prefix_tree[0]) == 6 23 | assert prefix_tree[0][()] == {1, 2} 24 | assert prefix_tree[0][(1,)] == {2, 3} 25 | assert prefix_tree[0][(1, 2)] == {4} 26 | assert prefix_tree[0][(1, 3)] == {4} 27 | assert prefix_tree[0][(2,)] == {3} 28 | assert prefix_tree[0][(2, 3)] == {4} 29 | 30 | # The second instance had four valid action sequence prefixes. 31 | assert len(prefix_tree[1]) == 4 32 | assert prefix_tree[1][()] == {2, 3} 33 | assert prefix_tree[1][(2,)] == {3} 34 | assert prefix_tree[1][(2, 3)] == {4} 35 | assert prefix_tree[1][(3,)] == {4} 36 | -------------------------------------------------------------------------------- /training_config/wikitables_erm_parser.jsonnet: -------------------------------------------------------------------------------- 1 | // The Wikitables data is available at https://ppasupat.github.io/WikiTableQuestions/ 2 | { 3 | "random_seed": 4536, 4 | "numpy_seed": 9834, 5 | "pytorch_seed": 953, 6 | "dataset_reader": { 7 | "type": "wikitables", 8 | "lazy": false, 9 | "output_agendas": true, 10 | "tables_directory": "/wikitables_tagged/", 11 | "keep_if_no_logical_forms": true 12 | }, 13 | "vocabulary": { 14 | "min_count": {"tokens": 3}, 15 | "tokens_to_add": {"tokens": ["-1"]} 16 | }, 17 | "train_data_path": "/wikitables_raw_data/random-split-1-train.examples", 18 | "validation_data_path": "/wikitables_raw_data/random-split-1-dev.examples", 19 | "model": { 20 | "type": "wikitables_erm_parser", 21 | "question_embedder": { 22 | "tokens": { 23 | "type": "embedding", 24 | "embedding_dim": 200, 25 | "trainable": true 26 | } 27 | }, 28 | "action_embedding_dim": 100, 29 | "encoder": { 30 | "type": "lstm", 31 | "input_size": 400, 32 | "hidden_size": 100, 33 | "bidirectional": true, 34 | "num_layers": 1 35 | }, 36 | "entity_encoder": { 37 | "type": "boe", 38 | "embedding_dim": 200, 39 | "averaged": true 40 | }, 41 | "checklist_cost_weight": 0.2, 42 | "max_decoding_steps": 18, 43 | "decoder_beam_size": 50, 44 | "decoder_num_finished_states": 100, 45 | "attention": { 46 | "type": "bilinear", 47 | "vector_dim": 200, 48 | "matrix_dim": 200 49 | }, 50 | "dropout": 0.5, 51 | "mml_model_file": "/mml_model/model.tar.gz" 52 | }, 53 | "data_loader": { 54 | "batch_sampler": { 55 | "type": "bucket", 56 | "padding_noise": 0.0, 57 | "batch_size" : 10, 58 | }, 59 | }, 60 | "trainer": { 61 | "num_epochs": 30, 62 | "patience": 5, 63 | "validation_metric": "+denotation_acc", 64 | "cuda_device": -1, 65 | "optimizer": { 66 | "type": "sgd", 67 | "lr": 0.01 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /training_config/wikitables_mml_parser.jsonnet: -------------------------------------------------------------------------------- 1 | // The Wikitables data is available at https://ppasupat.github.io/WikiTableQuestions/ 2 | { 3 | "random_seed": 4536, 4 | "numpy_seed": 9834, 5 | "pytorch_seed": 953, 6 | "dataset_reader": { 7 | "type": "wikitables", 8 | "tables_directory": "/wikitables_tagged", 9 | "offline_logical_forms_directory": "/offline_search_output/", 10 | "max_offline_logical_forms": 60, 11 | "lazy": false 12 | }, 13 | "validation_dataset_reader": { 14 | "type": "wikitables", 15 | "tables_directory": "/wikitables_tagged", 16 | "keep_if_no_logical_forms": true, 17 | "lazy": false 18 | }, 19 | "vocabulary": { 20 | "min_count": {"tokens": 3}, 21 | "tokens_to_add": {"tokens": ["-1"]} 22 | }, 23 | "train_data_path": "/wikitables_raw_data/random-split-1-train.examples", 24 | "validation_data_path": "/wikitables_raw_data/random-split-1-dev.examples", 25 | "model": { 26 | "type": "wikitables_mml_parser", 27 | "question_embedder": { 28 | "tokens": { 29 | "type": "embedding", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | }, 34 | "action_embedding_dim": 100, 35 | "encoder": { 36 | "type": "lstm", 37 | "input_size": 400, 38 | "hidden_size": 100, 39 | "bidirectional": true, 40 | "num_layers": 1 41 | }, 42 | "entity_encoder": { 43 | "type": "boe", 44 | "embedding_dim": 200, 45 | "averaged": true 46 | }, 47 | "decoder_beam_search": { 48 | "beam_size": 10 49 | }, 50 | "max_decoding_steps": 16, 51 | "attention": { 52 | "type": "bilinear", 53 | "vector_dim": 200, 54 | "matrix_dim": 200 55 | }, 56 | "dropout": 0.5 57 | }, 58 | "data_loader": { 59 | "batch_size": 1, 60 | }, 61 | "trainer": { 62 | "num_epochs": 100, 63 | "patience": 10, 64 | "cuda_device": 0, 65 | "grad_norm": 5.0, 66 | "validation_metric": "+denotation_acc", 67 | "optimizer": { 68 | "type": "sgd", 69 | "lr": 0.1 70 | }, 71 | "learning_rate_scheduler": { 72 | "type": "exponential", 73 | "gamma": 0.99 74 | } 75 | } 76 | } 77 | --------------------------------------------------------------------------------